diff --git a/cmd/kdhcpd/BUILD.bazel b/cmd/kdhcpd/BUILD.bazel index f300b9a..9ffd6e3 100644 --- a/cmd/kdhcpd/BUILD.bazel +++ b/cmd/kdhcpd/BUILD.bazel @@ -5,7 +5,10 @@ go_library( srcs = ["main.go"], importpath = "git.wntrmute.dev/kyle/kdhcp/cmd/kdhcpd", visibility = ["//visibility:private"], - deps = ["//server"], + deps = [ + "//log", + "//server", + ], ) go_binary( diff --git a/cmd/kdhcpd/main.go b/cmd/kdhcpd/main.go index 35bd285..80b1ed6 100644 --- a/cmd/kdhcpd/main.go +++ b/cmd/kdhcpd/main.go @@ -2,18 +2,33 @@ package main import ( "flag" - "log" + "git.wntrmute.dev/kyle/kdhcp/log" "git.wntrmute.dev/kyle/kdhcp/server" ) func main() { - cfg := &server.Config{} - flag.StringVar(&cfg.Device, "i", "eth0", "network `interface` to listen on") + cfg := server.DefaultConfig() + var level, tag string + flag.StringVar(&level, "l", "DEBUG", "log level") // TODO(kyle): change this warning later + flag.IntVar(&cfg.Port, "p", cfg.Port, "port to listen on") + flag.StringVar(&tag, "t", "kdhcpd", "logging tag") flag.Parse() - _, err := server.NewServer(cfg) + log.Setup(level, tag) + + srv, err := server.NewServer(cfg) if err != nil { log.Fatal(err) } + + for { + packet, err := srv.ReadFrom() + if err != nil { + log.Warning(err) + continue + } + + log.Debugf("receive %d byte packet from %s", len(packet.Data), packet.Addr) + } } diff --git a/log/BUILD.bazel b/log/BUILD.bazel index a73a852..c9fceed 100644 --- a/log/BUILD.bazel +++ b/log/BUILD.bazel @@ -5,5 +5,5 @@ go_library( srcs = ["logger.go"], importpath = "git.wntrmute.dev/kyle/kdhcp/log", visibility = ["//visibility:public"], - deps = ["@com_github_hashicorp_go_syslog//:go-syslog"], + deps = ["//bazel-kdhcp/external/com_github_hashicorp_go_syslog:go-syslog"], ) diff --git a/log/logger.go b/log/logger.go index a7761e5..25a600b 100644 --- a/log/logger.go +++ b/log/logger.go @@ -2,6 +2,8 @@ package log import ( "fmt" + "os" + "strings" "time" gsyslog "github.com/hashicorp/go-syslog" @@ -13,6 +15,10 @@ type logger struct { } func (log *logger) printf(p gsyslog.Priority, format string, args ...interface{}) { + if !strings.HasSuffix(format, "\n") { + format += "\n" + } + if p <= log.p { fmt.Printf("%s [%s] ", prioritiev[p], timestamp()) fmt.Printf(format, args...) @@ -158,34 +164,45 @@ func Emergln(args ...interface{}) { log.println(gsyslog.LOG_EMERG, args...) } -func Debugf(args ...interface{}) { - log.printf(gsyslog.LOG_DEBUG, args...) +func Debugf(format string, args ...interface{}) { + log.printf(gsyslog.LOG_DEBUG, format, args...) } -func Infof(args ...interface{}) { - log.printf(gsyslog.LOG_INFO, args...) +func Infof(format string, args ...interface{}) { + log.printf(gsyslog.LOG_INFO, format, args...) } -func Noticef(args ...interface{}) { - log.printf(gsyslog.LOG_NOTICE, args...) +func Noticef(format string, args ...interface{}) { + log.printf(gsyslog.LOG_NOTICE, format, args...) } -func Warningf(args ...interface{}) { - log.print(gsyslog.LOG_WARNING, args...) +func Warningf(format string, args ...interface{}) { + log.printf(gsyslog.LOG_WARNING, format, args...) } -func Errf(args ...interface{}) { - log.printf(gsyslog.LOG_ERR, args...) +func Errf(format string, args ...interface{}) { + log.printf(gsyslog.LOG_ERR, format, args...) } -func Critf(args ...interface{}) { - log.printf(gsyslog.LOG_CRIT, args...) +func Critf(format string, args ...interface{}) { + log.printf(gsyslog.LOG_CRIT, format, args...) } -func Alertf(args ...interface{}) { - log.printf(gsyslog.LOG_ALERT, args...) +func Alertf(format string, args ...interface{}) { + log.printf(gsyslog.LOG_ALERT, format, args...) } -func Emergf(args ...interface{}) { - log.printf(gsyslog.LOG_EMERG, args...) +func Emergf(format string, args ...interface{}) { + log.printf(gsyslog.LOG_EMERG, format, args...) + os.Exit(1) +} + +func Fatal(args ...interface{}) { + log.println(gsyslog.LOG_ERR, args...) + os.Exit(1) +} + +func Fatalf(format string, args ...interface{}) { + log.printf(gsyslog.LOG_ERR, format, args...) + os.Exit(1) } diff --git a/server/BUILD.bazel b/server/BUILD.bazel index 78cc2a4..f69a764 100644 --- a/server/BUILD.bazel +++ b/server/BUILD.bazel @@ -2,8 +2,16 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library") go_library( name = "server", - srcs = ["server.go"], + srcs = [ + "addr.go", + "config.go", + "packet.go", + "server.go", + ], importpath = "git.wntrmute.dev/kyle/kdhcp/server", visibility = ["//visibility:public"], - deps = ["@com_github_davecgh_go_spew//spew"], + deps = [ + "//bazel-kdhcp/external/com_github_davecgh_go_spew/spew", + "//log", + ], ) diff --git a/server/addr.go b/server/addr.go new file mode 100644 index 0000000..c623f4c --- /dev/null +++ b/server/addr.go @@ -0,0 +1,61 @@ +package server + +import ( + "fmt" + "net" + + "git.wntrmute.dev/kyle/kdhcp/bazel-kdhcp/external/com_github_davecgh_go_spew/spew" + "git.wntrmute.dev/kyle/kdhcp/log" +) + +type addr struct { + ifi *net.Interface + IP net.IP + ipn *net.IPNet +} + +func addrsForDevice(dev string) ([]addr, error) { + netInterface, err := net.InterfaceByName(dev) + if err != nil { + return nil, fmt.Errorf("while selecting interface %s: %w", dev, err) + } + + spew.Dump(netInterface) + + var addrs []addr + devAddrs, err := netInterface.Addrs() + if err != nil { + return nil, err + } + + for _, devAddr := range devAddrs { + log.Debugf("consider %s", devAddr.String()) + ip, ipn, err := net.ParseCIDR(devAddr.String()) + if err != nil { + continue + } + + if ip == nil { + continue // address isn't an IP address + } + + log.Debugf("found IP: %s", ip) + ip = ip.To4() + + // DHCP should only listen on private addresses. + if !ip.IsPrivate() { + log.Debugln("skipping non-private") + continue + } + + // only support IPv4 for now + if len(ip) != 4 { + log.Debugf("%d IP, only supporting v4 right now", len(ip)) + continue + } + + addrs = append(addrs, addr{netInterface, ip, ipn}) + } + + return addrs, nil +} diff --git a/server/config.go b/server/config.go new file mode 100644 index 0000000..782f532 --- /dev/null +++ b/server/config.go @@ -0,0 +1,19 @@ +package server + +const ( + DefaultPort = 67 + DefaultMaxPacketSize = 512 + DefaultNetwork = "udp4" +) + +type Config struct { + Port int + MaxPacketSize int +} + +func DefaultConfig() *Config { + return &Config{ + Port: DefaultPort, + MaxPacketSize: DefaultMaxPacketSize, + } +} diff --git a/server/packet.go b/server/packet.go new file mode 100644 index 0000000..20f2604 --- /dev/null +++ b/server/packet.go @@ -0,0 +1,8 @@ +package server + +import "net" + +type Packet struct { + Data []byte + Addr net.Addr +} diff --git a/server/server.go b/server/server.go index 87bddf8..2e259ea 100644 --- a/server/server.go +++ b/server/server.go @@ -1,82 +1,62 @@ package server - import ( - "errors" - "log" + "fmt" "net" - "github.com/davecgh/go-spew/spew" + "git.wntrmute.dev/kyle/kdhcp/log" ) // github.com/insomniacslk/dhcp -type addr struct { - IP net.IP - ipn *net.IPNet -} - -type Config struct { - Device string `yaml:"device"` -} - type Server struct { - addrs []net.IP - l net.Listener + cfg *Config + conn net.PacketConn } -func addrsForDevice(dev string) ([]addr, error) { - netInterface, err := net.InterfaceByName(dev) +func (srv *Server) Listen() (err error) { + if srv.conn != nil { + srv.conn.Close() + } + + log.Debugf("attempting to set up packet listener on %s 0.0.0.0:%d", DefaultNetwork, srv.cfg.Port) + srv.conn, err = net.ListenPacket(DefaultNetwork, fmt.Sprintf(":%d", srv.cfg.Port)) + if err != nil { + return err + } + + return nil +} + +func (srv *Server) Close() error { + if srv.conn != nil { + return srv.conn.Close() + } + return nil +} + +func (srv *Server) ReadFrom() (*Packet, error) { + b := make([]byte, srv.cfg.MaxPacketSize) + n, addr, err := srv.conn.ReadFrom(b) if err != nil { return nil, err } - spew.Dump(netInterface) - - var addrs []addr - devAddrs, err := netInterface.Addrs() - if err != nil { - return nil, err - } - - for _, devAddr := range devAddrs { - log.Printf("consider %s", devAddr.String()) - ip, ipn, err := net.ParseCIDR(devAddr.String()) - if err != nil { - continue - } - - if ip == nil { - continue // address isn't an IP address - } - - log.Printf("found IP: %s", ip) - ip = ip.To4() - - // DHCP should only listen on private addresses. - if !ip.IsPrivate() { - log.Println("skipping non-private") - continue - } - - // only support IPv4 for now - if len(ip) != 4 { - log.Printf("%d IP, only supporting v4 right now", len(ip)) - continue - } - - addrs = append(addrs, addr{ip, ipn}) - } - - return addrs, nil + return &Packet{ + Data: b[:n], + Addr: addr, + }, nil } func NewServer(cfg *Config) (*Server, error) { - addrs, err := addrsForDevice(cfg.Device) + srv := &Server{ + cfg: cfg, + } + + err := srv.Listen() if err != nil { return nil, err } - log.Println("server IP list: ", addrs) - return nil, errors.New("not implemented") + return srv, nil }