mirror of
https://github.com/fosrl/olm.git
synced 2026-03-06 10:46:42 +00:00
51
olm/olm.go
51
olm/olm.go
@@ -428,22 +428,6 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wrap TUN device with packet filter for DNS proxy
|
|
||||||
middleDev = middleDevice.NewMiddleDevice(tdev)
|
|
||||||
|
|
||||||
// Create and start DNS proxy
|
|
||||||
dnsProxy, err = dns.NewDNSProxy(tdev, config.MTU)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to create DNS proxy: %v", err)
|
|
||||||
}
|
|
||||||
if err := dnsProxy.Start(middleDev); err != nil {
|
|
||||||
logger.Error("Failed to start DNS proxy: %v", err)
|
|
||||||
}
|
|
||||||
ip := net.ParseIP("192.168.1.100")
|
|
||||||
if dnsProxy.AddDNSRecord("example.com", ip); err != nil {
|
|
||||||
logger.Error("Failed to add DNS record: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// fileUAPI, err := func() (*os.File, error) {
|
// fileUAPI, err := func() (*os.File, error) {
|
||||||
// if config.FileDescriptorUAPI != 0 {
|
// if config.FileDescriptorUAPI != 0 {
|
||||||
// fd, err := strconv.ParseUint(fmt.Sprintf("%d", config.FileDescriptorUAPI), 10, 32)
|
// fd, err := strconv.ParseUint(fmt.Sprintf("%d", config.FileDescriptorUAPI), 10, 32)
|
||||||
@@ -460,6 +444,9 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
// return
|
// return
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
// Wrap TUN device with packet filter for DNS proxy
|
||||||
|
middleDev = middleDevice.NewMiddleDevice(tdev)
|
||||||
|
|
||||||
wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ")
|
wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ")
|
||||||
// Use filtered device instead of raw TUN device
|
// Use filtered device instead of raw TUN device
|
||||||
dev = device.NewDevice(middleDev, sharedBind, (*device.Logger)(wgLogger))
|
dev = device.NewDevice(middleDev, sharedBind, (*device.Logger)(wgLogger))
|
||||||
@@ -486,10 +473,28 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
logger.Error("Failed to bring up WireGuard device: %v", err)
|
logger.Error("Failed to bring up WireGuard device: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create and start DNS proxy
|
||||||
|
dnsProxy, err = dns.NewDNSProxy(tdev, config.MTU)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to create DNS proxy: %v", err)
|
||||||
|
}
|
||||||
|
if err := dnsProxy.Start(middleDev); err != nil {
|
||||||
|
logger.Error("Failed to start DNS proxy: %v", err)
|
||||||
|
}
|
||||||
|
ip := net.ParseIP("192.168.1.100")
|
||||||
|
if dnsProxy.AddDNSRecord("example.com", ip); err != nil {
|
||||||
|
logger.Error("Failed to add DNS record: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err = ConfigureInterface(interfaceName, wgData, config.MTU); err != nil {
|
if err = ConfigureInterface(interfaceName, wgData, config.MTU); err != nil {
|
||||||
logger.Error("Failed to configure interface: %v", err)
|
logger.Error("Failed to configure interface: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if addRoutes([]string{"10.30.30.30/32"}, interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add route for DNS server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: seperate adding the callback to this so we can init it above with the interface
|
||||||
peerMonitor = peermonitor.NewPeerMonitor(
|
peerMonitor = peermonitor.NewPeerMonitor(
|
||||||
func(siteID int, connected bool, rtt time.Duration) {
|
func(siteID int, connected bool, rtt time.Duration) {
|
||||||
// Find the site config to get endpoint information
|
// Find the site config to get endpoint information
|
||||||
@@ -528,11 +533,11 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
logger.Error("Failed to configure peer: %v", err)
|
logger.Error("Failed to configure peer: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil {
|
if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil { // this is something for darwin only thats required
|
||||||
logger.Error("Failed to add route for peer: %v", err)
|
logger.Error("Failed to add route for peer: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil {
|
if err := addRoutes(site.RemoteSubnets, interfaceName); err != nil {
|
||||||
logger.Error("Failed to add routes for remote subnets: %v", err)
|
logger.Error("Failed to add routes for remote subnets: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -635,7 +640,7 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add new remote subnet routes
|
// Add new remote subnet routes
|
||||||
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
|
if err := addRoutes(siteConfig.RemoteSubnets, interfaceName); err != nil {
|
||||||
logger.Error("Failed to add new remote subnet routes: %v", err)
|
logger.Error("Failed to add new remote subnet routes: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -688,7 +693,7 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
logger.Error("Failed to add route for new peer: %v", err)
|
logger.Error("Failed to add route for new peer: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
|
if err := addRoutes(siteConfig.RemoteSubnets, interfaceName); err != nil {
|
||||||
logger.Error("Failed to add routes for remote subnets: %v", err)
|
logger.Error("Failed to add routes for remote subnets: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -814,7 +819,7 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add routes for the new subnets
|
// Add routes for the new subnets
|
||||||
if err := addRoutesForRemoteSubnets(newSubnets, interfaceName); err != nil {
|
if err := addRoutes(newSubnets, interfaceName); err != nil {
|
||||||
logger.Error("Failed to add routes for new remote subnets: %v", err)
|
logger.Error("Failed to add routes for new remote subnets: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -927,10 +932,10 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
|
|
||||||
// Then, add routes for new subnets
|
// Then, add routes for new subnets
|
||||||
if len(updateSubnetsData.NewRemoteSubnets) > 0 {
|
if len(updateSubnetsData.NewRemoteSubnets) > 0 {
|
||||||
if err := addRoutesForRemoteSubnets(updateSubnetsData.NewRemoteSubnets, interfaceName); err != nil {
|
if err := addRoutes(updateSubnetsData.NewRemoteSubnets, interfaceName); err != nil {
|
||||||
logger.Error("Failed to add routes for new remote subnets: %v", err)
|
logger.Error("Failed to add routes for new remote subnets: %v", err)
|
||||||
// Attempt to rollback by re-adding old routes
|
// Attempt to rollback by re-adding old routes
|
||||||
if rollbackErr := addRoutesForRemoteSubnets(updateSubnetsData.OldRemoteSubnets, interfaceName); rollbackErr != nil {
|
if rollbackErr := addRoutes(updateSubnetsData.OldRemoteSubnets, interfaceName); rollbackErr != nil {
|
||||||
logger.Error("Failed to rollback old routes: %v", rollbackErr)
|
logger.Error("Failed to rollback old routes: %v", rollbackErr)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
58
olm/route.go
58
olm/route.go
@@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/olm/network"
|
"github.com/fosrl/olm/network"
|
||||||
|
"github.com/vishvananda/netlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
func DarwinAddRoute(destination string, gateway string, interfaceName string) error {
|
func DarwinAddRoute(destination string, gateway string, interfaceName string) error {
|
||||||
@@ -60,23 +61,40 @@ func LinuxAddRoute(destination string, gateway string, interfaceName string) err
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var cmd *exec.Cmd
|
// Parse destination CIDR
|
||||||
|
_, ipNet, err := net.ParseCIDR(destination)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid destination address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create route
|
||||||
|
route := &netlink.Route{
|
||||||
|
Dst: ipNet,
|
||||||
|
}
|
||||||
|
|
||||||
if gateway != "" {
|
if gateway != "" {
|
||||||
// Route with specific gateway
|
// Route with specific gateway
|
||||||
cmd = exec.Command("ip", "route", "add", destination, "via", gateway)
|
gw := net.ParseIP(gateway)
|
||||||
|
if gw == nil {
|
||||||
|
return fmt.Errorf("invalid gateway address: %s", gateway)
|
||||||
|
}
|
||||||
|
route.Gw = gw
|
||||||
|
logger.Info("Adding route to %s via gateway %s", destination, gateway)
|
||||||
} else if interfaceName != "" {
|
} else if interfaceName != "" {
|
||||||
// Route via interface
|
// Route via interface
|
||||||
cmd = exec.Command("ip", "route", "add", destination, "dev", interfaceName)
|
link, err := netlink.LinkByName(interfaceName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
|
||||||
|
}
|
||||||
|
route.LinkIndex = link.Attrs().Index
|
||||||
|
logger.Info("Adding route to %s via interface %s", destination, interfaceName)
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("either gateway or interface must be specified")
|
return fmt.Errorf("either gateway or interface must be specified")
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Running command: %v", cmd)
|
// Add the route
|
||||||
|
if err := netlink.RouteAdd(route); err != nil {
|
||||||
out, err := cmd.CombinedOutput()
|
return fmt.Errorf("failed to add route: %v", err)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("ip route command failed: %v, output: %s", err, out)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -87,12 +105,22 @@ func LinuxRemoveRoute(destination string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := exec.Command("ip", "route", "del", destination)
|
// Parse destination CIDR
|
||||||
logger.Info("Running command: %v", cmd)
|
_, ipNet, err := net.ParseCIDR(destination)
|
||||||
|
|
||||||
out, err := cmd.CombinedOutput()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("ip route delete command failed: %v, output: %s", err, out)
|
return fmt.Errorf("invalid destination address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create route to delete
|
||||||
|
route := &netlink.Route{
|
||||||
|
Dst: ipNet,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Removing route to %s", destination)
|
||||||
|
|
||||||
|
// Delete the route
|
||||||
|
if err := netlink.RouteDel(route); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete route: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -268,8 +296,8 @@ func removeRouteForNetworkConfig(destination string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// addRoutesForRemoteSubnets adds routes for each subnet in RemoteSubnets
|
// addRoutes adds routes for each subnet in RemoteSubnets
|
||||||
func addRoutesForRemoteSubnets(remoteSubnets []string, interfaceName string) error {
|
func addRoutes(remoteSubnets []string, interfaceName string) error {
|
||||||
if len(remoteSubnets) == 0 {
|
if len(remoteSubnets) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user