/*
 * Licensed to The OpenNMS Group, Inc (TOG) under one or more
 * contributor license agreements.  See the LICENSE.md file
 * distributed with this work for additional information
 * regarding copyright ownership.
 *
 * TOG licenses this file to You under the GNU Affero General
 * Public License Version 3 (the "License") or (at your option)
 * any later version.  You may not use this file except in
 * compliance with the License.  You may obtain a copy of the
 * License at:
 *
 *      https://www.gnu.org/licenses/agpl-3.0.txt
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
 * either express or implied.  See the License for the specific
 * language governing permissions and limitations under the
 * License.
 */
package org.opennms.core.test.dns;

import java.io.BufferedReader;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.InterruptedIOException;
import java.net.DatagramPacket;
import java.net.DatagramSocket;
import java.net.InetAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.StringTokenizer;
import java.util.concurrent.CountDownLatch;

import org.apache.commons.io.IOUtils;
import org.opennms.core.utils.InetAddressUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.xbill.DNS.Address;
import org.xbill.DNS.CNAMERecord;
import org.xbill.DNS.Cache;
import org.xbill.DNS.Credibility;
import org.xbill.DNS.DClass;
import org.xbill.DNS.DNAMERecord;
import org.xbill.DNS.ExtendedFlags;
import org.xbill.DNS.Flags;
import org.xbill.DNS.Header;
import org.xbill.DNS.Message;
import org.xbill.DNS.Name;
import org.xbill.DNS.NameTooLongException;
import org.xbill.DNS.OPTRecord;
import org.xbill.DNS.Opcode;
import org.xbill.DNS.RRset;
import org.xbill.DNS.Rcode;
import org.xbill.DNS.Record;
import org.xbill.DNS.Section;
import org.xbill.DNS.SetResponse;
import org.xbill.DNS.TSIG;
import org.xbill.DNS.TSIGRecord;
import org.xbill.DNS.Type;
import org.xbill.DNS.Zone;
import org.xbill.DNS.ZoneTransferException;

public class DNSServer {

    private static final Logger LOG = LoggerFactory.getLogger(DNSServer.class);

    private static final int DEFAULT_SOCKET_TIMEOUT = 100;

    private final class TCPListener implements Stoppable {
        private final int m_port;
        private final InetAddress m_addr;
        private ServerSocket m_socket;
        private volatile boolean m_stopped = false;
        private CountDownLatch m_startingLatch = new CountDownLatch(1);
        private CountDownLatch m_runningLatch = new CountDownLatch(1);

        private TCPListener(final int port, final InetAddress addr) {
            m_port = port;
            m_addr = addr;
        }

        @Override
        public void await() throws InterruptedException {
            m_startingLatch.await();
        }

        @Override
        public void run() {
            try {
                m_socket = new ServerSocket(m_port, 128, m_addr);
                m_socket.setSoTimeout(DEFAULT_SOCKET_TIMEOUT);
                m_startingLatch.countDown();
                while (!m_stopped) {
                    try {
                        final Socket s = m_socket.accept();
                        final Thread t = new Thread(new Runnable() {
                            @Override
                            public void run() {
                                InputStream is = null;
                                DataInputStream dataIn = null;
                                DataOutputStream dataOut = null;
                                try {
                                    try {
                                        is = s.getInputStream();
                                        dataIn = new DataInputStream(is);
                                        final int inLength = dataIn.readUnsignedShort();
                                        final byte[] in = new byte[inLength];
                                        dataIn.readFully(in);

                                        final Message query;
                                        byte[] response = null;
                                        try {
                                            query = new Message(in);
                                            LOG.debug("received query: {}", query);
                                            response = generateReply(query, in, in.length, s);
                                        } catch (final IOException e) {
                                            response = formerrMessage(in);
                                        }
                                        LOG.debug("returned response: {}", response == null? null : new Message(response));
                                        if (response != null) {
                                            dataOut = new DataOutputStream(s.getOutputStream());
                                            dataOut.writeShort(response.length);
                                            dataOut.write(response);
                                        }
                                    } catch (final SocketTimeoutException e) {
                                        throw e;
                                    } catch (final IOException e) {
                                        LOG.warn("error while processing socket", e);
                                    } finally {
                                        IOUtils.closeQuietly(s);
                                        IOUtils.closeQuietly(dataOut);
                                        IOUtils.closeQuietly(dataIn);
                                        IOUtils.closeQuietly(is);
                                    }
                                } catch (final SocketTimeoutException e) {
                                    LOG.trace("timed out waiting for request", e);
                                }
                            }
                        });
                        t.start();
                    } catch (final SocketTimeoutException e) {
                        LOG.trace("timed out waiting for request", e);
                    }
                }
            } catch (final IOException e) {
                LOG.warn("unable to serve socket on {}", addrport(m_addr, m_port), e);
            } finally {
                try {
                    m_socket.close();
                } catch (final IOException e) {
                    LOG.debug("error while closing socket", e);
                }
                m_runningLatch.countDown();
            }
        }

