package dhcp

import (
	"bytes"
	"errors"
	"fmt"
	"io"

	log "git.wntrmute.dev/kyle/goutils/log"
)

var optionRegistry = map[OptionTag]Option{
	OptionTagPadding:              OptionPad,
	OptionTagHostName:             OptionHostName,
	OptionTagMessageType:          OptionMessageType,
	OptionTagParameterRequestList: OptionParameterRequestList,
	OptionTagClientID:             OptionClientID,
	OptionTagEnd:                  OptionEnd,
}

func OptionPad(req *BootRequest, r io.Reader) error {
	// The padding option is a single 0 byte octet.
	return nil
}

func getOptionLength(r io.Reader) (int, error) {
	var length uint8
	read := newPacketReaderFunc(r)

	if err := read(&length); err != nil {
		return -1, fmt.Errorf("dhcp: reading option length for DHCP Message Type")
	} else if length == 0 {
		return -1, errors.New("dhcp: read option length 0, but expected option length for DHCP Host Name is >= 1")
	}

	return int(length), nil
}

// OptionHostName reads a DHCP host name option.
//
// 3.14. Host Name Option
//
// This option specifies the name of the client.  The name may or may
// not be qualified with the local domain name (see section 3.17 for the
// preferred way to retrieve the domain name).  See RFC 1035 for
// character set restrictions.

// The code for this option is 12, and its minimum length is 1.
func OptionHostName(req *BootRequest, r io.Reader) error {
	length, err := getOptionLength(r)
	if err != nil {
		return err
	}

	hostName := make([]byte, int(length))
	if n, err := r.Read(hostName); err != nil {
		return fmt.Errorf("dhcp: while reading hostname: %w", err)
	} else if n != int(length) {
		return fmt.Errorf("dhcp: only read %d bytes of hostname, expected %d bytes", n, length)
	}

	req.HostName = string(hostName)
	return nil
}

func OptionMessageType(req *BootRequest, r io.Reader) error {
	read := newPacketReaderFunc(r)

	if length, err := getOptionLength(r); err != nil {
		return err
	} else if length != 1 {
		return fmt.Errorf("dhcp: read option length %d, but expected option length for DHCP Message Type is 1", length)
	}

	if err := read(&req.DHCPType); err != nil {
		return err
	}

	return nil
}

func OptionParameterRequestList(req *BootRequest, r io.Reader) error {
	length, err := getOptionLength(r)
	if err != nil {
		return err
	} else if length == 0 {
		return fmt.Errorf("dhcp: read option length %d, but expected option length for DHCP Parameter Request is >= 1", length)
	}

	var parameters = make([]byte, int(length))
	if n, err := r.Read(parameters); err != nil {
		return fmt.Errorf("dhcp: while reading parameters: %w", err)
	} else if n != int(length) {
		return fmt.Errorf("dhcp: only read %d octets of requested parameters, expected %d octets", n, length)
	}

	for _, parameter := range parameters {
		opt := OptionTag(parameter)
		log.Debugf("client is requesting %s", opt)
		req.ParameterRequests = append(req.ParameterRequests, opt)
	}

	return nil
}

func OptionClientID(req *BootRequest, r io.Reader) error {
	length, err := getOptionLength(r)
	if err != nil {
		return err
	} else if length == 0 {
		return fmt.Errorf("dhcp: read option length %d, but expected option length for DHCP Parameter Request is >= 1", length)
	}

	var clientID = make([]byte, int(length))
	if n, err := r.Read(clientID); err != nil {
		return fmt.Errorf("dhcp: while reading client ID: %w", err)
	} else if n != int(length) {
		return fmt.Errorf("dhcp: only read %d bytes of client ID, expected %d bytes", n, length)
	}

	req.ClientID = string(clientID)
	return nil
}

func OptionEnd(req *BootRequest, r io.Reader) error {
	req.endOptions = true
	return nil
}

func ReadOption(req *BootRequest, tag byte, r io.Reader) error {
	opt := OptionTag(tag)
	if f, ok := optionRegistry[opt]; ok {
		log.Debugf("dhcp: reading option=%s", opt)
		return f(req, r)
	}

	return readUnknownOption(req, tag, r)
}

func readUnknownOption(req *BootRequest, tag byte, r io.Reader) error {
	length, err := getOptionLength(r)
	if err != nil {
		return err
	} else if length == 0 {
		log.Debugf("skipped option %d/%02x with length 0", tag, tag)
		return nil
	}

	var data = make([]byte, length)
	if n, err := r.Read(data); err != nil {
		return fmt.Errorf("dhcp: while skipping unknown tag %d/%02x: %w", tag, tag, err)
	} else if n != int(length) {
		return fmt.Errorf("dhcp: only read %d bytes of unknown tag %d/%02x, expected %d bytes", n, tag, tag, length)
	}

	log.Infof("skipped unknown tag %d/%02x with data %0x", tag, tag, data)
	return nil
}

const magicCookieLength = 4

var magicCookie = []byte{99, 130, 83, 99}

func ReadMagicCookie(r io.Reader) error {
	var cookie = make([]byte, magicCookieLength)
	n, err := r.Read(cookie)
	if err != nil {
		return fmt.Errorf("dhcp: failed to read magic cookie: %w", err)
	} else if n != magicCookieLength {
		return fmt.Errorf("dhcp: read %d bytes, expected to read %d bytes for the magic cookie",
			n, magicCookieLength)
	}

	if !bytes.Equal(cookie, magicCookie) {
		return fmt.Errorf("dhcp: read magic cookie %x, expected %x", cookie, magicCookie)
	}

	return nil
}