diff --git a/plugin/pkg/proxyproto/proxyproto.go b/plugin/pkg/proxyproto/proxyproto.go index b8db1a53e..c22aa38f4 100644 --- a/plugin/pkg/proxyproto/proxyproto.go +++ b/plugin/pkg/proxyproto/proxyproto.go @@ -70,6 +70,7 @@ func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { if err != nil { return n, addr, err } + peer := addr n, addr, err = c.readFrom(p[:n], addr) if err != nil { if errors.Is(err, errHeaderOnly) { @@ -80,7 +81,7 @@ func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { } // drop invalid packet as returning error would cause the ReadFrom caller to exit // which could result in DoS if an attacker sends intentional invalid packets - clog.Warningf("dropping invalid Proxy Protocol packet from %s: %v", addr.String(), err) + clog.Warningf("dropping invalid Proxy Protocol packet from %s: %v", peer.String(), err) continue } return n, addr, nil diff --git a/plugin/pkg/proxyproto/udp_session_test.go b/plugin/pkg/proxyproto/udp_session_test.go index 2580be41f..1f372ec32 100644 --- a/plugin/pkg/proxyproto/udp_session_test.go +++ b/plugin/pkg/proxyproto/udp_session_test.go @@ -1,6 +1,7 @@ package proxyproto import ( + "io" "net" "testing" "time" @@ -132,3 +133,61 @@ func TestStoreSessionEvictsOldest(t *testing.T) { t.Fatal("expected r3 to be present") } } + +func TestPacketConnReadFromMalformedPPv2NonUDPDoesNotPanic(t *testing.T) { + pc := &singlePacketConn{ + packet: []byte{ + 0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, + 0x55, 0x49, 0x54, 0x0a, // PPv2 signature + 0x21, // version 2, PROXY command + 0x11, // TCPv4: invalid for UDP PacketConn + 0x00, 0x0c, // address length + 127, 0, 0, 1, + 127, 0, 0, 1, + 0x30, 0x39, + 0x00, 0x35, + }, + addr: udpAddr("127.0.0.1", 12345), + } + + ppc := &PacketConn{ + PacketConn: pc, + ConnPolicy: func(proxyproto.ConnPolicyOptions) (proxyproto.Policy, error) { + return proxyproto.USE, nil + }, + } + + defer func() { + if r := recover(); r != nil { + t.Fatalf("ReadFrom panicked on malformed PPv2 non-UDP packet: %v", r) + } + }() + + buf := make([]byte, 512) + _, _, err := ppc.ReadFrom(buf) + if err != io.EOF { + t.Fatalf("ReadFrom err = %v, want io.EOF after dropping malformed packet", err) + } +} + +type singlePacketConn struct { + packet []byte + addr net.Addr + read bool +} + +func (c *singlePacketConn) ReadFrom(p []byte) (int, net.Addr, error) { + if c.read { + return 0, nil, io.EOF + } + c.read = true + copy(p, c.packet) + return len(c.packet), c.addr, nil +} + +func (c *singlePacketConn) WriteTo([]byte, net.Addr) (int, error) { return 0, nil } +func (c *singlePacketConn) Close() error { return nil } +func (c *singlePacketConn) LocalAddr() net.Addr { return udpAddr("127.0.0.1", 53) } +func (c *singlePacketConn) SetDeadline(time.Time) error { return nil } +func (c *singlePacketConn) SetReadDeadline(time.Time) error { return nil } +func (c *singlePacketConn) SetWriteDeadline(time.Time) error { return nil }