        @Override
        public void stop() {
            m_stopped = true;
            try {
                m_runningLatch.await();
            } catch (final InterruptedException e) {
                LOG.warn("interrupted while stopping TCP listener", e);
                Thread.currentThread().interrupt();
            }
        }
    }

    private final class UDPListener implements Stoppable {
        private final int m_port;
        private final InetAddress m_addr;
        private volatile boolean m_stopped = false;
        private CountDownLatch m_startingLatch = new CountDownLatch(1);
        private CountDownLatch m_runningLatch = new CountDownLatch(1);

        private UDPListener(int port, InetAddress addr) {
            m_port = port;
            m_addr = addr;
        }

        @Override
        public void await() throws InterruptedException {
            m_startingLatch.await();
        }

        @Override
        public void run() {
            DatagramSocket sock = null;
            try {
                sock = new DatagramSocket(m_port, m_addr);
                sock.setSoTimeout(DEFAULT_SOCKET_TIMEOUT);
                final short udpLength = 512;
                byte[] in = new byte[udpLength];
                final DatagramPacket indp = new DatagramPacket(in, in.length);
                DatagramPacket outdp = null;
                m_startingLatch.countDown();
                while (!m_stopped) {
                    indp.setLength(in.length);
                    try {
                        sock.receive(indp);
                    } catch (final InterruptedIOException e) {
                        continue;
                    }
                    final Message query;
                    byte[] response = null;
                    try {
                        query = new Message(in);
                        response = generateReply(query, in, indp.getLength(), null);
                        if (response == null)
                            continue;
                    } catch (final IOException e) {
                        response = formerrMessage(in);
                    }
                    if (outdp == null)
                        outdp = new DatagramPacket(response, response.length, indp.getAddress(), indp.getPort());
                    else {
                        outdp.setData(response);
                        outdp.setLength(response.length);
                        outdp.setAddress(indp.getAddress());
                        outdp.setPort(indp.getPort());
                    }
                    sock.send(outdp);
                }
            } catch (final IOException e) {
                LOG.warn("error in the UDP listener: {}", addrport(m_addr, m_port), e);
            } finally {
                if (sock != null) {
                    try {
                        sock.close();
                    } catch (final Exception e) {
                        LOG.debug("error while closing socket", e);
                    }
                }
                m_runningLatch.countDown();
            }
        }

        @Override
        public void stop() {
            m_stopped = true;
            try {
                m_runningLatch.await();
            } catch (final InterruptedException e) {
                LOG.warn("interrupted while waiting for server to stop", e);
                Thread.currentThread().interrupt();
            }
        }
    }

    static final int FLAG_DNSSECOK = 1;
    static final int FLAG_SIGONLY = 2;

    final Map<Integer, Cache> m_caches = new HashMap<Integer, Cache>();
    final Map<Name, Zone> m_znames = new HashMap<Name, Zone>();
    final Map<Name, TSIG> m_TSIGs = new HashMap<Name, TSIG>();
    final List<Integer> m_ports = new ArrayList<>();
    final List<InetAddress> m_addresses = new ArrayList<>();

    final List<Stoppable> m_activeListeners = new ArrayList<>();

