diff --git a/core/dnsserver/server_https.go b/core/dnsserver/server_https.go index 9e866adc3..334d4cffb 100644 --- a/core/dnsserver/server_https.go +++ b/core/dnsserver/server_https.go @@ -236,7 +236,7 @@ func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) { buf, _ := dw.Msg.Pack() mt, _ := response.Typify(dw.Msg, time.Now().UTC()) - age := dnsutil.MinimalTTL(dw.Msg, mt) + age := dnsutil.MinimalTTLWithMaximum(dw.Msg, mt, dnsutil.MaximumDefaultTTL) w.Header().Set("Content-Type", doh.MimeType) w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", uint32(age.Seconds()))) diff --git a/core/dnsserver/server_https3.go b/core/dnsserver/server_https3.go index 340d83f4e..b43fc7cd8 100644 --- a/core/dnsserver/server_https3.go +++ b/core/dnsserver/server_https3.go @@ -216,7 +216,7 @@ func (s *ServerHTTPS3) ServeHTTP(w http.ResponseWriter, r *http.Request) { buf, _ := dw.Msg.Pack() mt, _ := response.Typify(dw.Msg, time.Now().UTC()) - age := dnsutil.MinimalTTL(dw.Msg, mt) + age := dnsutil.MinimalTTLWithMaximum(dw.Msg, mt, dnsutil.MaximumDefaultTTL) w.Header().Set("Content-Type", doh.MimeType) w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", uint32(age.Seconds()))) diff --git a/plugin/cache/cache.go b/plugin/cache/cache.go index 15312a1a6..fba713576 100644 --- a/plugin/cache/cache.go +++ b/plugin/cache/cache.go @@ -221,14 +221,15 @@ func (w *ResponseWriter) WriteMsg(res *dns.Msg) error { // key returns empty string for anything we don't want to cache. hasKey, key := key(w.state.Name(), res, mt, w.do, w.cd) - msgTTL := dnsutil.MinimalTTL(res, mt) var duration time.Duration switch mt { case response.NameError, response.NoData: + msgTTL := dnsutil.MinimalTTLWithMaximum(res, mt, w.nttl) duration = computeTTL(msgTTL, w.minnttl, w.nttl) case response.ServerError: duration = w.failttl default: + msgTTL := dnsutil.MinimalTTLWithMaximum(res, mt, w.pttl) duration = computeTTL(msgTTL, w.minpttl, w.pttl) } diff --git a/plugin/cache/cache_test.go b/plugin/cache/cache_test.go index 811884495..0a11f9737 100644 --- a/plugin/cache/cache_test.go +++ b/plugin/cache/cache_test.go @@ -362,6 +362,27 @@ func TestCacheZeroTTL(t *testing.T) { } } +func TestCacheHonorsConfiguredPositiveMaxTTLAboveDefault(t *testing.T) { + c := New() + c.pttl = 2 * time.Hour + c.minpttl = 0 + c.Next = ttlBackend(24 * 60 * 60) + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + c.ServeDNS(context.TODO(), rec, req) + + if rec.Msg == nil || len(rec.Msg.Answer) == 0 { + t.Fatalf("expected answer, got %+v", rec.Msg) + } + + if got, want := rec.Msg.Answer[0].Header().Ttl, uint32(7200); got != want { + t.Fatalf("expected TTL %d, got %d", want, got) + } +} + func TestCacheServfailTTL0(t *testing.T) { c := New() c.minpttl = minTTL diff --git a/plugin/pkg/dnsutil/ttl.go b/plugin/pkg/dnsutil/ttl.go index 7ac2f987a..7074898ac 100644 --- a/plugin/pkg/dnsutil/ttl.go +++ b/plugin/pkg/dnsutil/ttl.go @@ -10,6 +10,12 @@ import ( // MinimalTTL scans the message returns the lowest TTL found taking into the response.Type of the message. func MinimalTTL(m *dns.Msg, mt response.Type) time.Duration { + return MinimalTTLWithMaximum(m, mt, MaximumDefaultTTL) +} + +// MinimalTTLWithMaximum scans the DNS message and returns the lowest TTL found, +// constrained by maximumTTL and the response type. +func MinimalTTLWithMaximum(m *dns.Msg, mt response.Type, maximumTTL time.Duration) time.Duration { if mt != response.NoError && mt != response.NameError && mt != response.NoData { return MinimalDefaultTTL } @@ -20,7 +26,7 @@ func MinimalTTL(m *dns.Msg, mt response.Type) time.Duration { return MinimalDefaultTTL } - minTTL := MaximumDefaulTTL + minTTL := maximumTTL for _, r := range m.Answer { if r.Header().Ttl < uint32(minTTL.Seconds()) { minTTL = time.Duration(r.Header().Ttl) * time.Second