package config import ( "errors" "fmt" "io/ioutil" "net" "strconv" "git.wntrmute.dev/kyle/kdhcp/log" "gopkg.in/yaml.v2" ) const ( CurrentVersion = 1 ) func ensureV4(ip net.IP) (net.IP, error) { ip4 := ip.To4() if ip4 == nil { return ip4, fmt.Errorf("%s isn't an IPv4 address", ip) } return ip4, nil } type IPRange struct { Start net.IP End net.IP } func (r *IPRange) ensureV4() (err error) { if r.Start, err = ensureV4(r.Start); err != nil { return fmt.Errorf("config: range start address %w", err) } if r.End, err = ensureV4(r.End); err != nil { return fmt.Errorf("config: range end address %w", err) } return nil } type Network struct { Gateway net.IP `yaml:"gateway"` Mask net.IP `yaml:"mask"` Broadcast net.IP `yaml:"broadcast"` DNS []net.IP `yaml:"dns"` Domain string `yaml:"domain"` } func (n *Network) ensureV4() (err error) { n.Gateway, err = ensureV4(n.Gateway) if err != nil { return fmt.Errorf("config: gateway %w", err) } n.Mask, err = ensureV4(n.Mask) if err != nil { return fmt.Errorf("config: mask %w", err) } n.Broadcast, err = ensureV4(n.Broadcast) if err != nil { return fmt.Errorf("config: broadcast %w", err) } for i := range n.DNS { n.DNS[i], err = ensureV4(n.DNS[i]) if err != nil { return fmt.Errorf("config: DNS address %w", err) } } return nil } type ConfigFile struct { Server *Config `yaml:"kdhcp"` } type Config struct { Version int `yaml:"version"` Interface string `yaml:"interface"` Address string `yaml:"address"` IP net.IP Port int LeaseFile string `yaml:"lease_file"` Network *Network `yaml:"network"` Pools map[string]*IPRange `yaml:"pools"` Statics map[string]net.IP `yaml:"statics"` } func (cfg *Config) process() (err error) { switch { case cfg.Version == 0: log.Warningln("config: Version is 0, which indicates it hasn't been set. The config may be invalid.") case cfg.Version > CurrentVersion: log.Warningf("config: Version is greater than the current version %d. The config may not behave as expected.", CurrentVersion) } _, err = net.InterfaceByName(cfg.Interface) if err != nil { return fmt.Errorf("config: while looking up interface %s: %w", cfg.Interface, err) } ip, port, err := net.SplitHostPort(cfg.Address) if err != nil { return err } cfg.IP = net.ParseIP(ip) if cfg.IP == nil { return fmt.Errorf("config: parsing IP from address %s: %w", cfg.Address, err) } cfg.IP, err = ensureV4(cfg.IP) if err != nil { return fmt.Errorf("config: address %w", err) } cfg.Port, err = strconv.Atoi(port) if err != nil { return fmt.Errorf("config: invalid port %s: %w", port, err) } err = cfg.Network.ensureV4() if err != nil { return err } for k, v := range cfg.Pools { if err = v.ensureV4(); err != nil { return fmt.Errorf("config: pool %s %w", k, err) } cfg.Pools[k] = v } for k, v := range cfg.Statics { cfg.Statics[k], err = ensureV4(v) if err != nil { return fmt.Errorf("config: %s %w", k, err) } } return nil } func Load(path string) (*Config, error) { if path == "" { path = FindConfigPath() } if path == "" { return nil, errors.New("config: no config file path specified and couldn't find a valid config file path") } data, err := ioutil.ReadFile(path) if err != nil { return nil, fmt.Errorf("config: loading %s: %w", path, err) } configFile := &ConfigFile{} err = yaml.Unmarshal(data, configFile) if err != nil { return nil, fmt.Errorf("config: while unmarshaling %s: %w", path, err) } if configFile.Server == nil { log.Fatal("missing `kdhcp` section of config") } config := configFile.Server if err = config.process(); err != nil { return nil, err } log.Debugf("config: read configuration from %s", path) return config, nil }