    private static String addrport(final InetAddress addr, final int port) {
        return InetAddressUtils.str(addr) + "#" + port;
    }

    public DNSServer(final String conffile) throws IOException, ZoneTransferException, ConfigurationException {
        parseConfiguration(conffile);
    }

    public DNSServer() throws UnknownHostException {
    }

    public void start() throws UnknownHostException {
        initializeDefaults();

        for (final InetAddress addr : m_addresses) {
            for (final Integer port : m_ports) {
                try {
                    final UDPListener udpListener = new UDPListener(port, addr);
                    final Thread udpThread = new Thread(udpListener);
                    udpThread.start();
                    m_activeListeners.add(udpListener);

                    final TCPListener tcpListener = new TCPListener(port, addr);
                    final Thread tcpThread = new Thread(tcpListener);
                    tcpThread.start();
                    m_activeListeners.add(tcpListener);

                    udpListener.await();
                    tcpListener.await();

                    LOG.info("listening on {}", addrport(addr, port));
                } catch (final Exception e) {
                    throw new RuntimeException("Failed to start DNS server on " + InetAddressUtils.str(addr) + "/" + port, e);
                }
            }
        }
        LOG.debug("finished starting up");
    }

    public void stop() {
        for (final Stoppable listener : m_activeListeners) {
            LOG.debug("stopping {}", listener);
            listener.stop();
            LOG.debug("stopped {}", listener);
        }
    }

    protected void parseConfiguration(final String conffile) throws ConfigurationException, IOException,
    ZoneTransferException, UnknownHostException {
        final FileInputStream fs;
        final InputStreamReader isr;
        final BufferedReader br;
        try {
            fs = new FileInputStream(conffile);
            isr = new InputStreamReader(fs);
            br = new BufferedReader(isr);
        } catch (final Exception e) {
            LOG.error("Cannot open {}", conffile, e);
            throw new ConfigurationException("unable to read from " + conffile, e);
        }

        try {
            String line = null;
            while ((line = br.readLine()) != null) {
                final StringTokenizer st = new StringTokenizer(line);
                if (!st.hasMoreTokens()) {
                    continue;
                }
                final String keyword = st.nextToken();
                if (!st.hasMoreTokens()) {
                    LOG.warn("unable to parse line: {}", line);
                    continue;
                }
                if (keyword.charAt(0) == '#') {
                    continue;
                }
                if (keyword.equals("primary")) {
                    addPrimaryZone(st.nextToken(), st.nextToken());
                } else if (keyword.equals("secondary")) {
                    addSecondaryZone(st.nextToken(), st.nextToken());
                } else if (keyword.equals("cache")) {
                    final Cache cache = new Cache(st.nextToken());
                    m_caches.put(Integer.valueOf(DClass.IN), cache);
                } else if (keyword.equals("key")) {
                    final String s1 = st.nextToken();
                    final String s2 = st.nextToken();
                    if (st.hasMoreTokens()) {
                        addTSIG(s1, s2, st.nextToken());
                    } else {
                        addTSIG("hmac-md5", s1, s2);
                    }
                } else if (keyword.equals("port")) {
                    m_ports.add(Integer.valueOf(st.nextToken()));
                } else if (keyword.equals("address")) {
                    final String addr = st.nextToken();
                    m_addresses.add(Address.getByAddress(addr));
                } else {
                    LOG.warn("unknown keyword: {}", keyword);
                }

            }
        } finally {
            fs.close();
        }
    }

    protected void initializeDefaults() throws UnknownHostException {
        if (m_ports.size() == 0) {
            m_ports.add(Integer.valueOf(53));
        }

        if (m_addresses.size() == 0) {
            m_addresses.add(Address.getByAddress("0.0.0.0"));
        }
    }

    public void addPort(final int port) {
        m_ports.add(port);
    }

    public void setPorts(final List<Integer> ports) {
        if (m_ports == ports) return;
        m_ports.clear();
        m_ports.addAll(ports);
    }

    public void addAddress(final InetAddress address) {
        m_addresses.add(address);
    }

