diff --git a/config/BUILD.bazel b/config/BUILD.bazel index c5a66e8..8c9639d 100644 --- a/config/BUILD.bazel +++ b/config/BUILD.bazel @@ -10,6 +10,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//iptools", + "//leases", "@dev_wntrmute_git_kyle_goutils//log", "@in_gopkg_yaml_v2//:yaml_v2", ], diff --git a/config/config.go b/config/config.go index d79d132..3289c57 100644 --- a/config/config.go +++ b/config/config.go @@ -8,6 +8,7 @@ import ( log "git.wntrmute.dev/kyle/goutils/log" "git.wntrmute.dev/kyle/kdhcp/iptools" + "git.wntrmute.dev/kyle/kdhcp/leases" "gopkg.in/yaml.v2" ) @@ -33,6 +34,24 @@ type Network struct { Domain string `yaml:"domain"` } +func (n Network) NetworkInfo() *leases.Server { + lnet := &leases.Network{ + Mask: iptools.NetIPtoAddr(n.Mask), + Gateway: iptools.NetIPtoAddr(n.Gateway), + Domain: n.Domain, + Broadcast: iptools.NetIPtoAddr(n.Broadcast), + } + + for _, addr := range n.DNS { + lnet.DNS = append(lnet.DNS, iptools.NetIPtoAddr(addr)) + } + + return &leases.Server{ + Addr: iptools.NetIPtoAddr(n.IP), + Network: lnet, + } +} + func (n *Network) ensureV4() (err error) { n.IP, err = ensureV4(n.IP) if err != nil { diff --git a/deps.bzl b/deps.bzl index 19959cc..0748420 100644 --- a/deps.bzl +++ b/deps.bzl @@ -4,8 +4,8 @@ def go_dependencies(): go_repository( name = "com_github_benbjohnson_clock", importpath = "github.com/benbjohnson/clock", - sum = "h1:g+rSsSaAzhHJYcIQE78hJ3AhyjjtQvleKDjlhdBnIhc=", - version = "v1.3.3", + sum = "h1:wj3BFPrTw8yYgA1OlMqvUk95nc8OMv3cvBSF5erT2W4=", + version = "v1.3.4", ) go_repository( diff --git a/dhcp.pcap b/dhcp.pcap new file mode 100644 index 0000000..f11d34f Binary files /dev/null and b/dhcp.pcap differ diff --git a/dhcp/BUILD.bazel b/dhcp/BUILD.bazel index bd1c59c..325ead5 100644 --- a/dhcp/BUILD.bazel +++ b/dhcp/BUILD.bazel @@ -1,13 +1,25 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "dhcp", srcs = [ + "offer.go", "options.go", "packet.go", + "parameters.go", "read_options.go", ], importpath = "git.wntrmute.dev/kyle/kdhcp/dhcp", visibility = ["//visibility:public"], - deps = ["@dev_wntrmute_git_kyle_goutils//log"], + deps = [ + "//leases", + "@dev_wntrmute_git_kyle_goutils//assert", + "@dev_wntrmute_git_kyle_goutils//log", + ], +) + +go_test( + name = "dhcp_test", + srcs = ["parameters_test.go"], + embed = [":dhcp"], ) diff --git a/dhcp/offer.go b/dhcp/offer.go new file mode 100644 index 0000000..230c2b1 --- /dev/null +++ b/dhcp/offer.go @@ -0,0 +1,36 @@ +package dhcp + +import ( + "net/netip" + + "git.wntrmute.dev/kyle/kdhcp/leases" +) + +func NewOffer(request *Packet, server *leases.Server, lease *leases.Info) (offer *Packet, err error) { + packet := &Packet{ + MessageType: MessageTypeBootResponse, + HardwareType: request.HardwareType, + HardwareAddress: request.HardwareAddress, + Hops: 0, + TransactionID: request.TransactionID, + SecondsElapsed: 0, + Flags: request.Flags, + ServerName: "", + FileName: "", + ClientIP: request.ClientIP, + YourIP: lease.Addr, + NextIP: netip.IPv4Unspecified(), + RelayIP: netip.IPv4Unspecified(), + DHCPType: DHCPMessageTypeOffer, + HostName: "", + ClientID: request.ClientID, + Parameters: []Parameter{}, + } + + packet.AddParameter(ParameterDNS(server.Network.DNS)) + packet.AddParameter(ParameterHostName(lease.HostName)) + packet.AddParameter(ParameterRouter(server.Network.Gateway)) + packet.AddParameter(ParameterSubnetMask(server.Network.Mask)) + + return packet, nil +} diff --git a/dhcp/options.go b/dhcp/options.go index 67b3221..3dcd6a8 100644 --- a/dhcp/options.go +++ b/dhcp/options.go @@ -10,13 +10,13 @@ type OptionTag uint8 func (opt OptionTag) String() string { s, ok := optionStrings[opt] if !ok { - panic(fmt.Sprintf("no string for option %d", opt)) + return fmt.Sprintf("", opt) } return s } -type Option func(req *BootRequest, r io.Reader) error +type Option func(req *Packet, r io.Reader) error const ( OptionTagPadding OptionTag = 0 diff --git a/dhcp/packet.go b/dhcp/packet.go index fbb8c45..4e8ed1f 100644 --- a/dhcp/packet.go +++ b/dhcp/packet.go @@ -5,10 +5,12 @@ import ( "encoding/binary" "fmt" "io" - "net" + "net/netip" "strings" + "git.wntrmute.dev/kyle/goutils/assert" log "git.wntrmute.dev/kyle/goutils/log" + "git.wntrmute.dev/kyle/kdhcp/leases" ) const ( @@ -17,7 +19,10 @@ const ( maxFileName = 128 ) -var anyAddr = net.IP([]byte{0, 0, 0, 0}) +const ( + MessageTypeBootRequest = 1 + MessageTypeBootResponse = 2 +) func formatMAC(mac []byte) string { s := []string{} @@ -28,27 +33,16 @@ func formatMAC(mac []byte) string { return strings.Join(s, ":") } -type BootRequest struct { - MessageType uint8 - HardwareType uint8 - HardwareAddress []byte - Hops uint8 - TransactionID uint32 - SecondsElapsed uint16 - Flags uint16 - ServerName string - FileName string +func readIPv4(r io.Reader) (netip.Addr, error) { + var buf [4]byte - ClientIP net.IP - YourIP net.IP - NextIP net.IP - RelayIP net.IP + n, err := r.Read(buf[:]) + if err != nil { + return netip.Addr{}, fmt.Errorf("dhcp: while reading IPv4 address from reader: %w", err) + } - DHCPType DHCPMessageType // option 53 - HostName string // option 12 - ClientID string // option 61 - ParameterRequests []OptionTag - endOptions bool + assert.Bool(n == 4, fmt.Sprintf("read %d bytes, but expected to read 4 bytes", n)) + return netip.AddrFrom4(buf), nil } func newPacketReaderFunc(r io.Reader) func(v any) error { @@ -57,87 +51,113 @@ func newPacketReaderFunc(r io.Reader) func(v any) error { } } -func (req *BootRequest) Read(packet []byte) error { +func newPacketWriterFunc(w io.Writer) func(v any) error { + return func(v any) error { + return binary.Write(w, binary.BigEndian, v) + } +} + +type Packet struct { + MessageType uint8 + HardwareType uint8 + HardwareAddress leases.HardwareAddress + Hops uint8 + TransactionID uint32 + SecondsElapsed uint16 + Flags uint16 + ServerName string + FileName string + + ClientIP netip.Addr + YourIP netip.Addr + NextIP netip.Addr + RelayIP netip.Addr + + DHCPType DHCPMessageType // option 53 + HostName string // option 12 + ClientID string // option 61 + ParameterRequests []OptionTag + Parameters []Parameter + endOptions bool +} + +func (pkt *Packet) Read(packet []byte) (err error) { buf := bytes.NewBuffer(packet) read := newPacketReaderFunc(buf) - if err := read(&req.MessageType); err != nil { + if err = read(&pkt.MessageType); err != nil { return err } - if err := read(&req.HardwareType); err != nil { + if err = read(&pkt.HardwareType); err != nil { return err } var hwaLength uint8 - if err := read(&hwaLength); err != nil { + if err = read(&hwaLength); err != nil { return err } - if err := read(&req.Hops); err != nil { + if err = read(&pkt.Hops); err != nil { return err } - if err := read(&req.TransactionID); err != nil { + if err = read(&pkt.TransactionID); err != nil { return err } - if err := read(&req.SecondsElapsed); err != nil { + if err = read(&pkt.SecondsElapsed); err != nil { return err } - if err := read(&req.Flags); err != nil { + if err = read(&pkt.Flags); err != nil { return err } - req.ClientIP = anyAddr[:] - if _, err := buf.Read(req.ClientIP); err != nil { + if pkt.ClientIP, err = readIPv4(buf); err != nil { return err } - req.YourIP = anyAddr[:] - if _, err := buf.Read(req.YourIP); err != nil { + if pkt.YourIP, err = readIPv4(buf); err != nil { return err } - req.NextIP = anyAddr[:] - if _, err := buf.Read(req.NextIP); err != nil { + if pkt.NextIP, err = readIPv4(buf); err != nil { return err } - req.RelayIP = anyAddr[:] - if _, err := buf.Read(req.RelayIP); err != nil { + if pkt.RelayIP, err = readIPv4(buf); err != nil { return err } - req.HardwareAddress = make([]byte, int(hwaLength)) - if _, err := buf.Read(req.HardwareAddress); err != nil { + pkt.HardwareAddress = make([]byte, int(hwaLength)) + if _, err = buf.Read(pkt.HardwareAddress); err != nil { return err } hwaPad := make([]byte, maxHardwareAddrLen-hwaLength) - if _, err := buf.Read(hwaPad); err != nil { + if _, err = buf.Read(hwaPad); err != nil { return err } tempBuf := make([]byte, maxServerName) - if _, err := buf.Read(tempBuf); err != nil { + if _, err = buf.Read(tempBuf); err != nil { return err } - req.ServerName = string(bytes.Trim(tempBuf, "\x00")) + pkt.ServerName = string(bytes.Trim(tempBuf, "\x00")) tempBuf = make([]byte, maxFileName) - if _, err := buf.Read(tempBuf); err != nil { + if _, err = buf.Read(tempBuf); err != nil { return err } - req.FileName = string(bytes.Trim(tempBuf, "\x00")) + pkt.FileName = string(bytes.Trim(tempBuf, "\x00")) - if err := ReadMagicCookie(buf); err != nil { + if err = ReadMagicCookie(buf); err != nil { return err } for { - if req.endOptions { + if pkt.endOptions { break } @@ -149,9 +169,9 @@ func (req *BootRequest) Read(packet []byte) error { return err } - err = ReadOption(req, tag, buf) + err = ReadOption(pkt, tag, buf) if err != nil { - log.Spew(*req) + log.Spew(*pkt) log.Spew(packet) return err } @@ -159,13 +179,119 @@ func (req *BootRequest) Read(packet []byte) error { return nil } -func ReadRequest(pkt []byte) (*BootRequest, error) { - req := &BootRequest{} - err := req.Read(pkt) +func ReadPacket(pkt []byte) (*Packet, error) { + packet := &Packet{} + err := packet.Read(pkt) if err != nil { return nil, err } - log.Debugf("dhcp: BOOTP request with txid %d for %s", req.TransactionID, formatMAC(req.HardwareAddress)) - return req, nil + log.Debugf("dhcp: BOOTP packet with txid %d for %s", packet.TransactionID, formatMAC(packet.HardwareAddress)) + return packet, nil +} + +func (pkt *Packet) Write(w io.Writer) error { + write := newPacketWriterFunc(w) + + if err := write(pkt.MessageType); err != nil { + return err + } + + if err := write(pkt.HardwareType); err != nil { + return err + } + + if err := write(uint8(len(pkt.HardwareAddress))); err != nil { + return err + } + + if err := write(pkt.Hops); err != nil { + return err + } + + if err := write(pkt.TransactionID); err != nil { + return err + } + + if err := write(pkt.SecondsElapsed); err != nil { + return err + } + + if err := write(pkt.Flags); err != nil { + return err + } + + if err := write(pkt.ClientIP.AsSlice()); err != nil { + return err + } + + if err := write(pkt.YourIP.AsSlice()); err != nil { + return err + } + + if err := write(pkt.NextIP.AsSlice()); err != nil { + return err + } + + if err := write(pkt.RelayIP.AsSlice()); err != nil { + return err + } + + padding := make([]byte, maxHardwareAddrLen-len(pkt.HardwareAddress)) + if err := write(pkt.HardwareAddress); err != nil { + return err + } + + if err := write(padding); err != nil { + return err + } + + padding = make([]byte, maxServerName-len(pkt.ServerName)) + if err := write([]byte(pkt.ServerName)); err != nil { + return err + } + + if err := write(padding); err != nil { + return err + } + + padding = make([]byte, maxFileName-len(pkt.FileName)) + if err := write([]byte(pkt.FileName)); err != nil { + return err + } + + if err := write(padding); err != nil { + return err + } + + if err := write(magicCookie); err != nil { + return err + } + + // TODO: write parameter request write + + for _, p := range pkt.Parameters { + if !p.Valid() { + continue + } + + if err := write(p.Bytes()); err != nil { + return err + } + } + + if err := write(OptionTagEnd); err != nil { + return err + } + + return nil +} + +func WritePacket(pkt *Packet) ([]byte, error) { + buf := &bytes.Buffer{} + if err := pkt.Write(buf); err != nil { + return nil, err + } + + return buf.Bytes(), nil } diff --git a/dhcp/parameters.go b/dhcp/parameters.go new file mode 100644 index 0000000..f50ad0b --- /dev/null +++ b/dhcp/parameters.go @@ -0,0 +1,124 @@ +package dhcp + +import "net/netip" + +type Parameter interface { + Code() OptionTag + Len() int + Bytes() []byte // the serialized parameter + Valid() bool +} + +func (packet *Packet) AddParameter(p Parameter) { + packet.Parameters = append(packet.Parameters, p) +} + +type parameterWithAddr struct { + code OptionTag + addr netip.Addr +} + +func (p parameterWithAddr) Code() OptionTag { + return p.code +} + +func (p parameterWithAddr) Len() int { + return len(p.addr.AsSlice()) +} + +func (p parameterWithAddr) Bytes() []byte { + return append([]byte{byte(p.code), uint8(p.Len())}, + p.addr.AsSlice()...) +} + +func (p parameterWithAddr) Valid() bool { + return p.addr.Is4() && p.addr.IsValid() && !p.addr.IsUnspecified() +} + +type parameterWithAddrs struct { + code OptionTag + addrs []netip.Addr +} + +func (p parameterWithAddrs) Code() OptionTag { + return p.code +} + +func (p parameterWithAddrs) Len() int { + plen := 0 + + for _, addr := range p.addrs { + plen += len(addr.AsSlice()) + } + + return plen +} + +func (p parameterWithAddrs) Bytes() []byte { + out := []byte{byte(p.code), byte(p.Len())} + + for _, addr := range p.addrs { + out = append(out, addr.AsSlice()...) + } + + return out +} + +func (p parameterWithAddrs) Valid() bool { + for _, addr := range p.addrs { + if !(addr.Is4() && addr.IsValid() && !addr.IsUnspecified()) { + return false + } + } + return true +} + +type parameterWithString struct { + code OptionTag + data string +} + +func (p parameterWithString) Code() OptionTag { + return p.code +} + +func (p parameterWithString) Len() int { + return len(p.data) +} + +func (p parameterWithString) Bytes() []byte { + out := []byte{byte(p.code), byte(p.Len())} + return append(out, []byte(p.data)...) +} + +func (p parameterWithString) Valid() bool { + return len(p.data) > 0 +} + +func ParameterRouter(router netip.Addr) Parameter { + return ¶meterWithAddr{ + code: OptionTagRouter, + addr: router, + } +} + +func ParameterSubnetMask(mask netip.Addr) Parameter { + return ¶meterWithAddr{ + code: OptionTagSubnetMask, + addr: mask, + } +} + +func ParameterDNS(dns []netip.Addr) Parameter { + return ¶meterWithAddrs{ + code: OptionTagDomainNameServer, + addrs: dns, + } +} + +func ParameterHostName(name string) Parameter { + return ¶meterWithString{ + code: OptionTagHostName, + data: name, + } +} diff --git a/dhcp/parameters_test.go b/dhcp/parameters_test.go new file mode 100644 index 0000000..3d5b2f1 --- /dev/null +++ b/dhcp/parameters_test.go @@ -0,0 +1,35 @@ +package dhcp + +import ( + "bytes" + "net/netip" + "testing" +) + +func Test_parameterSubnetMask(t *testing.T) { + expected := []byte{ + byte(OptionTagSubnetMask), 4, 255, 255, 255, 0, + } + + prefix := netip.MustParseAddr("255.255.255.0") + param := ParameterSubnetMask(prefix) + + out := param.Bytes() + if !bytes.Equal(expected, out) { + t.Fatalf("invalid parameter subnet mask: have %x, want %x", out, expected) + } +} + +func Test_parameterRouter(t *testing.T) { + expected := []byte{ + byte(OptionTagRouter), 4, 192, 168, 1, 1, + } + + router := netip.MustParseAddr("192.168.1.1") + param := ParameterRouter(router) + + out := param.Bytes() + if !bytes.Equal(expected, out) { + t.Fatalf("invalid parameter router: have %x, want %x", out, expected) + } +} diff --git a/dhcp/read_options.go b/dhcp/read_options.go index 44fd552..5c0123a 100644 --- a/dhcp/read_options.go +++ b/dhcp/read_options.go @@ -18,7 +18,7 @@ var optionRegistry = map[OptionTag]Option{ OptionTagEnd: OptionEnd, } -func OptionPad(req *BootRequest, r io.Reader) error { +func OptionPad(req *Packet, r io.Reader) error { // The padding option is a single 0 byte octet. return nil } @@ -46,7 +46,7 @@ func getOptionLength(r io.Reader) (int, error) { // character set restrictions. // The code for this option is 12, and its minimum length is 1. -func OptionHostName(req *BootRequest, r io.Reader) error { +func OptionHostName(req *Packet, r io.Reader) error { length, err := getOptionLength(r) if err != nil { return err @@ -63,7 +63,7 @@ func OptionHostName(req *BootRequest, r io.Reader) error { return nil } -func OptionMessageType(req *BootRequest, r io.Reader) error { +func OptionMessageType(req *Packet, r io.Reader) error { read := newPacketReaderFunc(r) if length, err := getOptionLength(r); err != nil { @@ -79,7 +79,7 @@ func OptionMessageType(req *BootRequest, r io.Reader) error { return nil } -func OptionParameterRequestList(req *BootRequest, r io.Reader) error { +func OptionParameterRequestList(req *Packet, r io.Reader) error { length, err := getOptionLength(r) if err != nil { return err @@ -103,7 +103,7 @@ func OptionParameterRequestList(req *BootRequest, r io.Reader) error { return nil } -func OptionClientID(req *BootRequest, r io.Reader) error { +func OptionClientID(req *Packet, r io.Reader) error { length, err := getOptionLength(r) if err != nil { return err @@ -122,12 +122,12 @@ func OptionClientID(req *BootRequest, r io.Reader) error { return nil } -func OptionEnd(req *BootRequest, r io.Reader) error { +func OptionEnd(req *Packet, r io.Reader) error { req.endOptions = true return nil } -func ReadOption(req *BootRequest, tag byte, r io.Reader) error { +func ReadOption(req *Packet, tag byte, r io.Reader) error { opt := OptionTag(tag) if f, ok := optionRegistry[opt]; ok { log.Debugf("dhcp: reading option=%s", opt) @@ -137,7 +137,7 @@ func ReadOption(req *BootRequest, tag byte, r io.Reader) error { return readUnknownOption(req, tag, r) } -func readUnknownOption(req *BootRequest, tag byte, r io.Reader) error { +func readUnknownOption(req *Packet, tag byte, r io.Reader) error { length, err := getOptionLength(r) if err != nil { return err diff --git a/go.mod b/go.mod index 1a3b7bc..35c529e 100644 --- a/go.mod +++ b/go.mod @@ -2,16 +2,14 @@ module git.wntrmute.dev/kyle/kdhcp go 1.20 -require github.com/hashicorp/go-syslog v1.0.0 // indirect - require ( + git.wntrmute.dev/kyle/goutils v1.7.0 + github.com/benbjohnson/clock v1.3.4 github.com/peterbourgon/ff/v3 v3.3.0 gopkg.in/yaml.v2 v2.4.0 ) -require github.com/davecgh/go-spew v1.1.1 // indirect - require ( - git.wntrmute.dev/kyle/goutils v1.7.0 - github.com/benbjohnson/clock v1.3.3 + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/hashicorp/go-syslog v1.0.0 // indirect ) diff --git a/go.sum b/go.sum index cd26237..c6aedc3 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ git.wntrmute.dev/kyle/goutils v1.7.0 h1:+lk6uUMcpJK49sEGEMCOns3WVd2ThH/htMWnsyXEGl8= git.wntrmute.dev/kyle/goutils v1.7.0/go.mod h1:hMcPr+XSYXjQ/IRTziNVYmUmb9BPATZc+cyehSjBs0w= -github.com/benbjohnson/clock v1.3.3 h1:g+rSsSaAzhHJYcIQE78hJ3AhyjjtQvleKDjlhdBnIhc= -github.com/benbjohnson/clock v1.3.3/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= +github.com/benbjohnson/clock v1.3.4 h1:wj3BFPrTw8yYgA1OlMqvUk95nc8OMv3cvBSF5erT2W4= +github.com/benbjohnson/clock v1.3.4/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/hashicorp/go-syslog v1.0.0 h1:KaodqZuhUoZereWVIYmpUgZysurB1kBLX2j0MwMrUAE= diff --git a/iptools/BUILD.bazel b/iptools/BUILD.bazel index 3c5dbb1..7ba4966 100644 --- a/iptools/BUILD.bazel +++ b/iptools/BUILD.bazel @@ -3,8 +3,6 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "iptools", srcs = [ - "hardware_address.go", - "lease_info.go", "pool.go", "range.go", "tools.go", @@ -12,6 +10,8 @@ go_library( importpath = "git.wntrmute.dev/kyle/kdhcp/iptools", visibility = ["//visibility:public"], deps = [ + "//dhcp", + "//leases", "@dev_wntrmute_git_kyle_goutils//assert", "@dev_wntrmute_git_kyle_goutils//log", ], @@ -21,7 +21,6 @@ go_test( name = "iptools_test", size = "small", srcs = [ - "hardware_address_test.go", "pool_test.go", "range_test.go", ], diff --git a/iptools/pool.go b/iptools/pool.go index f247ea3..14e6e7a 100644 --- a/iptools/pool.go +++ b/iptools/pool.go @@ -1,28 +1,39 @@ package iptools import ( + "bytes" "fmt" "net/netip" "sync" "time" - "git.wntrmute.dev/kyle/goutils/assert" "git.wntrmute.dev/kyle/goutils/log" + "git.wntrmute.dev/kyle/kdhcp/dhcp" + "git.wntrmute.dev/kyle/kdhcp/leases" ) const DefaultExpiry = 168 * time.Hour type Pool struct { - Name string `yaml:"name"` - Range *Range `yaml:"range"` - Expiry time.Duration `yaml:"expiry"` - Available []*LeaseInfo `yaml:"available"` - Active map[netip.Addr]*LeaseInfo `yaml:"active"` - Limbo map[netip.Addr]*LeaseInfo `yaml:"limbo"` // leases that are currently being offered - NoHostName bool `yaml:"no_hostname"` // don't set the hostname + Name string `yaml:"name" json:"name"` + Expiry time.Duration `yaml:"expiry" json:"expiry"` + Available []*leases.Info `yaml:"available" json:"available"` + Active map[netip.Addr]*leases.Info `yaml:"active" json:"active"` + Limbo map[netip.Addr]*leases.Info `yaml:"limbo" json:"limbo"` // leases that are currently being offered + NoHostName bool `yaml:"no_hostname" json:"no_hostname"` // don't set the hostname mtx *sync.Mutex } +func (p *Pool) checkMaps() { + if p.Active == nil { + p.Active = map[netip.Addr]*leases.Info{} + } + + if p.Limbo == nil { + p.Limbo = map[netip.Addr]*leases.Info{} + } +} + func (p *Pool) lock() { if p.mtx == nil { p.mtx = &sync.Mutex{} @@ -49,14 +60,14 @@ func NewPool(name string, r *Range) (*Pool, error) { Expiry: r.Expiry, NoHostName: r.NoHostName, Available: enumerateRange(name, r, true), - Limbo: map[netip.Addr]*LeaseInfo{}, + Limbo: map[netip.Addr]*leases.Info{}, } return p, nil } func (p *Pool) sort() { - p.Available = SortLeases(p.Available) + p.Available = leases.Sort(p.Available) } func (p *Pool) Sort() { @@ -73,30 +84,51 @@ func (p *Pool) IsAddressAvailable() bool { // Peek returns the first available address from the pool and moves it // from available to limbo. When the client is given the address, Accept // should be called on the address to move it from limbo to active. -func (p *Pool) Peek(t time.Time, waitFor time.Duration) *LeaseInfo { +func (p *Pool) Peek(req *dhcp.Packet, t time.Time, waitFor time.Duration) *leases.Info { p.lock() defer p.unlock() - li := p.peek(t, waitFor) + li := p.peek(req, t, waitFor) return li } -func (p *Pool) peek(t time.Time, waitFor time.Duration) *LeaseInfo { +func (p *Pool) peek(req *dhcp.Packet, t time.Time, waitFor time.Duration) *leases.Info { + p.checkMaps() + + for _, li := range p.Active { + if bytes.Equal(req.HardwareAddress, li.HardwareAddress) { + log.Debugf("returning existing lease to %x of %s", req.HardwareAddress, li.Addr) + return li + } + } + + for _, li := range p.Limbo { + if bytes.Equal(req.HardwareAddress, li.HardwareAddress) { + log.Debugf("returning existing offer to %x of %s", req.HardwareAddress, li.Addr) + return li + } + } + if len(p.Available) == 0 { return nil } lease := p.Available[0] p.Available = p.Available[1:] - lease.ResetExpiry(t, waitFor) + lease.HardwareAddress = req.HardwareAddress + lease.ResetExpiry(t, waitFor) p.Limbo[lease.Addr] = lease return lease } func (p *Pool) accept(addr netip.Addr) error { - assert.Bool(p.Active[addr] == nil, fmt.Sprintf("limbo address %s is already active: %#v", addr, *p.Active[addr])) + p.checkMaps() + + if active, ok := p.Active[addr]; ok { + return fmt.Errorf("limbo address %s is already active: %#v", addr, *active) + } p.Active[addr] = p.Limbo[addr] delete(p.Limbo, addr) @@ -124,6 +156,8 @@ func (p *Pool) Update(t time.Time) bool { } func (p *Pool) update(t time.Time) bool { + p.checkMaps() + var updated bool for k, v := range p.Active { @@ -150,6 +184,8 @@ func (p *Pool) update(t time.Time) bool { return updated } -func (p *Pool) Renew(lease *LeaseInfo, t time.Time) { +func (p *Pool) Renew(lease *leases.Info, t time.Time) { + p.checkMaps() + p.Active[lease.Addr].ResetExpiry(t, p.Expiry) } diff --git a/iptools/range.go b/iptools/range.go index 50cfec4..889093f 100644 --- a/iptools/range.go +++ b/iptools/range.go @@ -11,11 +11,11 @@ const ( ) type Range struct { - Start netip.Addr `yaml:"start"` - End netip.Addr `yaml:"end"` - Network netip.Prefix `yaml:"network"` - Expiry time.Duration `yaml:"expiry"` - NoHostName bool `yaml:"no_hostname"` // don't set the hostname + Start netip.Addr `yaml:"start" json:"start"` + End netip.Addr `yaml:"end" json:"end"` + Network netip.Prefix `yaml:"network" json:"network"` + Expiry time.Duration `yaml:"expiry" json:"expiry"` + NoHostName bool `yaml:"no_hostname" json:"no_hostname"` // don't set the hostname } func (r *Range) Validate() error { diff --git a/iptools/tools.go b/iptools/tools.go index 0a4c1f1..5e2b1b2 100644 --- a/iptools/tools.go +++ b/iptools/tools.go @@ -2,9 +2,13 @@ package iptools import ( "fmt" + "net" + "net/netip" + + "git.wntrmute.dev/kyle/kdhcp/leases" ) -func enumerateRange(name string, r *Range, startFromOne bool) []*LeaseInfo { +func enumerateRange(name string, r *Range, startFromOne bool) []*leases.Info { start := r.Start cur := start lenfmt := fmt.Sprintf("%%s%%0%dd", len(fmt.Sprintf("%d", r.numHosts()))) @@ -12,11 +16,11 @@ func enumerateRange(name string, r *Range, startFromOne bool) []*LeaseInfo { if startFromOne { i++ } - leases := []*LeaseInfo{} + lrange := []*leases.Info{} for r.End.Compare(cur) >= 0 { hostName := fmt.Sprintf(lenfmt, name, i) - leases = append(leases, &LeaseInfo{ + lrange = append(lrange, &leases.Info{ HostName: hostName, Addr: cur, }) @@ -24,5 +28,18 @@ func enumerateRange(name string, r *Range, startFromOne bool) []*LeaseInfo { cur = cur.Next() } - return leases + return lrange +} + +func NetIPtoAddr(ip net.IP) netip.Addr { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + return netip.IPv4Unspecified() + } + + if !addr.Is4() { + return netip.IPv4Unspecified() + } + + return addr } diff --git a/kdhcpd.yaml b/kdhcpd.yaml index 857294e..f47c6e5 100644 --- a/kdhcpd.yaml +++ b/kdhcpd.yaml @@ -1,6 +1,6 @@ kdhcp: version: 1 - lease_file: /tmp/kdhcp_lease.yaml + lease_file: /tmp/kdhcp_lease.json interface: enp89s0 port: 67 network: diff --git a/leases/BUILD.bazel b/leases/BUILD.bazel new file mode 100644 index 0000000..09a609a --- /dev/null +++ b/leases/BUILD.bazel @@ -0,0 +1,19 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "leases", + srcs = [ + "hardware_address.go", + "lease_info.go", + "server_info.go", + ], + importpath = "git.wntrmute.dev/kyle/kdhcp/leases", + visibility = ["//visibility:public"], +) + +go_test( + name = "leases_test", + size = "small", + srcs = ["hardware_address_test.go"], + embed = [":leases"], +) diff --git a/iptools/hardware_address.go b/leases/hardware_address.go similarity index 97% rename from iptools/hardware_address.go rename to leases/hardware_address.go index f36a542..6a95e3e 100644 --- a/iptools/hardware_address.go +++ b/leases/hardware_address.go @@ -1,4 +1,4 @@ -package iptools +package leases import ( "bytes" diff --git a/iptools/hardware_address_test.go b/leases/hardware_address_test.go similarity index 96% rename from iptools/hardware_address_test.go rename to leases/hardware_address_test.go index 9d8f273..772aa29 100644 --- a/iptools/hardware_address_test.go +++ b/leases/hardware_address_test.go @@ -1,4 +1,4 @@ -package iptools +package leases import ( "bytes" diff --git a/iptools/lease_info.go b/leases/lease_info.go similarity index 63% rename from iptools/lease_info.go rename to leases/lease_info.go index ca97a9c..69b6f71 100644 --- a/iptools/lease_info.go +++ b/leases/lease_info.go @@ -1,4 +1,4 @@ -package iptools +package leases import ( "fmt" @@ -8,49 +8,49 @@ import ( "time" ) -type LeaseInfo struct { - HostName string `yaml:"hostname"` - Addr netip.Addr `yaml:"addr"` - HardwareAddress HardwareAddress `yaml:"mac_addr"` - Expires time.Time `yaml:"expires"` +type Info struct { + HostName string `yaml:"hostname" json:"hostname,omitempty"` + Addr netip.Addr `yaml:"addr" json:"addr,omitempty"` + HardwareAddress HardwareAddress `yaml:"mac_addr" json:"hardware_address,omitempty"` + Expires time.Time `yaml:"expires" json:"expires,omitempty"` } -func (li *LeaseInfo) String() string { - return fmt.Sprintf("lease[hostname=%s addr=%s hw=%x expires=%s]", li.HostName, li.Addr, li.HardwareAddress, li.Expires) +func (li *Info) String() string { + return fmt.Sprintf("lease[hostname=%s addr=%s hw=%s expires=%s]", li.HostName, li.Addr, li.HardwareAddress, li.Expires) } -type sortableLease []*LeaseInfo +type sortableLease []*Info func (a sortableLease) Len() int { return len(a) } func (a sortableLease) Swap(i, j int) { a[i], a[j] = a[j], a[i] } func (a sortableLease) Less(i, j int) bool { return a[i].Addr.Less(a[j].Addr) } -func (li *LeaseInfo) ResetExpiry(t time.Time, dur time.Duration) { +func (li *Info) ResetExpiry(t time.Time, dur time.Duration) { li.Expires = t.Add(dur) } -func (li *LeaseInfo) IsExpired(t time.Time) bool { +func (li *Info) IsExpired(t time.Time) bool { return t.After(li.Expires) } -func (li *LeaseInfo) Expire() { +func (li *Info) Expire() { li.Expires = time.Time{} } -func SortLeases(leases []*LeaseInfo) []*LeaseInfo { +func Sort(leases []*Info) []*Info { sortable := sortableLease(leases) sort.Sort(sortable) - return []*LeaseInfo(sortable) + return []*Info(sortable) } -func (lease *LeaseInfo) Reset() *LeaseInfo { +func (lease *Info) Reset() *Info { lease.Expires = time.Time{} lease.HardwareAddress = nil return lease } -func (lease LeaseInfo) Compare(other LeaseInfo) error { +func (lease Info) Compare(other Info) error { susFields := []string{} if lease.Addr != other.Addr { diff --git a/leases/server_info.go b/leases/server_info.go new file mode 100644 index 0000000..a7b629b --- /dev/null +++ b/leases/server_info.go @@ -0,0 +1,17 @@ +package leases + +import "net/netip" + +type Server struct { + Addr netip.Addr + HardwareAddress HardwareAddress + Network *Network +} + +type Network struct { + Mask netip.Addr + Gateway netip.Addr + Broadcast netip.Addr + Domain string + DNS []netip.Addr +} diff --git a/server/BUILD.bazel b/server/BUILD.bazel index 0df45d9..2d14076 100644 --- a/server/BUILD.bazel +++ b/server/BUILD.bazel @@ -3,6 +3,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library") go_library( name = "server", srcs = [ + "dhcp_transaction.go", "ifi.go", "ifi_linux.go", "lease_state.go", @@ -15,6 +16,7 @@ go_library( "//config", "//dhcp", "//iptools", + "//leases", "@com_github_benbjohnson_clock//:clock", "@dev_wntrmute_git_kyle_goutils//log", "@in_gopkg_yaml_v2//:yaml_v2", diff --git a/server/dhcp_transaction.go b/server/dhcp_transaction.go new file mode 100644 index 0000000..e282725 --- /dev/null +++ b/server/dhcp_transaction.go @@ -0,0 +1,58 @@ +package server + +import ( + "net" + + log "git.wntrmute.dev/kyle/goutils/log" + "git.wntrmute.dev/kyle/kdhcp/dhcp" +) + +func (srv *Server) WriteDHCPResponse(req *dhcp.Packet, ip net.Addr) error { + lease := srv.SelectLease(req, srv.clock.Now()) + if lease == nil { + log.Errln("server: couldn't find available lease") + return nil + } + + log.Debugf("available lease: %s", lease) + response, err := dhcp.NewOffer(req, srv.Config.Network.NetworkInfo(), lease) + if err != nil { + log.Errln(err.Error()) + return err + } + + log.Debugf("building offer for address %s", lease.Addr) + pkt, err := dhcp.WritePacket(response) + if err != nil { + log.Errln(err.Error()) + return err + } + + ip, err = addrReply(ip) + if err != nil { + log.Errln(err.Error()) + return err + } + + log.Debugf("writing offer to %s", ip) + _, err = srv.Conn.WriteTo(pkt, ip) + if err != nil { + log.Errf("failed to write packet: %s", err) + return err + } + + log.Infof("sending offer for address %s to %s", lease.Addr, req.HardwareAddress) + go func() { + if err := srv.SaveLeases(); err != nil { + log.Warningf("server: while saving leases: %s", err) + } + }() + + err = srv.AcceptLease(req, lease, srv.clock.Now()) + if err != nil { + log.Errln(err.Error()) + return err + } + + return nil +} diff --git a/server/ifi_linux.go b/server/ifi_linux.go index 6dd1b84..1c9a5b4 100644 --- a/server/ifi_linux.go +++ b/server/ifi_linux.go @@ -3,20 +3,20 @@ package server import ( + "fmt" "net" "syscall" + + log "git.wntrmute.dev/kyle/goutils/log" ) func BindInterface(ip net.IP, port int, dev string) (net.PacketConn, error) { - if port == 0 { - port = 67 - } - - udpAddr := &net.UDPAddr{ - IP: ip, - Port: port, + udpAddr, err := net.ResolveUDPAddr("udp4", fmt.Sprintf(":%d", port)) + if err != nil { + return nil, err } + log.Debugf("listen on %s", udpAddr) conn, err := net.ListenUDP("udp4", udpAddr) if err != nil { return nil, err diff --git a/server/lease_state.go b/server/lease_state.go index e3339da..1b8f391 100644 --- a/server/lease_state.go +++ b/server/lease_state.go @@ -1,62 +1,66 @@ package server import ( - "fmt" + "encoding/json" "os" log "git.wntrmute.dev/kyle/goutils/log" "git.wntrmute.dev/kyle/kdhcp/iptools" - "gopkg.in/yaml.v2" + "git.wntrmute.dev/kyle/kdhcp/leases" ) type LeaseState struct { - Pools map[string]*iptools.Pool `yaml:"pools"` - Static map[string]*iptools.LeaseInfo `yaml:"static"` + Pools map[string]*iptools.Pool `json:"pools"` + Static map[string]*leases.Info `json:"static"` } func (srv *Server) SaveLeases() error { - leaseFile, err := os.Create(srv.Config.LeaseFile) - if err != nil { - return fmt.Errorf("server: while saving leases: %w", err) - } - defer leaseFile.Close() - encoder := yaml.NewEncoder(leaseFile) - state := &LeaseState{ Pools: srv.Pools, Static: srv.Static, } - if err = encoder.Encode(state); err != nil { - return fmt.Errorf("server: while saving leases: %w", err) + out, err := json.MarshalIndent(state, "", " ") + if err != nil { + return err } - if err = encoder.Close(); err != nil { - return fmt.Errorf("server: while saving leases: %w", err) + err = os.WriteFile(srv.Config.LeaseFile, out, 0644) + if err != nil { + return err } + log.Infof("wrote lease state to file %s", srv.Config.LeaseFile) return nil } func (srv *Server) LoadLeases() error { - leaseState := &LeaseState{} - leaseFile, err := os.Open(srv.Config.LeaseFile) + state := &LeaseState{} + + stateBytes, err := os.ReadFile(srv.Config.LeaseFile) if err != nil { if os.IsNotExist(err) { - log.Warningf("server: not loading leases from %s: lease file not found", srv.Config.LeaseFile) + log.Infof("not restoring leases from %s: file doesn't exist", srv.Config.LeaseFile) return nil } - - return fmt.Errorf("server: while reading leases: %w", err) - } - defer leaseFile.Close() - decoder := yaml.NewDecoder(leaseFile) - - if err = decoder.Decode(leaseState); err != nil { - return fmt.Errorf("server: while reading leases: %w", err) + return err } - srv.Pools = leaseState.Pools - srv.Static = leaseState.Static + err = json.Unmarshal(stateBytes, state) + if err != nil { + switch err := err.(type) { + case *json.UnmarshalTypeError: + log.Infof("type error: %q: [ \"bad type: got %s; want %s\" ]", err.Field, err.Value, err.Type.String()) + return err + case *json.InvalidUnmarshalError: + log.Infof("invalid unmarshal error: %s", err) + } + return err + } + + log.Infoln("restored lease states") + + srv.Pools = state.Pools + srv.Static = state.Static return nil } diff --git a/server/pools.go b/server/pools.go index ed4226d..d0ed088 100644 --- a/server/pools.go +++ b/server/pools.go @@ -7,6 +7,7 @@ import ( log "git.wntrmute.dev/kyle/goutils/log" "git.wntrmute.dev/kyle/kdhcp/iptools" + "git.wntrmute.dev/kyle/kdhcp/leases" ) // pools.go adds pool functionality to the server. @@ -18,7 +19,7 @@ func (srv *Server) loadPoolsFromConfig() error { return fmt.Errorf("server: while instantiating pools, could not load IP %s", ip) } log.Debugf("server: added static host entry %s -> %s", host, addr) - srv.Static[host] = &iptools.LeaseInfo{ + srv.Static[host] = &leases.Info{ HostName: host, Addr: addr, } diff --git a/server/server.go b/server/server.go index dc5e3b3..7ea9024 100644 --- a/server/server.go +++ b/server/server.go @@ -1,7 +1,7 @@ package server import ( - "errors" + "fmt" "net" "time" @@ -9,6 +9,7 @@ import ( "git.wntrmute.dev/kyle/kdhcp/config" "git.wntrmute.dev/kyle/kdhcp/dhcp" "git.wntrmute.dev/kyle/kdhcp/iptools" + "git.wntrmute.dev/kyle/kdhcp/leases" "github.com/benbjohnson/clock" ) @@ -18,11 +19,27 @@ const ( MaxResponseWait = 5 * time.Minute ) +func addrReply(addr net.Addr) (net.Addr, error) { + udpAddr, err := net.ResolveUDPAddr("udp4", addr.String()) + if err != nil { + return nil, err + } + + if udpAddr.IP.Equal(net.IPv4zero) { + return &net.UDPAddr{ + IP: net.IPv4bcast, + Port: udpAddr.Port, + }, nil + } + + return udpAddr, nil +} + type Server struct { Conn net.PacketConn Config *config.Config Pools map[string]*iptools.Pool - Static map[string]*iptools.LeaseInfo + Static map[string]*leases.Info clock clock.Clock } @@ -51,48 +68,74 @@ func (s *Server) ReadFrom() ([]byte, net.Addr, error) { return b, addr, nil } -func (s *Server) ReadDHCPRequest() (*dhcp.BootRequest, error) { +func (s *Server) ReadDHCPRequest() (*dhcp.Packet, net.Addr, error) { pkt, addr, err := s.ReadFrom() if err != nil { - return nil, err + return nil, nil, err } log.Debugf("server: read packet from %s", addr) - return dhcp.ReadRequest(pkt) + req, err := dhcp.ReadPacket(pkt) + if err != nil { + return nil, nil, err + } + + return req, addr, nil } func (s *Server) WriteTo(b []byte, addr net.Addr) error { - return errors.New("server: not implemented") + n, err := s.Conn.WriteTo(b, addr) + if err != nil { + return err + } + + log.Debugf("wrote %d bytes to %s", n, addr) + return err } -func (s *Server) AcceptPacket() (*dhcp.BootRequest, error) { - request, err := s.ReadDHCPRequest() +func (s *Server) AcceptPacket() (*dhcp.Packet, net.Addr, error) { + request, addr, err := s.ReadDHCPRequest() if err != nil { - return nil, err + return nil, nil, err } log.Debugf("BOOTP request received from %x", request.HardwareAddress) - return request, nil + return request, addr, nil } -func (s *Server) Listen() { - go s.updatePoolLoop() +func (srv *Server) Listen() { + go srv.updatePoolLoop() + go func() { + for { + time.Sleep(5 * time.Minute) + if err := srv.SaveLeases(); err != nil { + log.Warningf("server: while saving leases: %s", err) + } + } + }() for { - req, err := s.AcceptPacket() + req, addr, err := srv.AcceptPacket() if err != nil { log.Errf("server: error reading packet: %s", err) continue } - lease := s.SelectLease(req, s.clock.Now()) - if err != nil { - log.Err("server: couldn't find available lease") + if req.MessageType != dhcp.MessageTypeBootRequest { + log.Debugf("ignoring BOOTP message type %d", req.MessageType) continue } - log.Infof("available lease: %s", lease) - continue + switch req.DHCPType { + case dhcp.DHCPMessageTypeDiscover: + go srv.WriteDHCPResponse(req, addr) + case dhcp.DHCPMessageTypeRequest: + log.Debugf("not handling DHCP request") + case dhcp.DHCPMessageTypeRelease: + log.Debugf("not handling DHCP release") + default: + log.Debugf("not handling unknown request type %d", req.DHCPType) + } } } @@ -100,7 +143,7 @@ func New(cfg *config.Config) (*Server, error) { srv := &Server{ Config: cfg, Pools: map[string]*iptools.Pool{}, - Static: map[string]*iptools.LeaseInfo{}, + Static: map[string]*leases.Info{}, clock: clock.New(), } @@ -121,18 +164,36 @@ func New(cfg *config.Config) (*Server, error) { return srv, nil } -func (srv *Server) SelectLease(req *dhcp.BootRequest, t time.Time) *iptools.LeaseInfo { +func (srv *Server) SelectLease(req *dhcp.Packet, t time.Time) *leases.Info { if li, ok := srv.Static[req.HostName]; ok { return li } if pool, ok := srv.Pools[req.HostName]; ok { - return pool.Peek(t, MaxResponseWait) + return pool.Peek(req, t, MaxResponseWait) } if pool, ok := srv.Pools[DefaultPool]; ok { - return pool.Peek(t, MaxResponseWait) + return pool.Peek(req, t, MaxResponseWait) } return nil } + +func (srv *Server) AcceptLease(req *dhcp.Packet, li *leases.Info, t time.Time) error { + if _, ok := srv.Static[li.HostName]; ok { + return nil + } + + for _, pool := range srv.Pools { + if _, ok := pool.Active[li.Addr]; ok { + return nil + } + + if _, ok := pool.Limbo[li.Addr]; ok { + return pool.Accept(li.Addr, t) + } + } + + return fmt.Errorf("could not accept unknown lease: %s", li) +}