iptools: reorganizing and continuing pool work

This commit is contained in:
Kyle Isom 2023-05-03 02:57:04 +00:00
parent 77c21c9747
commit 842181f555
14 changed files with 360 additions and 31 deletions

1
.gitignore vendored
View File

@ -5,3 +5,4 @@ bazel-testlogs
kdhcpd kdhcpd
cmd/kdhcpd/kdhcpd cmd/kdhcpd/kdhcpd
strace.txt

View File

@ -9,6 +9,7 @@ go_library(
importpath = "git.wntrmute.dev/kyle/kdhcp/config", importpath = "git.wntrmute.dev/kyle/kdhcp/config",
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//iptools",
"//log", "//log",
"@in_gopkg_yaml_v2//:yaml_v2", "@in_gopkg_yaml_v2//:yaml_v2",
], ],

View File

@ -6,6 +6,7 @@ import (
"io/ioutil" "io/ioutil"
"net" "net"
"git.wntrmute.dev/kyle/kdhcp/iptools"
"git.wntrmute.dev/kyle/kdhcp/log" "git.wntrmute.dev/kyle/kdhcp/log"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
@ -23,23 +24,6 @@ func ensureV4(ip net.IP) (net.IP, error) {
return ip4, nil return ip4, nil
} }
type IPRange struct {
Start net.IP
End net.IP
}
func (r *IPRange) ensureV4() (err error) {
if r.Start, err = ensureV4(r.Start); err != nil {
return fmt.Errorf("config: range start address %w", err)
}
if r.End, err = ensureV4(r.End); err != nil {
return fmt.Errorf("config: range end address %w", err)
}
return nil
}
type Network struct { type Network struct {
IP net.IP `yaml:"address"` IP net.IP `yaml:"address"`
Gateway net.IP `yaml:"gateway"` Gateway net.IP `yaml:"gateway"`
@ -85,14 +69,14 @@ type ConfigFile struct {
} }
type Config struct { type Config struct {
Version int `yaml:"version"` Version int `yaml:"version"`
Interface string `yaml:"interface"` Interface string `yaml:"interface"`
Address string `yaml:"address"` Address string `yaml:"address"`
Port int `yaml:"port"` Port int `yaml:"port"`
LeaseFile string `yaml:"lease_file"` LeaseFile string `yaml:"lease_file"`
Network *Network `yaml:"network"` Network *Network `yaml:"network"`
Pools map[string]*IPRange `yaml:"pools"` Pools map[string]*iptools.Range `yaml:"pools"`
Statics map[string]net.IP `yaml:"statics"` Statics map[string]net.IP `yaml:"statics"`
} }
func (cfg *Config) process() (err error) { func (cfg *Config) process() (err error) {
@ -114,7 +98,7 @@ func (cfg *Config) process() (err error) {
} }
for k, v := range cfg.Pools { for k, v := range cfg.Pools {
if err = v.ensureV4(); err != nil { if err = v.Validate(); err != nil {
return fmt.Errorf("config: pool %s %w", k, err) return fmt.Errorf("config: pool %s %w", k, err)
} }

View File

@ -2,6 +2,7 @@ package dhcp
import ( import (
"errors" "errors"
"fmt"
"io" "io"
) )
@ -11,14 +12,14 @@ type Option func(req *BootRequest, r io.Reader) error
const ( const (
OptionTagPadding OptionTag = 0 OptionTagPadding OptionTag = 0
OptionTagHostName = 12 OptionTagHostName OptionTag = 12
OptionTagMessageType = 53 OptionTagMessageType OptionTag = 53
OptionTagParameterRequestList = 55 OptionTagParameterRequestList OptionTag = 55
OptionTagEnd = 255 OptionTagEnd OptionTag = 255
) )
var optionRegistry = map[OptionTag]Option{ var optionRegistry = map[OptionTag]Option{
OptionTagPadding: OptionTag, OptionTagPadding: OptionPad,
OptionTagHostName: OptionHostName, OptionTagHostName: OptionHostName,
OptionTagMessageType: OptionMessageType, OptionTagMessageType: OptionMessageType,
OptionTagParameterRequestList: OptionParameterRequestList, OptionTagParameterRequestList: OptionParameterRequestList,
@ -44,3 +45,13 @@ func OptionParameterRequestList(req *BootRequest, r io.Reader) error {
func OptionEnd(req *BootRequest, r io.Reader) error { func OptionEnd(req *BootRequest, r io.Reader) error {
return errors.New("dhcp: option not implemented yet") return errors.New("dhcp: option not implemented yet")
} }
func ReadOption(req *BootRequest, tag byte, r io.Reader) error {
opt := OptionTag(tag)
if f, ok := optionRegistry[opt]; ok {
return f(req, r)
}
return fmt.Errorf("dhcp: unknown/unhandled option %d", opt)
}

17
iptools/BUILD.bazel Normal file
View File

@ -0,0 +1,17 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
go_library(
name = "iptools",
srcs = [
"pool.go",
"range.go",
],
importpath = "git.wntrmute.dev/kyle/kdhcp/iptools",
visibility = ["//visibility:public"],
)
go_test(
name = "iptools_test",
srcs = ["range_test.go"],
embed = [":iptools"],
)

View File

@ -0,0 +1,41 @@
package iptools
import (
"bytes"
"fmt"
"strconv"
"strings"
)
type HardwareAddress []byte
func (mac HardwareAddress) String() string {
marshalled := []string{}
for i := 0; i < len(mac); i++ {
marshalled = append(marshalled, fmt.Sprintf("%02x", []byte(mac[i:i+1])))
}
return strings.Join(marshalled, ":")
}
func (mac HardwareAddress) MarshalText() ([]byte, error) {
return []byte(mac.String()), nil
}
func (mac *HardwareAddress) UnmarshalText(b []byte) error {
rb := bytes.Split(b, []byte(":"))
for _, octet := range rb {
n, err := strconv.ParseUint(string(octet), 16, 8)
if err != nil {
return err
}
*mac = append(*mac, uint8(n))
}
return nil
}
func (mac HardwareAddress) Match(other HardwareAddress) bool {
return bytes.Equal(mac, other)
}

View File

@ -0,0 +1,31 @@
package iptools
import (
"bytes"
"testing"
)
func TestHardwareMacMarshalling(t *testing.T) {
macString := "b8:27:eb:b6:a1:a7"
mac := HardwareAddress([]byte{0xb8, 0x27, 0xeb, 0xb6, 0xa1, 0xa7})
b, err := mac.MarshalText()
if err != nil {
t.Fatal(err)
}
s := string(b)
if s != macString {
t.Fatalf("have %s, want %s", s, macString)
}
mac2 := &HardwareAddress{}
err = mac2.UnmarshalText(b)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(*mac2, mac) {
t.Fatalf("have %x, want %x", *mac2, mac)
}
}

69
iptools/lease_info.go Normal file
View File

@ -0,0 +1,69 @@
package iptools
import (
"fmt"
"net/netip"
"sort"
"strings"
"time"
)
type LeaseInfo struct {
HostName string `yaml:"hostname"`
Addr netip.Addr `yaml:"addr"`
HardwareAddress HardwareAddress `yaml:"mac_addr"`
Expires time.Time `yaml:"expires"`
}
type sortableLease []LeaseInfo
func (a sortableLease) Len() int { return len(a) }
func (a sortableLease) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a sortableLease) Less(i, j int) bool { return a[i].Addr.Less(a[j].Addr) }
func (li *LeaseInfo) ResetExpiry(t time.Time, dur time.Duration) {
li.Expires = t.Add(dur)
}
func (li *LeaseInfo) IsExpired(t time.Time) bool {
return t.After(li.Expires)
}
func (li *LeaseInfo) Expire() {
li.Expires = time.Time{}
}
func SortLeases(leases []LeaseInfo) []LeaseInfo {
sortable := sortableLease(leases)
sort.Sort(sortable)
return []LeaseInfo(sortable)
}
func (lease LeaseInfo) Reset() LeaseInfo {
lease.Expires = time.Time{}
lease.HardwareAddress = nil
return lease
}
func (lease LeaseInfo) Compare(other LeaseInfo) error {
susFields := []string{}
if lease.Addr != other.Addr {
susFields = append(susFields, fmt.Sprintf("address is %s, but is recorded as %s", lease.Addr, other.Addr))
}
if !lease.HardwareAddress.Match(other.HardwareAddress) {
susFields = append(susFields, fmt.Sprintf("hardware address is %s, but is recorded as %s", lease.HardwareAddress, other.HardwareAddress))
}
if lease.HostName != other.HostName {
susFields = append(susFields, fmt.Sprintf("hostname is %s, but is recorded as %s", lease.HostName, other.HostName))
}
if len(susFields) > 0 {
return fmt.Errorf("suspicious lease: %s", strings.Join(susFields, ";"))
}
return nil
}

37
iptools/pool_test.go Normal file
View File

@ -0,0 +1,37 @@
package iptools
import (
"fmt"
"net/netip"
"testing"
"time"
)
var (
poolTestIP1 = netip.MustParseAddr("192.168.4.1")
poolTestIP2 = netip.MustParseAddr("192.168.4.32")
)
func TestBasicPool(t *testing.T) {
r := &Range{
Start: poolTestIP1,
End: poolTestIP2,
}
p, err := NewPool("cluster", 24*time.Hour, r)
if err != nil {
t.Fatal(err)
}
if len(p.Available) != 32 {
t.Fatalf("have %d available leases, want %d", len(p.Available), 32)
}
for i := range p.Available {
l := p.Available[i]
expectedName := fmt.Sprintf("cluster%02d", l.Addr.As4()[3])
if l.HostName != expectedName {
t.Fatalf("have hostname %s, want %s", l.HostName, expectedName)
}
}
}

61
iptools/range.go Normal file
View File

@ -0,0 +1,61 @@
package iptools
import (
"fmt"
"net/netip"
)
const (
DefaultMaskBits = 24
)
type Range struct {
Start netip.Addr `yaml:"start"`
End netip.Addr `yaml:"end"`
Network netip.Prefix `yaml:"network"`
}
func (r *Range) Validate() error {
if !r.Start.Is4() {
return fmt.Errorf("range start %s is not a valid IPv4 address", r.Start)
}
if !r.End.Is4() {
return fmt.Errorf("range end %s is not a valid IPv4 address", r.End)
}
// Compare returns -1 if lhs < rhs, 0 if lhs == rhs, and 1 if lhs > rhs.
if r.End.Compare(r.Start) != 1 {
return fmt.Errorf("start address %s is not before end address %s", r.Start, r.End)
}
var err error
if !r.Network.IsValid() {
r.Network, err = r.Start.Prefix(DefaultMaskBits)
if err != nil {
return err
}
}
if !r.Network.Contains(r.Start) {
return fmt.Errorf("prefix %s does not contain start address %s", r.Network, r.Start)
}
if !r.Network.Contains(r.End) {
return fmt.Errorf("prefix %s does not contain end address %s", r.Network, r.End)
}
return nil
}
// this is probably dumb, but it's a one-time cost upfront on pool instantiation.
func (r *Range) numHosts() int {
cur := r.Start
hosts := 0
for cur.Compare(r.End) < 1 {
hosts++
cur.Next()
}
return hosts
}

49
iptools/range_test.go Normal file
View File

@ -0,0 +1,49 @@
package iptools
import (
"net/netip"
"testing"
)
var (
rangeTestIP1 = netip.AddrFrom4([4]byte{192, 168, 4, 3})
rangeTestIP2 = netip.AddrFrom4([4]byte{192, 168, 4, 17})
)
func TestBasicValidation(t *testing.T) {
r1 := &Range{
Start: rangeTestIP1,
End: rangeTestIP2,
}
if err := r1.Validate(); err != nil {
t.Fatalf("range 1 should be valid: %s", err)
}
r2 := &Range{
Start: rangeTestIP2,
End: rangeTestIP1,
}
if r2.Validate() == nil {
t.Fatal("range 2 should be invalid")
}
r3 := &Range{
Start: netip.IPv6LinkLocalAllRouters(),
End: rangeTestIP1,
}
if r3.Validate() == nil {
t.Fatal("range 3 should be invalid")
}
r4 := &Range{
Start: rangeTestIP2,
End: netip.IPv6LinkLocalAllRouters(),
}
if r4.Validate() == nil {
t.Fatal("range 4 should be invalid")
}
}

25
iptools/tools.go Normal file
View File

@ -0,0 +1,25 @@
package iptools
import "fmt"
func enumerateRange(name string, r *Range, startFromOne bool) []LeaseInfo {
start := r.Start
cur := start
lenfmt := fmt.Sprintf("%%s%%0%dd", len(fmt.Sprintf("%d", r.numHosts())))
i := 0
if startFromOne {
i++
}
leases := []LeaseInfo{}
for r.End.Compare(cur) >= 0 {
leases = append(leases, LeaseInfo{
HostName: fmt.Sprintf(lenfmt, name, i),
Addr: cur,
})
i++
cur = cur.Next()
}
return leases
}

View File

@ -4,6 +4,7 @@ go_library(
name = "server", name = "server",
srcs = [ srcs = [
"ifi.go", "ifi.go",
"ifi_linux.go",
"server.go", "server.go",
], ],
importpath = "git.wntrmute.dev/kyle/kdhcp/server", importpath = "git.wntrmute.dev/kyle/kdhcp/server",

View File

@ -70,6 +70,7 @@ func (s *Server) Listen() {
log.Errf("server: error reading packet: %s", err) log.Errf("server: error reading packet: %s", err)
continue continue
} }
break
} }
} }