diff --git a/olm/olm.go b/olm/olm.go index ac28a7b..94098cb 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -428,22 +428,6 @@ func StartTunnel(config TunnelConfig) { } } - // Wrap TUN device with packet filter for DNS proxy - middleDev = middleDevice.NewMiddleDevice(tdev) - - // Create and start DNS proxy - dnsProxy, err = dns.NewDNSProxy(tdev, config.MTU) - if err != nil { - logger.Error("Failed to create DNS proxy: %v", err) - } - if err := dnsProxy.Start(middleDev); err != nil { - logger.Error("Failed to start DNS proxy: %v", err) - } - ip := net.ParseIP("192.168.1.100") - if dnsProxy.AddDNSRecord("example.com", ip); err != nil { - logger.Error("Failed to add DNS record: %v", err) - } - // fileUAPI, err := func() (*os.File, error) { // if config.FileDescriptorUAPI != 0 { // fd, err := strconv.ParseUint(fmt.Sprintf("%d", config.FileDescriptorUAPI), 10, 32) @@ -460,6 +444,9 @@ func StartTunnel(config TunnelConfig) { // return // } + // Wrap TUN device with packet filter for DNS proxy + middleDev = middleDevice.NewMiddleDevice(tdev) + wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ") // Use filtered device instead of raw TUN device dev = device.NewDevice(middleDev, sharedBind, (*device.Logger)(wgLogger)) @@ -486,10 +473,28 @@ func StartTunnel(config TunnelConfig) { logger.Error("Failed to bring up WireGuard device: %v", err) } + // Create and start DNS proxy + dnsProxy, err = dns.NewDNSProxy(tdev, config.MTU) + if err != nil { + logger.Error("Failed to create DNS proxy: %v", err) + } + if err := dnsProxy.Start(middleDev); err != nil { + logger.Error("Failed to start DNS proxy: %v", err) + } + ip := net.ParseIP("192.168.1.100") + if dnsProxy.AddDNSRecord("example.com", ip); err != nil { + logger.Error("Failed to add DNS record: %v", err) + } + if err = ConfigureInterface(interfaceName, wgData, config.MTU); err != nil { logger.Error("Failed to configure interface: %v", err) } + if addRoutes([]string{"10.30.30.30/32"}, interfaceName); err != nil { + logger.Error("Failed to add route for DNS server: %v", err) + } + + // TODO: seperate adding the callback to this so we can init it above with the interface peerMonitor = peermonitor.NewPeerMonitor( func(siteID int, connected bool, rtt time.Duration) { // Find the site config to get endpoint information @@ -528,11 +533,11 @@ func StartTunnel(config TunnelConfig) { logger.Error("Failed to configure peer: %v", err) return } - if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil { + if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil { // this is something for darwin only thats required logger.Error("Failed to add route for peer: %v", err) return } - if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil { + if err := addRoutes(site.RemoteSubnets, interfaceName); err != nil { logger.Error("Failed to add routes for remote subnets: %v", err) return } @@ -635,7 +640,7 @@ func StartTunnel(config TunnelConfig) { } // Add new remote subnet routes - if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil { + if err := addRoutes(siteConfig.RemoteSubnets, interfaceName); err != nil { logger.Error("Failed to add new remote subnet routes: %v", err) return } @@ -688,7 +693,7 @@ func StartTunnel(config TunnelConfig) { logger.Error("Failed to add route for new peer: %v", err) return } - if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil { + if err := addRoutes(siteConfig.RemoteSubnets, interfaceName); err != nil { logger.Error("Failed to add routes for remote subnets: %v", err) return } @@ -814,7 +819,7 @@ func StartTunnel(config TunnelConfig) { } // Add routes for the new subnets - if err := addRoutesForRemoteSubnets(newSubnets, interfaceName); err != nil { + if err := addRoutes(newSubnets, interfaceName); err != nil { logger.Error("Failed to add routes for new remote subnets: %v", err) return } @@ -927,10 +932,10 @@ func StartTunnel(config TunnelConfig) { // Then, add routes for new subnets if len(updateSubnetsData.NewRemoteSubnets) > 0 { - if err := addRoutesForRemoteSubnets(updateSubnetsData.NewRemoteSubnets, interfaceName); err != nil { + if err := addRoutes(updateSubnetsData.NewRemoteSubnets, interfaceName); err != nil { logger.Error("Failed to add routes for new remote subnets: %v", err) // Attempt to rollback by re-adding old routes - if rollbackErr := addRoutesForRemoteSubnets(updateSubnetsData.OldRemoteSubnets, interfaceName); rollbackErr != nil { + if rollbackErr := addRoutes(updateSubnetsData.OldRemoteSubnets, interfaceName); rollbackErr != nil { logger.Error("Failed to rollback old routes: %v", rollbackErr) } return diff --git a/olm/route.go b/olm/route.go index 439d929..14c18a1 100644 --- a/olm/route.go +++ b/olm/route.go @@ -10,6 +10,7 @@ import ( "github.com/fosrl/newt/logger" "github.com/fosrl/olm/network" + "github.com/vishvananda/netlink" ) func DarwinAddRoute(destination string, gateway string, interfaceName string) error { @@ -60,23 +61,40 @@ func LinuxAddRoute(destination string, gateway string, interfaceName string) err return nil } - var cmd *exec.Cmd + // Parse destination CIDR + _, ipNet, err := net.ParseCIDR(destination) + if err != nil { + return fmt.Errorf("invalid destination address: %v", err) + } + + // Create route + route := &netlink.Route{ + Dst: ipNet, + } if gateway != "" { // Route with specific gateway - cmd = exec.Command("ip", "route", "add", destination, "via", gateway) + gw := net.ParseIP(gateway) + if gw == nil { + return fmt.Errorf("invalid gateway address: %s", gateway) + } + route.Gw = gw + logger.Info("Adding route to %s via gateway %s", destination, gateway) } else if interfaceName != "" { // Route via interface - cmd = exec.Command("ip", "route", "add", destination, "dev", interfaceName) + link, err := netlink.LinkByName(interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface %s: %v", interfaceName, err) + } + route.LinkIndex = link.Attrs().Index + logger.Info("Adding route to %s via interface %s", destination, interfaceName) } else { return fmt.Errorf("either gateway or interface must be specified") } - logger.Info("Running command: %v", cmd) - - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("ip route command failed: %v, output: %s", err, out) + // Add the route + if err := netlink.RouteAdd(route); err != nil { + return fmt.Errorf("failed to add route: %v", err) } return nil @@ -87,12 +105,22 @@ func LinuxRemoveRoute(destination string) error { return nil } - cmd := exec.Command("ip", "route", "del", destination) - logger.Info("Running command: %v", cmd) - - out, err := cmd.CombinedOutput() + // Parse destination CIDR + _, ipNet, err := net.ParseCIDR(destination) if err != nil { - return fmt.Errorf("ip route delete command failed: %v, output: %s", err, out) + return fmt.Errorf("invalid destination address: %v", err) + } + + // Create route to delete + route := &netlink.Route{ + Dst: ipNet, + } + + logger.Info("Removing route to %s", destination) + + // Delete the route + if err := netlink.RouteDel(route); err != nil { + return fmt.Errorf("failed to delete route: %v", err) } return nil @@ -268,8 +296,8 @@ func removeRouteForNetworkConfig(destination string) error { return nil } -// addRoutesForRemoteSubnets adds routes for each subnet in RemoteSubnets -func addRoutesForRemoteSubnets(remoteSubnets []string, interfaceName string) error { +// addRoutes adds routes for each subnet in RemoteSubnets +func addRoutes(remoteSubnets []string, interfaceName string) error { if len(remoteSubnets) == 0 { return nil }