200 lines
4.1 KiB
Go
200 lines
4.1 KiB
Go
package server
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"time"
|
|
|
|
log "git.wntrmute.dev/kyle/goutils/log"
|
|
"git.wntrmute.dev/kyle/kdhcp/config"
|
|
"git.wntrmute.dev/kyle/kdhcp/dhcp"
|
|
"git.wntrmute.dev/kyle/kdhcp/iptools"
|
|
"git.wntrmute.dev/kyle/kdhcp/leases"
|
|
"github.com/benbjohnson/clock"
|
|
)
|
|
|
|
const (
|
|
DefaultPool = "default"
|
|
MaxPacketSize = 512
|
|
MaxResponseWait = 5 * time.Minute
|
|
)
|
|
|
|
func addrReply(addr net.Addr) (net.Addr, error) {
|
|
udpAddr, err := net.ResolveUDPAddr("udp4", addr.String())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if udpAddr.IP.Equal(net.IPv4zero) {
|
|
return &net.UDPAddr{
|
|
IP: net.IPv4bcast,
|
|
Port: udpAddr.Port,
|
|
}, nil
|
|
}
|
|
|
|
return udpAddr, nil
|
|
}
|
|
|
|
type Server struct {
|
|
Conn net.PacketConn
|
|
Config *config.Config
|
|
Pools map[string]*iptools.Pool
|
|
Static map[string]*leases.Info
|
|
|
|
clock clock.Clock
|
|
}
|
|
|
|
func (s *Server) Close() error {
|
|
return s.Conn.Close()
|
|
}
|
|
|
|
func (s *Server) Bind() (err error) {
|
|
// In order to read DHCP packets, we'll need to listen on all addresses.
|
|
// That being said, we also want to limit our listening to the DHCP
|
|
// network device.
|
|
ip := net.IP([]byte{0, 0, 0, 0})
|
|
s.Conn, err = BindInterface(ip, s.Config.Port, s.Config.Interface)
|
|
return err
|
|
}
|
|
|
|
func (s *Server) ReadFrom() ([]byte, net.Addr, error) {
|
|
b := make([]byte, MaxPacketSize)
|
|
n, addr, err := s.Conn.ReadFrom(b)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
b = b[:n]
|
|
return b, addr, nil
|
|
}
|
|
|
|
func (s *Server) ReadDHCPRequest() (*dhcp.Packet, net.Addr, error) {
|
|
pkt, addr, err := s.ReadFrom()
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
log.Debugf("server: read packet from %s", addr)
|
|
req, err := dhcp.ReadPacket(pkt)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
return req, addr, nil
|
|
}
|
|
|
|
func (s *Server) WriteTo(b []byte, addr net.Addr) error {
|
|
n, err := s.Conn.WriteTo(b, addr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
log.Debugf("wrote %d bytes to %s", n, addr)
|
|
return err
|
|
}
|
|
|
|
func (s *Server) AcceptPacket() (*dhcp.Packet, net.Addr, error) {
|
|
request, addr, err := s.ReadDHCPRequest()
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
log.Debugf("BOOTP request received from %x", request.HardwareAddress)
|
|
return request, addr, nil
|
|
}
|
|
|
|
func (srv *Server) Listen() {
|
|
go srv.updatePoolLoop()
|
|
go func() {
|
|
for {
|
|
time.Sleep(5 * time.Minute)
|
|
if err := srv.SaveLeases(); err != nil {
|
|
log.Warningf("server: while saving leases: %s", err)
|
|
}
|
|
}
|
|
}()
|
|
|
|
for {
|
|
req, addr, err := srv.AcceptPacket()
|
|
if err != nil {
|
|
log.Errf("server: error reading packet: %s", err)
|
|
continue
|
|
}
|
|
|
|
if req.MessageType != dhcp.MessageTypeBootRequest {
|
|
log.Debugf("ignoring BOOTP message type %d", req.MessageType)
|
|
continue
|
|
}
|
|
|
|
switch req.DHCPType {
|
|
case dhcp.DHCPMessageTypeDiscover:
|
|
go srv.WriteDHCPResponse(req, addr)
|
|
case dhcp.DHCPMessageTypeRequest:
|
|
log.Debugf("not handling DHCP request")
|
|
case dhcp.DHCPMessageTypeRelease:
|
|
log.Debugf("not handling DHCP release")
|
|
default:
|
|
log.Debugf("not handling unknown request type %d", req.DHCPType)
|
|
}
|
|
}
|
|
}
|
|
|
|
func New(cfg *config.Config) (*Server, error) {
|
|
srv := &Server{
|
|
Config: cfg,
|
|
Pools: map[string]*iptools.Pool{},
|
|
Static: map[string]*leases.Info{},
|
|
clock: clock.New(),
|
|
}
|
|
|
|
if err := srv.loadPoolsFromConfig(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := srv.LoadLeases(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := srv.Bind(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
log.Infof("server: bound to %s:%s", cfg.Interface, srv.Conn.LocalAddr())
|
|
|
|
return srv, nil
|
|
}
|
|
|
|
func (srv *Server) SelectLease(req *dhcp.Packet, t time.Time) *leases.Info {
|
|
if li, ok := srv.Static[req.HostName]; ok {
|
|
return li
|
|
}
|
|
|
|
if pool, ok := srv.Pools[req.HostName]; ok {
|
|
return pool.Peek(req, t, MaxResponseWait)
|
|
}
|
|
|
|
if pool, ok := srv.Pools[DefaultPool]; ok {
|
|
return pool.Peek(req, t, MaxResponseWait)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (srv *Server) AcceptLease(req *dhcp.Packet, li *leases.Info, t time.Time) error {
|
|
if _, ok := srv.Static[li.HostName]; ok {
|
|
return nil
|
|
}
|
|
|
|
for _, pool := range srv.Pools {
|
|
if _, ok := pool.Active[li.Addr]; ok {
|
|
return nil
|
|
}
|
|
|
|
if _, ok := pool.Limbo[li.Addr]; ok {
|
|
return pool.Accept(li.Addr, t)
|
|
}
|
|
}
|
|
|
|
return fmt.Errorf("could not accept unknown lease: %s", li)
|
|
}
|