diff --git a/scapy/layers/dns.py b/scapy/layers/dns.py index 3681696a7b7..864ead516ad 100644 --- a/scapy/layers/dns.py +++ b/scapy/layers/dns.py @@ -1271,9 +1271,10 @@ def __getattr__(self, attr): class DNS(DNSCompressedPacket): name = "DNS" + FORCE_TCP = False fields_desc = [ ConditionalField(ShortField("length", None), - lambda p: isinstance(p.underlayer, TCP)), + lambda p: p.FORCE_TCP or isinstance(p.underlayer, TCP)), ShortField("id", 0), BitField("qr", 0, 1), BitEnumField("opcode", 0, 4, {0: "QUERY", 1: "IQUERY", 2: "STATUS"}), @@ -1300,7 +1301,7 @@ class DNS(DNSCompressedPacket): def get_full(self): # Required for DNSCompressedPacket - if isinstance(self.underlayer, TCP): + if isinstance(self.underlayer, TCP) or self.FORCE_TCP: return self.original[2:] else: return self.original @@ -1332,7 +1333,10 @@ def mysummary(self): ) def post_build(self, pkt, pay): - if isinstance(self.underlayer, TCP) and self.length is None: + if ( + (isinstance(self.underlayer, TCP) or self.FORCE_TCP) and + self.length is None + ): pkt = struct.pack("!H", len(pkt) - 2) + pkt[2:] return pkt + pay @@ -1363,6 +1367,14 @@ def pre_dissect(self, s): return s +class DNSTCP(DNS): + """ + A DNS packet that is always under TCP + """ + FORCE_TCP = True + match_subclass = True + + bind_layers(UDP, DNS, dport=5353) bind_layers(UDP, DNS, sport=5353) bind_layers(UDP, DNS, dport=53) @@ -1413,16 +1425,18 @@ def dns_resolve(qname, qtype="A", raw=False, tcp=False, verbose=1, timeout=3, ** try: # Spawn a socket, connect to the nameserver on port 53 if tcp: + cls = DNSTCP sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) else: + cls = DNS sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.settimeout(kwargs["timeout"]) sock.connect((nameserver, 53)) # Connected. Wrap it with DNS - sock = StreamSocket(sock, DNS) + sock = StreamSocket(sock, cls) # I/O res = sock.sr1( - DNS(qd=[DNSQR(qname=qname, qtype=qtype)], id=RandShort()), + cls(qd=[DNSQR(qname=qname, qtype=qtype)], id=RandShort()), **kwargs, ) except IOError as ex: diff --git a/test/regression.uts b/test/regression.uts index 0cd5b0e1479..97bba309ca8 100644 --- a/test/regression.uts +++ b/test/regression.uts @@ -3500,16 +3500,11 @@ import socket sck = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sck.connect(("8.8.8.8", 53)) -class DNSTCP(Packet): - name = "DNS over TCP" - fields_desc = [ FieldLenField("len", None, fmt="!H", length_of="dns"), - PacketLenField("dns", 0, DNS, length_from=lambda p: p.len)] - ssck = StreamSocket(sck, DNSTCP) -r = ssck.sr1(DNSTCP(dns=DNS(rd=1, qd=DNSQR(qname="www.example.com"))), timeout=3) +r = ssck.sr1(DNSTCP(rd=1, qd=DNSQR(qname="www.example.com")), timeout=3) sck.close() -assert DNSTCP in r and len(r.dns.an) +assert DNSTCP in r and len(r.an) ############ + Tests of SSLStreamContext