core: Add full TSIG verification in DoH transport (#8013)

* core: Add full TSIG verification in DoH transport

This PR add full TSIG verification in DoH using dns.TsigVerify()
7943

---------

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang
2026-04-09 05:24:00 -07:00
committed by GitHub
parent 18d692a986
commit c0e6e7cef3
4 changed files with 233 additions and 20 deletions

View File

@@ -77,6 +77,13 @@ func ResponseToMsg(resp *http.Response) (*dns.Msg, error) {
// RequestToMsg converts a http.Request to a dns message.
func RequestToMsg(req *http.Request) (*dns.Msg, error) {
msg, _, err := RequestToMsgWire(req)
return msg, err
}
// RequestToMsgWire converts a http.Request to a dns message and returns the
// original DNS wire bytes from the request.
func RequestToMsgWire(req *http.Request) (*dns.Msg, []byte, error) {
switch req.Method {
case http.MethodGet:
return requestToMsgGet(req)
@@ -85,55 +92,60 @@ func RequestToMsg(req *http.Request) (*dns.Msg, error) {
return requestToMsgPost(req)
default:
return nil, fmt.Errorf("method not allowed: %s", req.Method)
return nil, nil, fmt.Errorf("method not allowed: %s", req.Method)
}
}
// requestToMsgPost extracts the dns message from the request body.
func requestToMsgPost(req *http.Request) (*dns.Msg, error) {
func requestToMsgPost(req *http.Request) (*dns.Msg, []byte, error) {
defer req.Body.Close()
return toMsg(req.Body)
return toMsgWire(req.Body)
}
const maxDNSQuerySize = 65536
const maxBase64Len = (maxDNSQuerySize*8 + 5) / 6
// requestToMsgGet extract the dns message from the GET request.
func requestToMsgGet(req *http.Request) (*dns.Msg, error) {
func requestToMsgGet(req *http.Request) (*dns.Msg, []byte, error) {
values := req.URL.Query()
b64, ok := values["dns"]
if !ok {
return nil, fmt.Errorf("no 'dns' query parameter found")
return nil, nil, fmt.Errorf("no 'dns' query parameter found")
}
if len(b64) != 1 {
return nil, fmt.Errorf("multiple 'dns' query values found")
return nil, nil, fmt.Errorf("multiple 'dns' query values found")
}
if len(b64[0]) > maxBase64Len {
return nil, fmt.Errorf("dns query too large")
return nil, nil, fmt.Errorf("dns query too large")
}
return base64ToMsg(b64[0])
return base64ToMsgWire(b64[0])
}
func toMsg(r io.ReadCloser) (*dns.Msg, error) {
buf, err := io.ReadAll(http.MaxBytesReader(nil, r, maxDNSQuerySize))
if err != nil {
return nil, err
}
m := new(dns.Msg)
err = m.Unpack(buf)
m, _, err := toMsgWire(r)
return m, err
}
func base64ToMsg(b64 string) (*dns.Msg, error) {
func toMsgWire(r io.ReadCloser) (*dns.Msg, []byte, error) {
buf, err := io.ReadAll(http.MaxBytesReader(nil, r, maxDNSQuerySize))
if err != nil {
return nil, nil, err
}
m := new(dns.Msg)
err = m.Unpack(buf)
return m, buf, err
}
func base64ToMsgWire(b64 string) (*dns.Msg, []byte, error) {
buf, err := b64Enc.DecodeString(b64)
if err != nil {
return nil, err
return nil, nil, err
}
m := new(dns.Msg)
err = m.Unpack(buf)
return m, err
return m, buf, err
}
var b64Enc = base64.RawURLEncoding