feat(cache): add optional verify timeout to serve_stale (#8070)

This commit is contained in:
Syed Azeez
2026-05-06 13:02:28 +05:30
committed by GitHub
parent 145029c847
commit b2cb44b966
6 changed files with 226 additions and 23 deletions

View File

@@ -501,6 +501,103 @@ func TestServeFromStaleCacheFetchVerify(t *testing.T) {
}
}
func TestServeFromStaleCacheFetchVerifyTimeout(t *testing.T) {
// Verify that when verifyStaleTimeout is set and the upstream is slow,
// the client gets the stale entry within ~timeout, while the in-flight
// verify continues in the background and refreshes the cache.
c := New()
c.staleUpTo = 1 * time.Hour
c.verifyStale = true
c.verifyStaleTimeout = 50 * time.Millisecond
c.Next = ttlBackend(120)
req := new(dns.Msg)
req.SetQuestion("cached.org.", dns.TypeA)
ctx := context.TODO()
// Prime the cache with a 120s TTL entry.
rec := dnstest.NewRecorder(&test.ResponseWriter{})
c.ServeDNS(ctx, rec, req)
if c.pcache.Len() != 1 {
t.Fatalf("Msg with > 0 TTL should have been cached")
}
// Move forward past the TTL so the entry is stale.
c.now = func() time.Time { return time.Now().Add(3 * time.Minute) }
// Swap in a slow backend that takes longer than the verify timeout.
bgDone := make(chan struct{})
c.Next = slowTTLBackend(60, 200*time.Millisecond, bgDone)
rec = dnstest.NewRecorder(&test.ResponseWriter{})
start := time.Now()
ret, err := c.ServeDNS(ctx, rec, req.Copy())
elapsed := time.Since(start)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ret != dns.RcodeSuccess {
t.Fatalf("expected RcodeSuccess, got %d", ret)
}
if elapsed > 150*time.Millisecond {
t.Errorf("expected response within ~timeout (50ms); took %v", elapsed)
}
if rec.Msg == nil || len(rec.Msg.Answer) == 0 {
t.Fatalf("expected an answer, got %+v", rec.Msg)
}
// Stale serve sets TTL to 0.
if got := rec.Msg.Answer[0].Header().Ttl; got != 0 {
t.Errorf("expected stale TTL=0, got %d", got)
}
// Wait for the background verify to complete.
select {
case <-bgDone:
case <-time.After(2 * time.Second):
t.Fatalf("background verify never completed")
}
}
func TestServeFromStaleCacheFetchVerifyTimeoutFastUpstream(t *testing.T) {
// When the upstream answers within the verify timeout, the client should
// receive the freshly verified response (not a stale one).
c := New()
c.staleUpTo = 1 * time.Hour
c.verifyStale = true
c.verifyStaleTimeout = 500 * time.Millisecond
c.Next = ttlBackend(120)
req := new(dns.Msg)
req.SetQuestion("cached.org.", dns.TypeA)
ctx := context.TODO()
rec := dnstest.NewRecorder(&test.ResponseWriter{})
c.ServeDNS(ctx, rec, req)
if c.pcache.Len() != 1 {
t.Fatalf("Msg with > 0 TTL should have been cached")
}
c.now = func() time.Time { return time.Now().Add(3 * time.Minute) }
// Fast upstream returning fresh TTL=200.
c.Next = ttlBackend(200)
rec = dnstest.NewRecorder(&test.ResponseWriter{})
ret, err := c.ServeDNS(ctx, rec, req.Copy())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ret != dns.RcodeSuccess {
t.Fatalf("expected RcodeSuccess, got %d", ret)
}
if rec.Msg == nil || len(rec.Msg.Answer) == 0 {
t.Fatalf("expected an answer, got %+v", rec.Msg)
}
if got := rec.Msg.Answer[0].Header().Ttl; got != 200 {
t.Errorf("expected fresh TTL=200, got %d", got)
}
}
func TestNegativeStaleMaskingPositiveCache(t *testing.T) {
c := New()
c.staleUpTo = time.Minute * 10
@@ -672,6 +769,28 @@ func servFailBackend(ttl int) plugin.Handler {
})
}
// slowTTLBackend wraps ttlBackend with a fixed delay to simulate a slow upstream.
// done is closed once the response is written so callers can synchronise with the
// background goroutine.
func slowTTLBackend(ttl int, delay time.Duration, done chan<- struct{}) plugin.Handler {
return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
select {
case <-time.After(delay):
case <-ctx.Done():
return dns.RcodeServerFailure, ctx.Err()
}
m := new(dns.Msg)
m.SetReply(r)
m.Response, m.RecursionAvailable = true, true
m.Answer = []dns.RR{test.A(fmt.Sprintf("example.org. %d IN A 127.0.0.53", ttl))}
w.WriteMsg(m)
if done != nil {
close(done)
}
return dns.RcodeSuccess, nil
})
}
func TestComputeTTL(t *testing.T) {
tests := []struct {
msgTTL time.Duration