    public void setAddresses(final List<InetAddress> addresses) {
        if (m_addresses == addresses) return;
        m_addresses.clear();
        m_addresses.addAll(addresses);
    }

    public void addZone(final Zone zone) {
        m_znames.put(zone.getOrigin(), zone);
    }

    public void addPrimaryZone(final String zname, final String zonefile) throws IOException {
        Name origin = null;
        if (zname != null)
            origin = Name.fromString(zname, Name.root);
        final Zone newzone = new Zone(origin, zonefile);
        m_znames.put(newzone.getOrigin(), newzone);
    }

    public void addSecondaryZone(final String zone, final String remote) throws IOException, ZoneTransferException {
        final Name zname = Name.fromString(zone, Name.root);
        final Zone newzone = new Zone(zname, DClass.IN, remote);
        m_znames.put(zname, newzone);
    }

    public void addTSIG(final String algstr, final String namestr, final String key) throws IOException {
        final Name name = Name.fromString(namestr, Name.root);
        m_TSIGs.put(name, new TSIG(algstr, namestr, key));
    }

    public Cache getCache(final int dclass) {
        Cache c = m_caches.get(dclass);
        if (c == null) {
            c = new Cache(dclass);
            m_caches.put(Integer.valueOf(dclass), c);
        }
        return c;
    }

    public Zone findBestZone(final Name name) {
        Zone foundzone = m_znames.get(name);
        if (foundzone != null) {
            return foundzone;
        }
        final int labels = name.labels();
        for (int i = 1; i < labels; i++) {
            final Name tname = new Name(name, i);
            foundzone = m_znames.get(tname);
            if (foundzone != null) {
                return foundzone;
            }
        }
        return null;
    }

    public RRset findExactMatch(final Name name, final int type, final int dclass, final boolean glue) {
        final Zone zone = findBestZone(name);
        if (zone != null) {
            return zone.findExactMatch(name, type);
        } else {
            final List<RRset> rrsets;
            final Cache cache = getCache(dclass);
            if (glue) {
                rrsets = cache.findAnyRecords(name, type);
            } else {
                rrsets = cache.findRecords(name, type);
            }
            if (rrsets == null || rrsets.isEmpty()) {
                return null;
            } else {
                return rrsets.get(0); /* not quite right */
            }
        }
    }

    void addRRset(final Name name, final Message response, final RRset rrset, final int section, final int flags) {
        for (int s = 1; s <= section; s++) {
            if (response.findRRset(name, rrset.getType(), s)) return;
        }
        if ((flags & FLAG_SIGONLY) == 0) {
            for (final Record r : rrset.rrs()) {
                if (r.getName().isWild() && !name.isWild()) {
                    response.addRecord(r.withName(name), section);
                } else {
                    response.addRecord(r, section);
                }
            }
        }
        if ((flags & (FLAG_SIGONLY | FLAG_DNSSECOK)) != 0) {
            for (final Record r : rrset.sigs()) {
                if (r.getName().isWild() && !name.isWild()) {
                    response.addRecord(r.withName(name), section);
                } else {
                    response.addRecord(r, section);
                }
            }
        }
    }

    private final void addSOA(final Message response, final Zone zone) {
        response.addRecord(zone.getSOA(), Section.AUTHORITY);
    }

    private final void addNS(final Message response, final Zone zone, final int flags) {
        final RRset nsRecords = zone.getNS();
        addRRset(nsRecords.getName(), response, nsRecords, Section.AUTHORITY, flags);
    }

    private final void addCacheNS(final Message response, final Cache cache, final Name name) {
        final SetResponse sr = cache.lookupRecords(name, Type.NS, Credibility.HINT);
        if (!sr.isDelegation()) return;
        final RRset nsRecords = sr.getNS();
        for (final Record r : nsRecords.rrs()) {
            response.addRecord(r, Section.AUTHORITY);
        }
    }

    private void addGlue(final Message response, final Name name, final int flags) {
        final RRset a = findExactMatch(name, Type.A, DClass.IN, true);
        if (a == null) return;
        addRRset(name, response, a, Section.ADDITIONAL, flags);
    }

