diff --git a/core/dnsserver/server_https3.go b/core/dnsserver/server_https3.go index a297a2c0a..597f0e529 100644 --- a/core/dnsserver/server_https3.go +++ b/core/dnsserver/server_https3.go @@ -18,6 +18,7 @@ import ( "github.com/coredns/coredns/plugin/pkg/reuseport" "github.com/coredns/coredns/plugin/pkg/transport" + "github.com/miekg/dns" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" ) @@ -172,7 +173,7 @@ func (s *ServerHTTPS3) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - msg, err := doh.RequestToMsg(r) + msg, raw, err := doh.RequestToMsgWire(r) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) s.countResponse(http.StatusBadRequest) @@ -188,6 +189,16 @@ func (s *ServerHTTPS3) ServeHTTP(w http.ResponseWriter, r *http.Request) { request: r, } + if tsig := msg.IsTsig(); tsig != nil { + if s.tsigSecret == nil { + dw.tsigStatus = dns.ErrSecret + } else if secret, ok := s.tsigSecret[tsig.Hdr.Name]; !ok { + dw.tsigStatus = dns.ErrSecret + } else { + dw.tsigStatus = dns.TsigVerify(raw, secret, "", false) + } + } + ctx := context.WithValue(r.Context(), Key{}, s.Server) ctx = context.WithValue(ctx, LoopKey{}, 0) ctx = context.WithValue(ctx, HTTPRequestKey{}, r) diff --git a/core/dnsserver/server_https3_test.go b/core/dnsserver/server_https3_test.go index bd460c889..faf879a5c 100644 --- a/core/dnsserver/server_https3_test.go +++ b/core/dnsserver/server_https3_test.go @@ -3,13 +3,25 @@ package dnsserver import ( "bytes" "crypto/tls" + "errors" + "io" "net/http" "net/http/httptest" + "strings" "testing" + "time" + + "github.com/coredns/coredns/plugin" "github.com/miekg/dns" ) +const ( + testTSIGKeyNameHTTPS3 = "tsig-key." + testTSIGSecretHTTPS3 = "MTIzNA==" + testTSIGWrongSecretHTTPS3 = "NTY3OA==" +) + func testServerHTTPS3(t *testing.T, path string, validator func(*http.Request) bool) *http.Response { t.Helper() c := Config{ @@ -139,3 +151,178 @@ func TestNewServerHTTPS3ZeroLimits(t *testing.T) { t.Errorf("Expected quicConfig.MaxIncomingStreams = 0 (QUIC default), got %d", server.quicConfig.MaxIncomingStreams) } } + +func testConfigWithTSIGCheckPluginHTTPS3(t *testing.T, check func(*testing.T, error)) *Config { + t.Helper() + + c := &Config{ + Zone: "example.com.", + Transport: "https3", + TLSConfig: &tls.Config{}, + ListenHosts: []string{"127.0.0.1"}, + Port: "443", + TsigSecret: map[string]string{ + testTSIGKeyNameHTTPS3: testTSIGSecretHTTPS3, + }, + } + c.AddPlugin(func(_next plugin.Handler) plugin.Handler { + return tsigStatusCheckPlugin{t: t, check: check} + }) + return c +} + +func testServerHTTPS3Msg(t *testing.T, cfg *Config, req *dns.Msg) *dns.Msg { + t.Helper() + + s, err := NewServerHTTPS3("127.0.0.1:443", []*Config{cfg}) + if err != nil { + t.Fatal("could not create HTTPS3 server:", err) + } + + buf, err := req.Pack() + if err != nil { + t.Fatal(err) + } + + r := httptest.NewRequest(http.MethodPost, "/dns-query", bytes.NewReader(buf)) + r.RemoteAddr = "127.0.0.1:12345" + w := httptest.NewRecorder() + + s.ServeHTTP(w, r) + + res := w.Result() + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + t.Fatalf("unexpected HTTP status: got %d", res.StatusCode) + } + + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + + m := new(dns.Msg) + if err := m.Unpack(body); err != nil { + t.Fatal(err) + } + return m +} + +func testServerHTTPS3Raw(t *testing.T, cfg *Config, buf []byte) *dns.Msg { + t.Helper() + + s, err := NewServerHTTPS3("127.0.0.1:443", []*Config{cfg}) + if err != nil { + t.Fatal("could not create HTTPS3 server:", err) + } + + r := httptest.NewRequest(http.MethodPost, "/dns-query", bytes.NewReader(buf)) + r.RemoteAddr = "127.0.0.1:12345" + w := httptest.NewRecorder() + + s.ServeHTTP(w, r) + + res := w.Result() + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + t.Fatalf("unexpected HTTP status: got %d", res.StatusCode) + } + + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + + m := new(dns.Msg) + if err := m.Unpack(body); err != nil { + t.Fatal(err) + } + return m +} + +func forgedTSIGMsgHTTPS3() *dns.Msg { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + + m.Extra = append(m.Extra, &dns.TSIG{ + Hdr: dns.RR_Header{ + Name: "bogus-key.", + Rrtype: dns.TypeTSIG, + Class: dns.ClassANY, + Ttl: 0, + }, + Algorithm: dns.HmacSHA256, + TimeSigned: uint64(time.Now().Unix()), + Fudge: 300, + MACSize: 32, + MAC: strings.Repeat("00", 32), + OrigId: m.Id, + Error: dns.RcodeSuccess, + }) + return m +} + +func TestServeHTTP3RejectsUnsignedTSIGRequiredRequest(t *testing.T) { + cfg := testConfigWithTSIGCheckPluginHTTPS3(t, func(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatalf("expected nil TsigStatus for unsigned request, got %v", err) + } + }) + + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + resp := testServerHTTPS3Msg(t, cfg, m) + + if resp.Rcode != dns.RcodeSuccess { + t.Fatalf("expected NOERROR response from plugin, got %s", dns.RcodeToString[resp.Rcode]) + } +} + +func TestServeHTTP3RejectsTSIGWithUnknownKey(t *testing.T) { + cfg := testConfigWithTSIGCheckPluginHTTPS3(t, func(t *testing.T, err error) { + t.Helper() + if !errors.Is(err, dns.ErrSecret) { + t.Fatalf("expected dns.ErrSecret for unknown TSIG key, got %v", err) + } + }) + + resp := testServerHTTPS3Msg(t, cfg, forgedTSIGMsgHTTPS3()) + if resp.Rcode != dns.RcodeSuccess { + t.Fatalf("expected NOERROR response from plugin, got %s", dns.RcodeToString[resp.Rcode]) + } +} + +func TestServeHTTP3RejectsTSIGWithBadMAC(t *testing.T) { + cfg := testConfigWithTSIGCheckPluginHTTPS3(t, func(t *testing.T, err error) { + t.Helper() + if err == nil { + t.Fatal("expected non-nil TsigStatus for bad TSIG MAC") + } + }) + + buf := mustPackSignedTSIGQuery(t, testTSIGKeyNameHTTPS3, testTSIGWrongSecretHTTPS3, time.Now().Unix()) + resp := testServerHTTPS3Raw(t, cfg, buf) + + if resp.Rcode != dns.RcodeSuccess { + t.Fatalf("expected NOERROR response from plugin, got %s", dns.RcodeToString[resp.Rcode]) + } +} + +func TestServeHTTP3AcceptsValidTSIG(t *testing.T) { + cfg := testConfigWithTSIGCheckPluginHTTPS3(t, func(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatalf("expected nil TsigStatus for valid TSIG, got %v", err) + } + }) + + buf := mustPackSignedTSIGQuery(t, testTSIGKeyNameHTTPS3, testTSIGSecretHTTPS3, time.Now().Unix()) + resp := testServerHTTPS3Raw(t, cfg, buf) + + if resp.Rcode != dns.RcodeSuccess { + t.Fatalf("expected NOERROR response from plugin, got %s", dns.RcodeToString[resp.Rcode]) + } +}