diff --git a/plugin/metrics/handler.go b/plugin/metrics/handler.go index fb350a2f5..4ac0e0ecc 100644 --- a/plugin/metrics/handler.go +++ b/plugin/metrics/handler.go @@ -2,7 +2,6 @@ package metrics import ( "context" - "path/filepath" "github.com/coredns/coredns/plugin" "github.com/coredns/coredns/plugin/metrics/vars" @@ -36,9 +35,9 @@ func (m *Metrics) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg // see https://github.com/coredns/coredns/blob/master/core/dnsserver/server.go#L318 rc = status } - plugin := m.authoritativePlugin(rw.Caller) // Pass the original request size to vars.Report - vars.Report(WithServer(ctx), state, zone, WithView(ctx), rcode.ToString(rc), plugin, + // rw.Plugin is set automatically by the plugin chain via the PluginTracker interface + vars.Report(WithServer(ctx), state, zone, WithView(ctx), rcode.ToString(rc), rw.Plugin, rw.Len, rw.Start, vars.WithOriginalReqSize(originalSize)) return status, err @@ -46,17 +45,3 @@ func (m *Metrics) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg // Name implements the Handler interface. func (m *Metrics) Name() string { return "prometheus" } - -// authoritativePlugin returns which of made the write, if none is found the empty string is returned. -func (m *Metrics) authoritativePlugin(caller [3]string) string { - // a b and c contain the full path of the caller, the plugin name 2nd last elements - // .../coredns/plugin/whoami/whoami.go --> whoami - // this is likely FS specific, so use filepath. - for _, c := range caller { - plug := filepath.Base(filepath.Dir(c)) - if _, ok := m.plugins[plug]; ok { - return plug - } - } - return "" -} diff --git a/plugin/metrics/recorder.go b/plugin/metrics/recorder.go index d4d42ba5c..c11ceb8ec 100644 --- a/plugin/metrics/recorder.go +++ b/plugin/metrics/recorder.go @@ -1,8 +1,6 @@ package metrics import ( - "runtime" - "github.com/coredns/coredns/plugin/pkg/dnstest" "github.com/miekg/dns" @@ -11,8 +9,9 @@ import ( // Recorder is a dnstest.Recorder specific to the metrics plugin. type Recorder struct { *dnstest.Recorder - // CallerN holds the string return value of the call to runtime.Caller(N+1) - Caller [3]string + // Plugin holds the name of the plugin that wrote the response. + // This is set automatically by the plugin chain via the PluginTracker interface. + Plugin string } // NewRecorder makes and returns a new Recorder. @@ -21,8 +20,15 @@ func NewRecorder(w dns.ResponseWriter) *Recorder { return &Recorder{Recorder: dn // WriteMsg records the status code and calls the // underlying ResponseWriter's WriteMsg method. func (r *Recorder) WriteMsg(res *dns.Msg) error { - _, r.Caller[0], _, _ = runtime.Caller(1) - _, r.Caller[1], _, _ = runtime.Caller(2) - _, r.Caller[2], _, _ = runtime.Caller(3) return r.Recorder.WriteMsg(res) } + +// SetPlugin implements the plugin.PluginTracker interface. +func (r *Recorder) SetPlugin(name string) { + r.Plugin = name +} + +// GetPlugin implements the plugin.PluginTracker interface. +func (r *Recorder) GetPlugin() string { + return r.Plugin +} diff --git a/plugin/metrics/recorder_test.go b/plugin/metrics/recorder_test.go index fd8c5fc8d..30d95b109 100644 --- a/plugin/metrics/recorder_test.go +++ b/plugin/metrics/recorder_test.go @@ -23,6 +23,34 @@ func (r *inmemoryWriter) Write(buf []byte) (int, error) { return r.ResponseWriter.Write(buf) } +func TestRecorder_PluginTracker(t *testing.T) { + tw := inmemoryWriter{ResponseWriter: test.ResponseWriter{}} + rec := NewRecorder(&tw) + + // Initially Plugin should be empty + if rec.Plugin != "" { + t.Errorf("Expected empty Plugin, got %q", rec.Plugin) + } + if rec.GetPlugin() != "" { + t.Errorf("Expected GetPlugin() to return empty string, got %q", rec.GetPlugin()) + } + + // SetPlugin should set the plugin name + rec.SetPlugin("whoami") + if rec.Plugin != "whoami" { + t.Errorf("Expected Plugin to be 'whoami', got %q", rec.Plugin) + } + if rec.GetPlugin() != "whoami" { + t.Errorf("Expected GetPlugin() to return 'whoami', got %q", rec.GetPlugin()) + } + + // SetPlugin should overwrite previous value + rec.SetPlugin("cache") + if rec.Plugin != "cache" { + t.Errorf("Expected Plugin to be 'cache', got %q", rec.Plugin) + } +} + func TestRecorder_WriteMsg(t *testing.T) { successResp := dns.Msg{} successResp.Answer = []dns.RR{ diff --git a/plugin/plugin.go b/plugin/plugin.go index 43c1e6547..2cb55e64d 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "net" "github.com/miekg/dns" ot "github.com/opentracing/opentracing-go" @@ -77,12 +78,60 @@ func NextOrFailure(name string, next Handler, ctx context.Context, w dns.Respons defer child.Finish() ctx = ot.ContextWithSpan(ctx, child) } - return next.ServeDNS(ctx, w, r) + // Wrap the ResponseWriter to track which plugin writes the response + pw := &pluginWriter{ResponseWriter: w, plugin: next.Name()} + return next.ServeDNS(ctx, pw, r) } return dns.RcodeServerFailure, Error(name, errors.New("no next plugin found")) } +// PluginTracker is an interface for ResponseWriters that track which plugin wrote the response. +type PluginTracker interface { + SetPlugin(name string) + GetPlugin() string +} + +// pluginWriter wraps a dns.ResponseWriter to track which plugin writes the response. +type pluginWriter struct { + dns.ResponseWriter + plugin string +} + +// WriteMsg implements dns.ResponseWriter and tracks the plugin that wrote the response. +func (pw *pluginWriter) WriteMsg(m *dns.Msg) error { + if tracker, ok := pw.ResponseWriter.(PluginTracker); ok { + tracker.SetPlugin(pw.plugin) + } + return pw.ResponseWriter.WriteMsg(m) +} + +// Write implements dns.ResponseWriter. +func (pw *pluginWriter) Write(b []byte) (int, error) { + if tracker, ok := pw.ResponseWriter.(PluginTracker); ok { + tracker.SetPlugin(pw.plugin) + } + return pw.ResponseWriter.Write(b) +} + +// LocalAddr implements dns.ResponseWriter. +func (pw *pluginWriter) LocalAddr() net.Addr { return pw.ResponseWriter.LocalAddr() } + +// RemoteAddr implements dns.ResponseWriter. +func (pw *pluginWriter) RemoteAddr() net.Addr { return pw.ResponseWriter.RemoteAddr() } + +// Close implements dns.ResponseWriter. +func (pw *pluginWriter) Close() error { return pw.ResponseWriter.Close() } + +// TsigStatus implements dns.ResponseWriter. +func (pw *pluginWriter) TsigStatus() error { return pw.ResponseWriter.TsigStatus() } + +// TsigTimersOnly implements dns.ResponseWriter. +func (pw *pluginWriter) TsigTimersOnly(b bool) { pw.ResponseWriter.TsigTimersOnly(b) } + +// Hijack implements dns.ResponseWriter. +func (pw *pluginWriter) Hijack() { pw.ResponseWriter.Hijack() } + // ClientWrite returns true if the response has been written to the client. // Each plugin to adhere to this protocol. func ClientWrite(rcode int) bool { diff --git a/plugin/plugin_test.go b/plugin/plugin_test.go new file mode 100644 index 000000000..2ebce5477 --- /dev/null +++ b/plugin/plugin_test.go @@ -0,0 +1,170 @@ +package plugin + +import ( + "context" + "net" + "testing" + + "github.com/miekg/dns" +) + +// mockResponseWriter implements dns.ResponseWriter for testing +type mockResponseWriter struct { + msg *dns.Msg +} + +func (m *mockResponseWriter) LocalAddr() net.Addr { + return &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 53} +} +func (m *mockResponseWriter) RemoteAddr() net.Addr { + return &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 40212} +} +func (m *mockResponseWriter) WriteMsg(msg *dns.Msg) error { m.msg = msg; return nil } +func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil } +func (m *mockResponseWriter) Close() error { return nil } +func (m *mockResponseWriter) TsigStatus() error { return nil } +func (m *mockResponseWriter) TsigTimersOnly(bool) {} +func (m *mockResponseWriter) Hijack() {} + +// mockPluginTracker implements PluginTracker for testing +type mockPluginTracker struct { + mockResponseWriter + plugin string +} + +func (m *mockPluginTracker) SetPlugin(name string) { m.plugin = name } +func (m *mockPluginTracker) GetPlugin() string { return m.plugin } + +// mockHandler implements Handler for testing +type mockHandler struct { + name string + writeMsg bool + returnErr error +} + +func (m *mockHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + if m.writeMsg { + resp := new(dns.Msg) + resp.SetReply(r) + w.WriteMsg(resp) + } + if m.returnErr != nil { + return dns.RcodeServerFailure, m.returnErr + } + return dns.RcodeSuccess, nil +} + +func (m *mockHandler) Name() string { return m.name } + +func TestPluginWriter_WriteMsg_SetsPlugin(t *testing.T) { + tracker := &mockPluginTracker{} + pw := &pluginWriter{ResponseWriter: tracker, plugin: "whoami"} + + msg := new(dns.Msg) + msg.SetQuestion("example.com.", dns.TypeA) + + err := pw.WriteMsg(msg) + if err != nil { + t.Fatalf("WriteMsg returned error: %v", err) + } + + if tracker.plugin != "whoami" { + t.Errorf("Expected plugin to be 'whoami', got %q", tracker.plugin) + } +} + +func TestPluginWriter_Write_SetsPlugin(t *testing.T) { + tracker := &mockPluginTracker{} + pw := &pluginWriter{ResponseWriter: tracker, plugin: "forward"} + + _, err := pw.Write([]byte("test")) + if err != nil { + t.Fatalf("Write returned error: %v", err) + } + + if tracker.plugin != "forward" { + t.Errorf("Expected plugin to be 'forward', got %q", tracker.plugin) + } +} + +func TestPluginWriter_NonTracker_NoError(t *testing.T) { + // When the underlying writer doesn't implement PluginTracker, + // WriteMsg should still work without error + mock := &mockResponseWriter{} + pw := &pluginWriter{ResponseWriter: mock, plugin: "whoami"} + + msg := new(dns.Msg) + msg.SetQuestion("example.com.", dns.TypeA) + + err := pw.WriteMsg(msg) + if err != nil { + t.Fatalf("WriteMsg returned error: %v", err) + } + + if mock.msg == nil { + t.Error("Expected message to be written") + } +} + +func TestNextOrFailure_WrapsWithPluginWriter(t *testing.T) { + tracker := &mockPluginTracker{} + handler := &mockHandler{name: "testplugin", writeMsg: true} + + req := new(dns.Msg) + req.SetQuestion("example.com.", dns.TypeA) + + _, err := NextOrFailure("caller", handler, context.Background(), tracker, req) + if err != nil { + t.Fatalf("NextOrFailure returned error: %v", err) + } + + // The handler should have written a message, which should have set the plugin + if tracker.plugin != "testplugin" { + t.Errorf("Expected plugin to be 'testplugin', got %q", tracker.plugin) + } +} + +func TestNextOrFailure_NilHandler(t *testing.T) { + mock := &mockResponseWriter{} + req := new(dns.Msg) + req.SetQuestion("example.com.", dns.TypeA) + + rcode, err := NextOrFailure("caller", nil, context.Background(), mock, req) + if err == nil { + t.Error("Expected error for nil handler") + } + if rcode != dns.RcodeServerFailure { + t.Errorf("Expected RcodeServerFailure, got %d", rcode) + } +} + +func TestPluginWriter_DelegatesMethods(t *testing.T) { + mock := &mockResponseWriter{} + pw := &pluginWriter{ResponseWriter: mock, plugin: "test"} + + // Test LocalAddr + if pw.LocalAddr() == nil { + t.Error("LocalAddr should not return nil") + } + + // Test RemoteAddr + if pw.RemoteAddr() == nil { + t.Error("RemoteAddr should not return nil") + } + + // Test Close + if err := pw.Close(); err != nil { + t.Errorf("Close returned error: %v", err) + } + + // Test TsigStatus + if err := pw.TsigStatus(); err != nil { + t.Errorf("TsigStatus returned error: %v", err) + } + + // Test TsigTimersOnly (should not panic) + pw.TsigTimersOnly(true) + + // Test Hijack (should not panic) + pw.Hijack() +}