# Copyright (C) Internet Systems Consortium, Inc. ("ISC") # # SPDX-License-Identifier: MPL-2.0 # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, you can obtain one at https://mozilla.org/MPL/2.0/. # # See the COPYRIGHT file distributed with this work for additional # information regarding copyright ownership. import os import time from typing import Any, Callable, Optional import dns.query import dns.message import isctest.log from isctest.compat import dns_rcode QUERY_TIMEOUT = 10 def generic_query( query_func: Callable[..., Any], message: dns.message.Message, ip: str, port: Optional[int] = None, source: Optional[str] = None, timeout: int = QUERY_TIMEOUT, attempts: int = 10, expected_rcode: dns_rcode = None, verify: bool = False, log_query: bool = True, log_response: bool = True, ) -> Any: if port is None: if query_func.__name__ == "tls": port = int(os.environ["TLSPORT"]) else: port = int(os.environ["PORT"]) query_args = { "q": message, "where": ip, "timeout": timeout, "port": port, "source": source, } if query_func.__name__ == "tls": query_args["verify"] = verify res = None for attempt in range(attempts): log_msg = ( f"isc.query.{query_func.__name__}(): ip={ip}, port={port}, source={source}, " f"timeout={timeout}, attempts left={attempts-attempt}" ) if log_query: log_msg += f"\n{message.to_text()}" log_query = False # only log query on first attempt isctest.log.debug(log_msg) try: res = query_func(**query_args) except (dns.exception.Timeout, ConnectionRefusedError) as e: isctest.log.debug( f"isc.query.{query_func.__name__}(): the '{e}' exception raised" ) else: if log_response: isctest.log.debug( f"isc.query.{query_func.__name__}(): response\n{res.to_text()}" ) if res.rcode() == expected_rcode or expected_rcode is None: return res time.sleep(1) if expected_rcode is not None: last_rcode = dns_rcode.to_text(res.rcode()) if res else None isctest.log.debug( f"isc.query.{query_func.__name__}(): expected rcode={dns_rcode.to_text(expected_rcode)}, last rcode={last_rcode}" ) raise dns.exception.Timeout def udp(*args, **kwargs) -> Any: return generic_query(dns.query.udp, *args, **kwargs) def tcp(*args, **kwargs) -> Any: return generic_query(dns.query.tcp, *args, **kwargs) def tls(*args, **kwargs) -> Any: try: return generic_query(dns.query.tls, *args, **kwargs) except TypeError as e: raise RuntimeError( "dnspython 2.5.0 or newer is required for isctest.query.tls()" ) from e