package server import ( "fmt" "net" "time" log "git.wntrmute.dev/kyle/goutils/log" "git.wntrmute.dev/kyle/kdhcp/config" "git.wntrmute.dev/kyle/kdhcp/dhcp" "git.wntrmute.dev/kyle/kdhcp/iptools" "git.wntrmute.dev/kyle/kdhcp/leases" "github.com/benbjohnson/clock" ) const ( DefaultPool = "default" MaxPacketSize = 512 MaxResponseWait = 5 * time.Minute ) func addrReply(addr net.Addr) (net.Addr, error) { udpAddr, err := net.ResolveUDPAddr("udp4", addr.String()) if err != nil { return nil, err } if udpAddr.IP.Equal(net.IPv4zero) { return &net.UDPAddr{ IP: net.IPv4bcast, Port: udpAddr.Port, }, nil } return udpAddr, nil } type Server struct { Conn net.PacketConn Config *config.Config Pools map[string]*iptools.Pool Static map[string]*leases.Info clock clock.Clock } func (s *Server) Close() error { return s.Conn.Close() } func (s *Server) Bind() (err error) { // In order to read DHCP packets, we'll need to listen on all addresses. // That being said, we also want to limit our listening to the DHCP // network device. ip := net.IP([]byte{0, 0, 0, 0}) s.Conn, err = BindInterface(ip, s.Config.Port, s.Config.Interface) return err } func (s *Server) ReadFrom() ([]byte, net.Addr, error) { b := make([]byte, MaxPacketSize) n, addr, err := s.Conn.ReadFrom(b) if err != nil { return nil, nil, err } b = b[:n] return b, addr, nil } func (s *Server) ReadDHCPRequest() (*dhcp.Packet, net.Addr, error) { pkt, addr, err := s.ReadFrom() if err != nil { return nil, nil, err } log.Debugf("server: read packet from %s", addr) req, err := dhcp.ReadPacket(pkt) if err != nil { return nil, nil, err } return req, addr, nil } func (s *Server) WriteTo(b []byte, addr net.Addr) error { n, err := s.Conn.WriteTo(b, addr) if err != nil { return err } log.Debugf("wrote %d bytes to %s", n, addr) return err } func (s *Server) AcceptPacket() (*dhcp.Packet, net.Addr, error) { request, addr, err := s.ReadDHCPRequest() if err != nil { return nil, nil, err } log.Debugf("BOOTP request received from %x", request.HardwareAddress) return request, addr, nil } func (srv *Server) Listen() { go srv.updatePoolLoop() go func() { for { time.Sleep(5 * time.Minute) if err := srv.SaveLeases(); err != nil { log.Warningf("server: while saving leases: %s", err) } } }() for { req, addr, err := srv.AcceptPacket() if err != nil { log.Errf("server: error reading packet: %s", err) continue } if req.MessageType != dhcp.MessageTypeBootRequest { log.Debugf("ignoring BOOTP message type %d", req.MessageType) continue } switch req.DHCPType { case dhcp.DHCPMessageTypeDiscover: go srv.WriteDHCPResponse(req, addr) case dhcp.DHCPMessageTypeRequest: log.Debugf("not handling DHCP request") case dhcp.DHCPMessageTypeRelease: log.Debugf("not handling DHCP release") default: log.Debugf("not handling unknown request type %d", req.DHCPType) } } } func New(cfg *config.Config) (*Server, error) { srv := &Server{ Config: cfg, Pools: map[string]*iptools.Pool{}, Static: map[string]*leases.Info{}, clock: clock.New(), } if err := srv.loadPoolsFromConfig(); err != nil { return nil, err } if err := srv.LoadLeases(); err != nil { return nil, err } if err := srv.Bind(); err != nil { return nil, err } log.Infof("server: bound to %s:%s", cfg.Interface, srv.Conn.LocalAddr()) return srv, nil } func (srv *Server) SelectLease(req *dhcp.Packet, t time.Time) *leases.Info { if li, ok := srv.Static[req.HostName]; ok { return li } if pool, ok := srv.Pools[req.HostName]; ok { return pool.Peek(req, t, MaxResponseWait) } if pool, ok := srv.Pools[DefaultPool]; ok { return pool.Peek(req, t, MaxResponseWait) } return nil } func (srv *Server) AcceptLease(req *dhcp.Packet, li *leases.Info, t time.Time) error { if _, ok := srv.Static[li.HostName]; ok { return nil } for _, pool := range srv.Pools { if _, ok := pool.Active[li.Addr]; ok { return nil } if _, ok := pool.Limbo[li.Addr]; ok { return pool.Accept(li.Addr, t) } } return fmt.Errorf("could not accept unknown lease: %s", li) }