    private void addAdditional2(final Message response, final int section, final int flags) {
        final Record[] records = response.getSectionArray(section);
        for (final Record r : records) {
            final Name glueName = r.getAdditionalName();
            if (glueName != null) addGlue(response, glueName, flags);
        }
    }

    private final void addAdditional(final Message response, final int flags) {
        addAdditional2(response, Section.ANSWER, flags);
        addAdditional2(response, Section.AUTHORITY, flags);
    }

    byte addAnswer(final Message response, final Name name, int type, int dclass, int iterations, int flags) {
        SetResponse sr;
        byte rcode = Rcode.NOERROR;

        if (iterations > 6)
            return Rcode.NOERROR;

        if (type == Type.SIG || type == Type.RRSIG) {
            type = Type.ANY;
            flags |= FLAG_SIGONLY;
        }

        final Zone zone = findBestZone(name);
        if (zone != null)
            sr = zone.findRecords(name, type);
        else {
            sr = getCache(dclass).lookupRecords(name, type, Credibility.NORMAL);
        }

        if (sr.isUnknown()) {
            addCacheNS(response, getCache(dclass), name);
        }
        if (sr.isNXDOMAIN()) {
            response.getHeader().setRcode(Rcode.NXDOMAIN);
            if (zone != null) {
                addSOA(response, zone);
                if (iterations == 0) response.getHeader().setFlag(Flags.AA);
            }
            rcode = Rcode.NXDOMAIN;
        } else if (sr.isNXRRSET()) {
            if (zone != null) {
                addSOA(response, zone);
                if (iterations == 0) response.getHeader().setFlag(Flags.AA);
            }
        } else if (sr.isDelegation()) {
            final RRset nsRecords = sr.getNS();
            addRRset(nsRecords.getName(), response, nsRecords, Section.AUTHORITY, flags);
        } else if (sr.isCNAME()) {
            final CNAMERecord cname = sr.getCNAME();
            addRRset(name, response, new RRset(cname), Section.ANSWER, flags);
            if (zone != null && iterations == 0) response.getHeader().setFlag(Flags.AA);
            rcode = addAnswer(response, cname.getTarget(), type, dclass, iterations + 1, flags);
        } else if (sr.isDNAME()) {
            final DNAMERecord dname = sr.getDNAME();
            RRset rrset = new RRset(dname);
            addRRset(name, response, rrset, Section.ANSWER, flags);
            final Name newname;
            try {
                newname = name.fromDNAME(dname);
            } catch (final NameTooLongException e) {
                return Rcode.YXDOMAIN;
            }
            rrset = new RRset(new CNAMERecord(name, dclass, 0, newname));
            addRRset(name, response, rrset, Section.ANSWER, flags);
            if (zone != null && iterations == 0)
                response.getHeader().setFlag(Flags.AA);
            rcode = addAnswer(response, newname, type, dclass, iterations + 1, flags);
        } else if (sr.isSuccessful()) {
            final List<RRset> rrsets = sr.answers();
            for (final RRset rrset : rrsets) {
                addRRset(name, response, rrset, Section.ANSWER, flags);
            }
            if (zone != null) {
                addNS(response, zone, flags);
                if (iterations == 0)
                    response.getHeader().setFlag(Flags.AA);
            } else
                addCacheNS(response, getCache(dclass), name);
        }
        return rcode;
    }

    byte[] doAXFR(final Name name, final Message query, final TSIG tsig, TSIGRecord qtsig, final Socket s) {
        final Zone zone = m_znames.get(name);
        boolean first = true;
        if (zone == null) {
            return errorMessage(query, Rcode.REFUSED);
        }
        final Iterator<RRset> it = zone.AXFR();
        try {
            final DataOutputStream dataOut = new DataOutputStream(s.getOutputStream());
            int id = query.getHeader().getID();
            while (it.hasNext()) {
                final RRset rrset = it.next();
                final Message response = new Message(id);
                final Header header = response.getHeader();
                header.setFlag(Flags.QR);
                header.setFlag(Flags.AA);
                addRRset(rrset.getName(), response, rrset, Section.ANSWER, FLAG_DNSSECOK);
                if (tsig != null) {
                    tsig.applyStream(response, qtsig, first);
                    qtsig = response.getTSIG();
                }
                first = false;
                final byte[] out = response.toWire();
                dataOut.writeShort(out.length);
                dataOut.write(out);
            }
        } catch (final IOException ex) {
            LOG.warn("AXFR failed", ex);
        }
        try {
            s.close();
        } catch (final IOException ex) {
            LOG.warn("error closing socket", ex);
        }
        return null;
    }

