diff --git a/plugin/pkg/proxy/connect.go b/plugin/pkg/proxy/connect.go index 27385a467..a855120fa 100644 --- a/plugin/pkg/proxy/connect.go +++ b/plugin/pkg/proxy/connect.go @@ -18,6 +18,13 @@ import ( "github.com/miekg/dns" ) +const ( + ErrTransportStopped = "proxy: transport stopped" + ErrTransportStoppedDuringDial = "proxy: transport stopped during dial" + ErrTransportStoppedRetClosed = "proxy: transport stopped, ret channel closed" + ErrTransportStoppedDuringRetWait = "proxy: transport stopped during ret wait" +) + // limitTimeout is a utility function to auto-tune timeout values // average observed time is moved towards the last observed delay moderated by a weight // next timeout to use will be the double of the computed average, limited by min and max frame. @@ -52,25 +59,48 @@ func (t *Transport) Dial(proto string) (*persistConn, bool, error) { proto = "tcp-tls" } - t.dial <- proto - pc := <-t.ret - - if pc != nil { - connCacheHitsCount.WithLabelValues(t.proxyName, t.addr, proto).Add(1) - return pc, true, nil + // Check if transport is stopped before attempting to dial + select { + case <-t.stop: + return nil, false, errors.New(ErrTransportStopped) + default: } - connCacheMissesCount.WithLabelValues(t.proxyName, t.addr, proto).Add(1) - reqTime := time.Now() - timeout := t.dialTimeout() - if proto == "tcp-tls" { - conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout) + // Use select to avoid blocking if connManager has stopped + select { + case t.dial <- proto: + // Successfully sent dial request + case <-t.stop: + return nil, false, errors.New(ErrTransportStoppedDuringDial) + } + + // Receive response with stop awareness + select { + case pc, ok := <-t.ret: + if !ok { + // ret channel was closed by connManager during stop + return nil, false, errors.New(ErrTransportStoppedRetClosed) + } + + if pc != nil { + connCacheHitsCount.WithLabelValues(t.proxyName, t.addr, proto).Add(1) + return pc, true, nil + } + connCacheMissesCount.WithLabelValues(t.proxyName, t.addr, proto).Add(1) + + reqTime := time.Now() + timeout := t.dialTimeout() + if proto == "tcp-tls" { + conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout) + t.updateDialTimeout(time.Since(reqTime)) + return &persistConn{c: conn}, false, err + } + conn, err := dns.DialTimeout(proto, t.addr, timeout) t.updateDialTimeout(time.Since(reqTime)) return &persistConn{c: conn}, false, err + case <-t.stop: + return nil, false, errors.New(ErrTransportStoppedDuringRetWait) } - conn, err := dns.DialTimeout(proto, t.addr, timeout) - t.updateDialTimeout(time.Since(reqTime)) - return &persistConn{c: conn}, false, err } // Connect selects an upstream, sends the request and waits for a response. diff --git a/plugin/pkg/proxy/connect_test.go b/plugin/pkg/proxy/connect_test.go new file mode 100644 index 000000000..edc68b68d --- /dev/null +++ b/plugin/pkg/proxy/connect_test.go @@ -0,0 +1,255 @@ +package proxy + +import ( + "sync" + "testing" + "time" +) + +const ( + testMsgExpectedError = "expected error" + testMsgUnexpectedNilError = "unexpected nil error" + testMsgWrongError = "wrong error message" +) + +// TestDial_TransportStopped_InitialCheck tests that Dial returns ErrTransportStopped +// if the transport is stopped before Dial is called. +func TestDial_TransportStopped_InitialCheck(t *testing.T) { + tr := newTransport("test_initial_stop", "127.0.0.1:0") + tr.Start() + + tr.Stop() + time.Sleep(50 * time.Millisecond) // Ensure connManager processes stop and exits + + _, _, err := tr.Dial("udp") + if err == nil { + t.Fatalf("%s: %s", testMsgExpectedError, testMsgUnexpectedNilError) + } + if err.Error() != ErrTransportStopped { + t.Errorf("%s: got '%v', want '%s'", testMsgWrongError, err, ErrTransportStopped) + } +} + +// TestDial_TransportStoppedDuringDialSend tests that Dial returns ErrTransportStoppedDuringDial +// if Stop() is called while Dial is attempting to send on the (blocked) t.dial channel. +// This is achieved by not starting the connManager, so t.dial remains unread. +func TestDial_TransportStoppedDuringDialSend(t *testing.T) { + tr := newTransport("test_during_dial_send", "127.0.0.1:0") + // No tr.Start() here. This ensures t.dial channel will block. + + dialErrChan := make(chan error, 1) + go func() { + // Dial will pass initial stop check (t.stop is open). + // Then it will block on `t.dial <- proto` because no connManager is reading. + _, _, err := tr.Dial("udp") + dialErrChan <- err + }() + + // Allow Dial goroutine to reach the blocking send on t.dial + time.Sleep(50 * time.Millisecond) + + tr.Stop() // Close t.stop. Dial's select should now pick <-t.stop. + + err := <-dialErrChan + if err == nil { + t.Fatalf("%s: %s", testMsgExpectedError, testMsgUnexpectedNilError) + } + if err.Error() != ErrTransportStoppedDuringDial { + t.Errorf("%s: got '%v', want '%s'", testMsgWrongError, err, ErrTransportStoppedDuringDial) + } +} + +// TestDial_TransportStoppedDuringRetWait tests that Dial returns ErrTransportStoppedDuringRetWait +// when the transport is stopped while Dial is waiting to receive from t.ret channel. +func TestDial_TransportStoppedDuringRetWait(t *testing.T) { + tr := newTransport("test_during_ret_wait", "127.0.0.1:0") + // Replace transport's channels to control interaction precisely + tr.dial = make(chan string) // Test-controlled, unbuffered + tr.ret = make(chan *persistConn) // Test-controlled, unbuffered + // tr.stop remains the original transport stop channel + + // Start connManager and use our controlled channels. + tr.Start() + + dialErrChan := make(chan error, 1) + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + defer wg.Done() + // Dial will: + // 1. Pass initial stop check. + // 2. Send on our tr.dial. + // 3. Block on our tr.ret in its 3rd select. + // When tr.Stop() is called, this 3rd select should pick <-tr.stop. + _, _, err := tr.Dial("udp") + dialErrChan <- err + }() + + // Simulate connManager reading from our tr.dial. + // This unblocks the Dial goroutine's send. + var protoFromDial string + select { + case protoFromDial = <-tr.dial: + t.Logf("Test: Simulated connManager read '%s' from Dial via test-controlled tr.dial", protoFromDial) + case <-time.After(100 * time.Millisecond): + t.Fatal("Test: Timeout waiting for Dial to send on test-controlled tr.dial") + } + + // Stop the transport and the tr.stop channel + tr.Stop() + + wg.Wait() // Wait for Dial goroutine to complete. + err := <-dialErrChan + + if err == nil { + t.Fatalf("%s: %s", testMsgExpectedError, testMsgUnexpectedNilError) + } + + // Expected error is ErrTransportStoppedDuringRetWait + // However, if connManager (using replaced channels) itself reacts to stop faster + // and somehow closes the test-controlled tr.ret (not its design), other errors are possible. + // But with tr.ret being ours and unwritten-to, Dial should pick tr.stop. + if err.Error() != ErrTransportStoppedDuringRetWait { + t.Errorf("%s: got '%v', want '%s' (or potentially '%s' if timing is very tight)", + testMsgWrongError, err, ErrTransportStoppedDuringRetWait, ErrTransportStopped) + } else { + t.Logf("SUCCESS: Dial correctly returned '%s'", ErrTransportStoppedDuringRetWait) + } +} + +// TestDial_Returns_ErrTransportStoppedRetClosed tests that Dial +// returns ErrTransportStoppedRetClosed when tr.ret is closed before Dial reads from it. +func TestDial_Returns_ErrTransportStoppedRetClosed(t *testing.T) { + tr := newTransport("test_returns_ret_closed", "127.0.0.1:0") + + // Replace transport channels with test-controlled ones + testDialChan := make(chan string, 1) // Buffered to allow non-blocking send by Dial + testRetChan := make(chan *persistConn) // This will be closed by the test + tr.dial = testDialChan + tr.ret = testRetChan + // tr.stop remains the original, initially open channel. + + dialErrChan := make(chan error, 1) + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + // Dial will: + // 1. Pass initial stop check (tr.stop is open). + // 2. Send "udp" on tr.dial (which is testDialChan). + // 3. Block on <-tr.ret (which is testRetChan) in its 3rd select. + // When testRetChan is closed, it will read (nil, false), hitting the target error. + _, _, err := tr.Dial("udp") + dialErrChan <- err + }() + + // Step 1: Simulate connManager reading the dial request from Dial. + // Read from testDialChan. This unblocks the Dial goroutine's send to testDialChan. + select { + case proto := <-testDialChan: + if proto != "udp" { + wg.Done() + t.Fatalf("Test: Dial sent wrong proto on testDialChan: got %s, want udp", proto) + } + t.Logf("Test: Simulated connManager received '%s' from Dial via testDialChan.", proto) + case <-time.After(100 * time.Millisecond): + // If Dial didn't send, the test is flawed or Dial is stuck before sending. + wg.Done() + t.Fatal("Test: Timeout waiting for Dial to send on testDialChan.") + } + + // Step 2: Simulate connManager stopping and closing its 'ret' channel. + close(testRetChan) + t.Logf("Test: Closed testRetChan (simulating connManager closing tr.ret).") + + // Step 3: Wait for the Dial goroutine to complete. + wg.Wait() + err := <-dialErrChan + + if err == nil { + t.Fatalf("%s: %s", testMsgExpectedError, testMsgUnexpectedNilError) + } + + if err.Error() != ErrTransportStoppedRetClosed { + t.Errorf("%s: got '%v', want '%s'", testMsgWrongError, err, ErrTransportStoppedRetClosed) + } else { + t.Logf("SUCCESS: Dial correctly returned '%s'", ErrTransportStoppedRetClosed) + } + + // Call tr.Stop() for completeness to close the original tr.stop channel. + // connManager was not started with original channels, so this mainly affects tr.stop. + tr.Stop() +} + +// TestDial_ConnManagerClosesRetOnStop verifies that connManager closes tr.ret upon stopping. +func TestDial_ConnManagerClosesRetOnStop(t *testing.T) { + tr := newTransport("test_connmanager_closes_ret", "127.0.0.1:0") + tr.Start() + + // Initiate a Dial to interact with connManager so tr.ret is used. + interactionDialErrChan := make(chan error, 1) + go func() { + _, _, err := tr.Dial("udp") + interactionDialErrChan <- err + }() + + // Allow the Dial goroutine to interact with connManager. + time.Sleep(100 * time.Millisecond) + + // Now stop the transport. connManager should clean up and close tr.ret. + tr.Stop() + + // Wait for connManager to fully stop and close its channels. + // This duration needs to be sufficient for the select loop in connManager to see <-t.stop, + // call t.cleanup(true), which in turn calls close(t.ret). + time.Sleep(50 * time.Millisecond) + + // Check if tr.ret is actually closed by trying a non-blocking read. + select { + case _, ok := <-tr.ret: + if !ok { + t.Logf("SUCCESS: tr.ret channel is closed as expected after transport stop.") + } else { + t.Errorf("FAIL: tr.ret channel was not closed after transport stop, or a value was read unexpectedly.") + } + default: + // This case means tr.ret is open but blocking (empty). + // This would be unexpected if connManager is supposed to close it on stop. + t.Errorf("FAIL: tr.ret channel is not closed and is blocking (or empty but open).") + } + + // Drain the error channel from the initial interaction Dial to ensure the goroutine finishes. + select { + case err := <-interactionDialErrChan: + if err != nil { + t.Logf("Interaction Dial completed with error (possibly expected due to 127.0.0.1:0 or race with Stop): %v", err) + } else { + t.Logf("Interaction Dial completed without error.") + } + case <-time.After(100 * time.Millisecond): // Timeout for safety if Dial hangs + t.Logf("Timeout waiting for interaction Dial to complete.") + } +} + +// TestDial_MultipleCallsAfterStop tests that multiple Dial calls after Stop +// consistently return ErrTransportStopped. +func TestDial_MultipleCallsAfterStop(t *testing.T) { + tr := newTransport("test_multiple_after_stop", "127.0.0.1:0") + tr.Start() + + tr.Stop() + time.Sleep(50 * time.Millisecond) + + for i := 0; i < 3; i++ { + _, _, err := tr.Dial("udp") + if err == nil { + t.Errorf("Attempt %d: %s: %s", i+1, testMsgExpectedError, testMsgUnexpectedNilError) + continue + } + if err.Error() != ErrTransportStopped { + t.Errorf("Attempt %d: %s: got '%v', want '%s'", i+1, testMsgWrongError, err, ErrTransportStopped) + } + } +}