/*
 * Decompiled with CFR 0.152.
 */
package org.xbill.DNS;

import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.time.Duration;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executor;
import java.util.concurrent.ForkJoinPool;
import lombok.Generated;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.xbill.DNS.DClass;
import org.xbill.DNS.EDNSOption;
import org.xbill.DNS.Message;
import org.xbill.DNS.Name;
import org.xbill.DNS.NioTcpClient;
import org.xbill.DNS.NioUdpClient;
import org.xbill.DNS.OPTRecord;
import org.xbill.DNS.Rcode;
import org.xbill.DNS.Record;
import org.xbill.DNS.Resolver;
import org.xbill.DNS.ResolverConfig;
import org.xbill.DNS.TSIG;
import org.xbill.DNS.Type;
import org.xbill.DNS.WireParseException;
import org.xbill.DNS.ZoneTransferException;
import org.xbill.DNS.ZoneTransferIn;

public class SimpleResolver
implements Resolver {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(SimpleResolver.class);
    public static final int DEFAULT_PORT = 53;
    public static final int DEFAULT_EDNS_PAYLOADSIZE = 1280;
    private InetSocketAddress address;
    private InetSocketAddress localAddress;
    private boolean useTCP;
    private boolean ignoreTruncation;
    private OPTRecord queryOPT = new OPTRecord(1280, 0, 0, 0);
    private TSIG tsig;
    private Duration timeoutValue = Duration.ofSeconds(10L);
    private static final short DEFAULT_UDPSIZE = 512;
    private static InetSocketAddress defaultResolver = new InetSocketAddress(InetAddress.getLoopbackAddress(), 53);

    public SimpleResolver() throws UnknownHostException {
        this((String)null);
    }

    public SimpleResolver(String hostname) throws UnknownHostException {
        if (hostname == null) {
            this.address = ResolverConfig.getCurrentConfig().server();
            if (this.address == null) {
                this.address = defaultResolver;
            }
            return;
        }
        InetAddress addr = "0".equals(hostname) ? InetAddress.getLoopbackAddress() : InetAddress.getByName(hostname);
        this.address = new InetSocketAddress(addr, 53);
    }

    public SimpleResolver(InetSocketAddress host) {
        this.address = Objects.requireNonNull(host, "host must not be null");
    }

    public SimpleResolver(InetAddress host) {
        Objects.requireNonNull(host, "host must not be null");
        this.address = new InetSocketAddress(host, 53);
    }

    public InetSocketAddress getAddress() {
        return this.address;
    }

    public static void setDefaultResolver(InetSocketAddress hostname) {
        defaultResolver = hostname;
    }

    public static void setDefaultResolver(String hostname) {
        defaultResolver = new InetSocketAddress(hostname, 53);
    }

    public int getPort() {
        return this.address.getPort();
    }

    @Override
    public void setPort(int port) {
        this.address = new InetSocketAddress(this.address.getAddress(), port);
    }

    public void setAddress(InetSocketAddress addr) {
        this.address = addr;
    }

    public void setAddress(InetAddress addr) {
        this.address = new InetSocketAddress(addr, this.address.getPort());
    }

    public void setLocalAddress(InetSocketAddress addr) {
        this.localAddress = addr;
    }

    public void setLocalAddress(InetAddress addr) {
        this.localAddress = new InetSocketAddress(addr, 0);
    }

    public boolean getTCP() {
        return this.useTCP;
    }

    @Override
    public void setTCP(boolean flag) {
        this.useTCP = flag;
    }

    public boolean getIgnoreTruncation() {
        return this.ignoreTruncation;
    }

    @Override
    public void setIgnoreTruncation(boolean flag) {
        this.ignoreTruncation = flag;
    }

    public OPTRecord getEDNS() {
        return this.queryOPT;
    }

    public void setEDNS(OPTRecord optRecord) {
        this.queryOPT = optRecord;
    }

    @Override
    public void setEDNS(int version, int payloadSize, int flags, List<EDNSOption> options) {
        switch (version) {
            case -1: {
                this.queryOPT = null;
                break;
            }
            case 0: {
                if (payloadSize == 0) {
                    payloadSize = 1280;
                }
                this.queryOPT = new OPTRecord(payloadSize, 0, version, flags, options);
                break;
            }
            default: {
                throw new IllegalArgumentException("invalid EDNS version - must be 0 or -1 to disable");
            }
        }
    }

    public TSIG getTSIGKey() {
        return this.tsig;
    }

    @Override
    public void setTSIGKey(TSIG key) {
        this.tsig = key;
    }

    @Override
    public void setTimeout(Duration timeout) {
        this.timeoutValue = timeout;
    }

    @Override
    public Duration getTimeout() {
        return this.timeoutValue;
    }

    private Message parseMessage(byte[] b) throws WireParseException {
        try {
            return new Message(b);
        }
        catch (IOException e) {
            if (!(e instanceof WireParseException)) {
                throw new WireParseException("Error parsing message", e);
            }
            throw (WireParseException)e;
        }
    }

    private void verifyTSIG(Message query, Message response, byte[] b) {
        if (this.tsig == null) {
            return;
        }
        int error = this.tsig.verify(response, b, query.getGeneratedTSIG());
        log.debug("TSIG verify on message id {}: {}", (Object)query.getHeader().getID(), (Object)Rcode.TSIGstring(error));
    }

    private void applyEDNS(Message query) {
        if (this.queryOPT == null || query.getOPT() != null) {
            return;
        }
        query.addRecord(this.queryOPT, 3);
    }

    private int maxUDPSize(Message query) {
        OPTRecord opt = query.getOPT();
        if (opt == null) {
            return 512;
        }
        return opt.getPayloadSize();
    }

    @Override
    public CompletionStage<Message> sendAsync(Message query) {
        return this.sendAsync(query, ForkJoinPool.commonPool());
    }

    @Override
    public CompletionStage<Message> sendAsync(Message query, Executor executor) {
        Record question;
        if (query.getHeader().getOpcode() == 0 && (question = query.getQuestion()) != null && question.getType() == 252) {
            CompletableFuture<Message> f = new CompletableFuture<Message>();
            CompletableFuture.runAsync(() -> {
                try {
                    f.complete(this.sendAXFR(query));
                }
                catch (IOException e) {
                    f.completeExceptionally(e);
                }
            }, executor);
            return f;
        }
        Message ednsTsigQuery = query.clone();
        this.applyEDNS(ednsTsigQuery);
        if (this.tsig != null) {
            ednsTsigQuery.setTSIG(this.tsig, 0, null);
        }
        return this.sendAsync(ednsTsigQuery, this.useTCP, executor);
    }

    CompletableFuture<Message> sendAsync(Message query, boolean forceTcp, Executor executor) {
        boolean tcp;
        int qid = query.getHeader().getID();
        byte[] out = query.toWire(65535);
        int udpSize = this.maxUDPSize(query);
        boolean bl = tcp = forceTcp || out.length > udpSize;
        if (log.isTraceEnabled()) {
            log.trace("Sending {}/{}, id={} to {}/{}:{}, query:\n{}", new Object[]{query.getQuestion().getName(), Type.string(query.getQuestion().getType()), qid, tcp ? "tcp" : "udp", this.address.getAddress().getHostAddress(), this.address.getPort(), query});
        } else if (log.isDebugEnabled()) {
            log.debug("Sending {}/{}, id={} to {}/{}:{}", new Object[]{query.getQuestion().getName(), Type.string(query.getQuestion().getType()), qid, tcp ? "tcp" : "udp", this.address.getAddress().getHostAddress(), this.address.getPort()});
        }
        CompletableFuture<byte[]> result = tcp ? NioTcpClient.sendrecv(this.localAddress, this.address, query, out, this.timeoutValue) : NioUdpClient.sendrecv(this.localAddress, this.address, out, udpSize, this.timeoutValue);
        return result.thenComposeAsync(in -> {
            Message response;
            CompletableFuture<Message> f = new CompletableFuture<Message>();
            if (((byte[])in).length < 12) {
                f.completeExceptionally(new WireParseException("invalid DNS header - too short"));
                return f;
            }
            int id = ((in[0] & 0xFF) << 8) + (in[1] & 0xFF);
            if (id != qid) {
                f.completeExceptionally(new WireParseException("invalid message id: expected " + qid + "; got id " + id));
                return f;
            }
            try {
                response = this.parseMessage((byte[])in);
            }
            catch (WireParseException e) {
                f.completeExceptionally(e);
                return f;
            }
            if (!query.getQuestion().getName().equals(response.getQuestion().getName())) {
                f.completeExceptionally(new WireParseException("invalid name in message: expected " + query.getQuestion().getName() + "; got " + response.getQuestion().getName()));
                return f;
            }
            if (query.getQuestion().getDClass() != response.getQuestion().getDClass()) {
                f.completeExceptionally(new WireParseException("invalid class in message: expected " + DClass.string(query.getQuestion().getDClass()) + "; got " + DClass.string(response.getQuestion().getDClass())));
                return f;
            }
            if (query.getQuestion().getType() != response.getQuestion().getType()) {
                f.completeExceptionally(new WireParseException("invalid type in message: expected " + Type.string(query.getQuestion().getType()) + "; got " + Type.string(response.getQuestion().getType())));
                return f;
            }
            this.verifyTSIG(query, response, (byte[])in);
            if (!tcp && !this.ignoreTruncation && response.getHeader().getFlag(6)) {
                if (log.isTraceEnabled()) {
                    log.trace("Got truncated response for id {}, retrying via TCP, response:\n{}", (Object)qid, (Object)response);
                } else {
                    log.debug("Got truncated response for id {}, retrying via TCP", (Object)qid);
                }
                return this.sendAsync(query, true, executor);
            }
            response.setResolver(this);
            f.complete(response);
            return f;
        }, executor);
    }

    private Message sendAXFR(Message query) throws IOException {
        Name qname = query.getQuestion().getName();
        ZoneTransferIn xfrin2 = ZoneTransferIn.newAXFR(qname, this.address, this.tsig);
        xfrin2.setTimeout(this.timeoutValue);
        xfrin2.setLocalAddress(this.localAddress);
        try {
            xfrin2.run();
        }
        catch (ZoneTransferException e) {
            throw new WireParseException(e.getMessage());
        }
        List<Record> records = xfrin2.getAXFR();
        Message response = new Message(query.getHeader().getID());
        response.getHeader().setFlag(5);
        response.getHeader().setFlag(0);
        response.addRecord(query.getQuestion(), 0);
        for (Record r : records) {
            response.addRecord(r, 1);
        }
        return response;
    }

    public String toString() {
        return "SimpleResolver [" + this.address + "]";
    }
}