    /*
     * Note: a null return value means that the caller doesn't need to do
     * anything. Currently this only happens if this is an AXFR request over
     * TCP.
     */
    byte[] generateReply(final Message query, final byte[] in, final int length, final Socket s) {
        final Header header = query.getHeader();
        int maxLength;
        int flags = 0;

        if (header.getFlag(Flags.QR))
            return null;
        if (header.getRcode() != Rcode.NOERROR)
            return errorMessage(query, Rcode.FORMERR);
        if (header.getOpcode() != Opcode.QUERY)
            return errorMessage(query, Rcode.NOTIMP);

        final Record queryRecord = query.getQuestion();

        final TSIGRecord queryTSIG = query.getTSIG();
        TSIG tsig = null;
        if (queryTSIG != null) {
            tsig = m_TSIGs.get(queryTSIG.getName());
            if (tsig == null || tsig.verify(query, in, length, null) != Rcode.NOERROR)
                return formerrMessage(in);
        }

        final OPTRecord queryOPT = query.getOPT();

        if (s != null)
            maxLength = 65535;
        else if (queryOPT != null)
            maxLength = Math.max(queryOPT.getPayloadSize(), 512);
        else
            maxLength = 512;

        if (queryOPT != null && (queryOPT.getFlags() & ExtendedFlags.DO) != 0)
            flags = FLAG_DNSSECOK;

        final Message response = new Message(query.getHeader().getID());
        response.getHeader().setFlag(Flags.QR);
        if (query.getHeader().getFlag(Flags.RD)) {
            response.getHeader().setFlag(Flags.RD);
        }
        response.addRecord(queryRecord, Section.QUESTION);

        final Name name = queryRecord.getName();
        final int type = queryRecord.getType();
        final int dclass = queryRecord.getDClass();
        if ((type == Type.AXFR || type == Type.IXFR) && s != null)
            return doAXFR(name, query, tsig, queryTSIG, s);
        if (!Type.isRR(type) && type != Type.ANY)
            return errorMessage(query, Rcode.NOTIMP);

        final byte rcode = addAnswer(response, name, type, dclass, 0, flags);
        if (rcode != Rcode.NOERROR && rcode != Rcode.NXDOMAIN)
            return errorMessage(query, rcode);

        addAdditional(response, flags);

        if (queryOPT != null) {
            final int optflags = (flags == FLAG_DNSSECOK) ? ExtendedFlags.DO : 0;
            final OPTRecord opt = new OPTRecord((short) 4096, rcode, (byte) 0, optflags);
            response.addRecord(opt, Section.ADDITIONAL);
        }

        response.setTSIG(tsig, Rcode.NOERROR, queryTSIG);
        return response.toWire(maxLength);
    }

    byte[] buildErrorMessage(final Header header, final int rcode, final Record question) {
        final Message response = new Message();
        response.setHeader(header);
        for (int i = 0; i < 4; i++)
            response.removeAllRecords(i);
        if (rcode == Rcode.SERVFAIL)
            response.addRecord(question, Section.QUESTION);
        header.setRcode(rcode);
        return response.toWire();
    }

    public byte[] formerrMessage(final byte[] in) {
        try {
            return buildErrorMessage(new Header(in), Rcode.FORMERR, null);
        } catch (final IOException e) {
            LOG.debug("unable to build error message", e);
            return null;
        }
    }

    public byte[] errorMessage(final Message query, final int rcode) {
        return buildErrorMessage(query.getHeader(), rcode, query.getQuestion());
    }
}
