From 20b3331ffff7fec5a0149703599cb04c21a8d945 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 24 Nov 2025 17:04:33 -0500 Subject: [PATCH] Refactor modules --- DNS_PROXY_README.md | 186 -------- IMPLEMENTATION_SUMMARY.md | 214 ---------- api/api.go | 33 +- olm/unix.go => device/tun_unix.go | 8 +- olm/windows.go => device/tun_windows.go | 8 +- diff | 523 ----------------------- {olm => network}/interface.go | 16 +- {olm => network}/interface_notwindows.go | 2 +- {olm => network}/interface_windows.go | 2 +- {olm => network}/route.go | 27 +- {olm => network}/route_notwindows.go | 2 +- {olm => network}/route_windows.go | 2 +- network/{network.go => settings.go} | 6 + olm/olm.go | 42 +- olm/{common.go => util.go} | 0 15 files changed, 71 insertions(+), 1000 deletions(-) delete mode 100644 DNS_PROXY_README.md delete mode 100644 IMPLEMENTATION_SUMMARY.md rename olm/unix.go => device/tun_unix.go (77%) rename olm/windows.go => device/tun_windows.go (62%) delete mode 100644 diff rename {olm => network}/interface.go (91%) rename {olm => network}/interface_notwindows.go (92%) rename {olm => network}/interface_windows.go (99%) rename {olm => network}/route.go (88%) rename {olm => network}/route_notwindows.go (92%) rename {olm => network}/route_windows.go (99%) rename network/{network.go => settings.go} (97%) rename olm/{common.go => util.go} (100%) diff --git a/DNS_PROXY_README.md b/DNS_PROXY_README.md deleted file mode 100644 index 272ccd8..0000000 --- a/DNS_PROXY_README.md +++ /dev/null @@ -1,186 +0,0 @@ -# Virtual DNS Proxy Implementation - -## Overview - -This implementation adds a high-performance virtual DNS proxy that intercepts DNS queries destined for `10.30.30.30:53` before they reach the WireGuard tunnel. The proxy processes DNS queries using a gvisor netstack and forwards them to upstream DNS servers, bypassing the VPN tunnel entirely. - -## Architecture - -### Components - -1. **FilteredDevice** (`olm/device_filter.go`) - - Wraps the TUN device with packet filtering capabilities - - Provides fast packet inspection without deep packet processing - - Supports multiple filtering rules that can be added/removed dynamically - - Optimized for performance - only extracts destination IP on fast path - -2. **DNSProxy** (`olm/dns_proxy.go`) - - Uses gvisor netstack to handle DNS protocol processing - - Listens on `10.30.30.30:53` within its own network stack - - Forwards queries to Google DNS (8.8.8.8, 8.8.4.4) - - Writes responses directly back to the TUN device, bypassing WireGuard - -### Packet Flow - -``` -┌─────────────────────────────────────────────────────────────┐ -│ Application │ -└──────────────────────┬──────────────────────────────────────┘ - │ DNS Query to 10.30.30.30:53 - ▼ -┌─────────────────────────────────────────────────────────────┐ -│ TUN Interface │ -└──────────────────────┬──────────────────────────────────────┘ - │ - ▼ -┌─────────────────────────────────────────────────────────────┐ -│ FilteredDevice (Read) │ -│ - Fast IP extraction │ -│ - Rule matching (10.30.30.30) │ -└──────────────┬──────────────────────────────────────────────┘ - │ - ┌──────────┴──────────┐ - │ │ - ▼ ▼ -┌─────────┐ ┌─────────────────────────┐ -│DNS Proxy│ │ WireGuard Device │ -│Netstack │ │ (other traffic) │ -└────┬────┘ └─────────────────────────┘ - │ - │ Forward to 8.8.8.8 - ▼ -┌─────────────┐ -│ Internet │ -│ (Direct) │ -└──────┬──────┘ - │ DNS Response - ▼ -┌─────────────────────────────────────────────────────────────┐ -│ DNSProxy writes directly to TUN │ -└──────────────────────┬──────────────────────────────────────┘ - │ - ▼ -┌─────────────────────────────────────────────────────────────┐ -│ Application │ -└─────────────────────────────────────────────────────────────┘ -``` - -## Performance Considerations - -### Fast Path Optimization - -1. **Minimal Packet Inspection** - - Only extracts destination IP (bytes 16-19 for IPv4, 24-39 for IPv6) - - No deep packet inspection unless packet matches a rule - - Zero-copy operations where possible - -2. **Rule Matching** - - Simple IP comparison (not prefix matching for rules) - - Linear scan of rules (fast for small number of rules) - - Read-lock only for rule access - -3. **Packet Processing** - - Filtered packets are removed from the slice in-place - - Non-matching packets passed through with minimal overhead - - No memory allocation for packets that don't match rules - -### Memory Efficiency - -- Packet copies are only made when absolutely necessary -- gvisor netstack uses buffer pooling internally -- DNS proxy uses a separate goroutine for response handling - -## Usage - -### Configuration - -The DNS proxy is automatically started when the tunnel is created. By default: -- DNS proxy IP: `10.30.30.30` -- DNS port: `53` -- Upstream DNS: `8.8.8.8` (primary), `8.8.4.4` (fallback) - -### Testing - -To test the DNS proxy, configure your DNS settings to use `10.30.30.30`: - -```bash -# Using dig -dig @10.30.30.30 google.com - -# Using nslookup -nslookup google.com 10.30.30.30 -``` - -## Extensibility - -The `FilteredDevice` architecture is designed to be extensible: - -### Adding New Services - -To add a new service (e.g., HTTP proxy on 10.30.30.31): - -1. Create a new service similar to `DNSProxy` -2. Register a filter rule with `filteredDev.AddRule()` -3. Process packets in your handler -4. Write responses back to the TUN device - -Example: - -```go -// In your service -func (s *MyService) handlePacket(packet []byte) bool { - // Parse packet - // Process request - // Write response to TUN device - s.tunDevice.Write([][]byte{response}, 0) - return true // Drop from normal path -} - -// During initialization -filteredDev.AddRule(myServiceIP, myService.handlePacket) -``` - -### Adding Filtering Rules - -Rules can be added/removed dynamically: - -```go -// Add a rule -filteredDev.AddRule(netip.MustParseAddr("10.30.30.40"), handleSpecialIP) - -// Remove a rule -filteredDev.RemoveRule(netip.MustParseAddr("10.30.30.40")) -``` - -## Implementation Details - -### Why Direct TUN Write? - -The DNS proxy writes responses directly back to the TUN device instead of going through the filter because: -1. Responses should go to the host, not through WireGuard -2. Avoids infinite loops (response → filter → DNS proxy → ...) -3. Better performance (one less layer) - -### Thread Safety - -- `FilteredDevice` uses RWMutex for rule access (read-heavy workload) -- `DNSProxy` goroutines are properly synchronized -- TUN device write operations are thread-safe - -### Error Handling - -- Failed DNS queries fall back to secondary DNS server -- Malformed packets are logged but don't crash the proxy -- Context cancellation ensures clean shutdown - -## Future Enhancements - -Potential improvements: -1. DNS caching to reduce upstream queries -2. DNS-over-HTTPS (DoH) support -3. Custom DNS filtering/blocking -4. Metrics and monitoring -5. IPv6 support for DNS proxy -6. Multiple upstream DNS servers with health checking -7. HTTP/HTTPS proxy on different IPs -8. SOCKS5 proxy support diff --git a/IMPLEMENTATION_SUMMARY.md b/IMPLEMENTATION_SUMMARY.md deleted file mode 100644 index 4a95984..0000000 --- a/IMPLEMENTATION_SUMMARY.md +++ /dev/null @@ -1,214 +0,0 @@ -# Virtual DNS Proxy Implementation - Summary - -## What Was Implemented - -A high-performance virtual DNS proxy for the olm WireGuard client that intercepts DNS queries before they enter the WireGuard tunnel. The implementation consists of three main components: - -### 1. FilteredDevice (`olm/device_filter.go`) -A TUN device wrapper that provides fast packet filtering: -- **Performance**: 2.6 ns per packet inspection (benchmarked) -- **Zero overhead** for non-matching packets -- **Extensible**: Easy to add new filter rules for other services -- **Thread-safe**: Uses RWMutex for concurrent access - -Key features: -- Fast destination IP extraction (IPv4 and IPv6) -- Protocol and port extraction utilities -- Rule-based packet interception -- In-place packet filtering (no unnecessary allocations) - -### 2. DNSProxy (`olm/dns_proxy.go`) -A DNS proxy implementation using gvisor netstack: -- **Listens on**: `10.30.30.30:53` -- **Upstream DNS**: Google DNS (8.8.8.8, 8.8.4.4) -- **Bypass WireGuard**: DNS responses go directly to host -- **No tunnel overhead**: DNS queries don't consume VPN bandwidth - -Architecture: -- Uses gvisor netstack for full TCP/IP stack simulation -- Separate goroutines for DNS query handling and response writing -- Direct TUN device write for responses (bypasses filter) -- Automatic failover between primary and secondary DNS servers - -### 3. Integration (`olm/olm.go`) -Seamless integration into the tunnel lifecycle: -- Automatically started when tunnel is created -- Properly cleaned up when tunnel stops -- No configuration required (works out of the box) - -## Performance Characteristics - -### Packet Processing Speed -``` -BenchmarkExtractDestIP-16 1000000 2.619 ns/op -``` - -This means: -- Can process ~380 million packets/second per core -- Negligible overhead on WireGuard throughput -- No measurable latency impact - -### Memory Efficiency -- Zero allocations for non-matching packets -- Minimal allocations for DNS packets -- gvisor uses internal buffer pooling - -## How to Use - -### Basic Usage -The DNS proxy starts automatically when the tunnel is created. To use it: - -```bash -# Configure your system to use 10.30.30.30 as DNS server -# Or test with dig/nslookup: -dig @10.30.30.30 google.com -nslookup google.com 10.30.30.30 -``` - -### Adding New Virtual Services - -To add a new service (e.g., HTTP proxy on 10.30.30.31): - -```go -// 1. Create your service -type HTTPProxy struct { - tunDevice tun.Device - // ... other fields -} - -// 2. Implement packet handler -func (h *HTTPProxy) handlePacket(packet []byte) bool { - // Process packet - // Write response to h.tunDevice - return true // Drop from normal path -} - -// 3. Register with filter (in olm.go) -httpProxyIP := netip.MustParseAddr("10.30.30.31") -filteredDev.AddRule(httpProxyIP, httpProxy.handlePacket) -``` - -## Files Created - -1. **`olm/device_filter.go`** - TUN device wrapper with packet filtering -2. **`olm/dns_proxy.go`** - DNS proxy using gvisor netstack -3. **`olm/device_filter_test.go`** - Unit tests and benchmarks -4. **`DNS_PROXY_README.md`** - Detailed architecture documentation -5. **`IMPLEMENTATION_SUMMARY.md`** - This file - -## Testing - -Tests included: -- `TestExtractDestIP` - Validates IPv4/IPv6 IP extraction -- `TestGetProtocol` - Validates protocol extraction -- `BenchmarkExtractDestIP` - Performance benchmark - -Run tests: -```bash -go test ./olm -v -run "TestExtractDestIP|TestGetProtocol" -go test ./olm -bench=BenchmarkExtractDestIP -``` - -## Technical Details - -### Packet Flow -``` -Application → TUN → FilteredDevice → [DNS Proxy | WireGuard] - ↓ - DNS Response - ↓ - TUN ← Direct Write -``` - -### Why This Design? - -1. **Wrapping TUN device**: Allows interception before WireGuard encryption -2. **Fast path optimization**: Only extracts what's needed (destination IP) -3. **Direct TUN write**: Responses bypass WireGuard to go straight to host -4. **Separate netstack**: Isolated DNS processing doesn't affect main stack - -### Limitations & Future Work - -Current limitations: -- Only IPv4 DNS (10.30.30.30) -- Hardcoded upstream DNS servers -- No DNS caching -- No DNS filtering/blocking - -Potential enhancements: -- DNS caching layer -- DNS-over-HTTPS (DoH) -- IPv6 support -- Custom DNS rules/filtering -- HTTP/HTTPS proxy on other IPs -- SOCKS5 proxy support -- Metrics and monitoring - -## Extensibility Examples - -### Adding a TCP Service - -```go -type TCPProxy struct { - stack *stack.Stack - tunDevice tun.Device -} - -func (t *TCPProxy) handlePacket(packet []byte) bool { - // Check if it's TCP to our IP:port - proto, _ := GetProtocol(packet) - if proto != 6 { // TCP - return false - } - - port, _ := GetDestPort(packet) - if port != 8080 { - return false - } - - // Inject into our netstack - // ... handle TCP connection - return true -} -``` - -### Adding Multiple DNS Servers - -Modify `dns_proxy.go` to support multiple virtual DNS IPs: - -```go -const ( - DNSProxyIP1 = "10.30.30.30" - DNSProxyIP2 = "10.30.30.31" -) - -// Register multiple rules -filteredDev.AddRule(ip1, dnsProxy1.handlePacket) -filteredDev.AddRule(ip2, dnsProxy2.handlePacket) -``` - -## Build & Deploy - -```bash -# Build -cd /home/owen/fossorial/olm -go build -o olm-binary . - -# Test -go test ./olm -v - -# Benchmark -go test ./olm -bench=. -benchmem -``` - -## Conclusion - -This implementation provides: -- ✅ High-performance packet filtering (2.6 ns/packet) -- ✅ Zero overhead for non-DNS traffic -- ✅ Extensible architecture for future services -- ✅ Clean integration with existing codebase -- ✅ Comprehensive tests and documentation -- ✅ Production-ready code - -The DNS proxy successfully intercepts DNS queries to 10.30.30.30, processes them through a separate gvisor netstack, forwards to upstream DNS servers, and returns responses directly to the host - all while bypassing the WireGuard tunnel. diff --git a/api/api.go b/api/api.go index cf04a89..2316373 100644 --- a/api/api.go +++ b/api/api.go @@ -9,6 +9,7 @@ import ( "time" "github.com/fosrl/newt/logger" + "github.com/fosrl/olm/network" ) // ConnectionRequest defines the structure for an incoming connection request @@ -47,12 +48,12 @@ type PeerStatus struct { // StatusResponse is returned by the status endpoint type StatusResponse struct { - Connected bool `json:"connected"` - Registered bool `json:"registered"` - TunnelIP string `json:"tunnelIP,omitempty"` - Version string `json:"version,omitempty"` - OrgID string `json:"orgId,omitempty"` - PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"` + Connected bool `json:"connected"` + Registered bool `json:"registered"` + Version string `json:"version,omitempty"` + OrgID string `json:"orgId,omitempty"` + PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"` + NetworkSettings network.NetworkSettings `json:"networkSettings,omitempty"` } // API represents the HTTP server and its state @@ -70,7 +71,6 @@ type API struct { connectedAt time.Time isConnected bool isRegistered bool - tunnelIP string version string orgID string } @@ -206,13 +206,6 @@ func (s *API) SetRegistered(registered bool) { s.isRegistered = registered } -// SetTunnelIP sets the tunnel IP address -func (s *API) SetTunnelIP(tunnelIP string) { - s.statusMu.Lock() - defer s.statusMu.Unlock() - s.tunnelIP = tunnelIP -} - // SetVersion sets the olm version func (s *API) SetVersion(version string) { s.statusMu.Lock() @@ -300,12 +293,12 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) { defer s.statusMu.RUnlock() resp := StatusResponse{ - Connected: s.isConnected, - Registered: s.isRegistered, - TunnelIP: s.tunnelIP, - Version: s.version, - OrgID: s.orgID, - PeerStatuses: s.peerStatuses, + Connected: s.isConnected, + Registered: s.isRegistered, + Version: s.version, + OrgID: s.orgID, + PeerStatuses: s.peerStatuses, + NetworkSettings: network.GetSettings(), } w.Header().Set("Content-Type", "application/json") diff --git a/olm/unix.go b/device/tun_unix.go similarity index 77% rename from olm/unix.go rename to device/tun_unix.go index 06eb5c4..c9bab60 100644 --- a/olm/unix.go +++ b/device/tun_unix.go @@ -1,6 +1,6 @@ //go:build !windows -package olm +package device import ( "net" @@ -12,7 +12,7 @@ import ( "golang.zx2c4.com/wireguard/tun" ) -func createTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) { +func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) { dupTunFd, err := unix.Dup(int(tunFd)) if err != nil { logger.Error("Unable to dup tun fd: %v", err) @@ -35,10 +35,10 @@ func createTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) { return device, nil } -func uapiOpen(interfaceName string) (*os.File, error) { +func UapiOpen(interfaceName string) (*os.File, error) { return ipc.UAPIOpen(interfaceName) } -func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) { +func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) { return ipc.UAPIListen(interfaceName, fileUAPI) } diff --git a/olm/windows.go b/device/tun_windows.go similarity index 62% rename from olm/windows.go rename to device/tun_windows.go index b168930..edcd6f6 100644 --- a/olm/windows.go +++ b/device/tun_windows.go @@ -1,6 +1,6 @@ //go:build windows -package olm +package device import ( "errors" @@ -11,15 +11,15 @@ import ( "golang.zx2c4.com/wireguard/tun" ) -func createTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) { +func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) { return nil, errors.New("CreateTUNFromFile not supported on Windows") } -func uapiOpen(interfaceName string) (*os.File, error) { +func UapiOpen(interfaceName string) (*os.File, error) { return nil, nil } -func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) { +func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) { // On Windows, UAPIListen only takes one parameter return ipc.UAPIListen(interfaceName) } diff --git a/diff b/diff deleted file mode 100644 index da7e62c..0000000 --- a/diff +++ /dev/null @@ -1,523 +0,0 @@ -diff --git a/api/api.go b/api/api.go -index dd07751..0d2e4ef 100644 ---- a/api/api.go -+++ b/api/api.go -@@ -18,6 +18,11 @@ type ConnectionRequest struct { - Endpoint string `json:"endpoint"` - } - -+// SwitchOrgRequest defines the structure for switching organizations -+type SwitchOrgRequest struct { -+ OrgID string `json:"orgId"` -+} -+ - // PeerStatus represents the status of a peer connection - type PeerStatus struct { - SiteID int `json:"siteId"` -@@ -35,6 +40,7 @@ type StatusResponse struct { - Registered bool `json:"registered"` - TunnelIP string `json:"tunnelIP,omitempty"` - Version string `json:"version,omitempty"` -+ OrgID string `json:"orgId,omitempty"` - PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"` - } - -@@ -46,6 +52,7 @@ type API struct { - server *http.Server - connectionChan chan ConnectionRequest - shutdownChan chan struct{} -+ switchOrgChan chan SwitchOrgRequest - statusMu sync.RWMutex - peerStatuses map[int]*PeerStatus - connectedAt time.Time -@@ -53,6 +60,7 @@ type API struct { - isRegistered bool - tunnelIP string - version string -+ orgID string - } - - // NewAPI creates a new HTTP server that listens on a TCP address -@@ -61,6 +69,7 @@ func NewAPI(addr string) *API { - addr: addr, - connectionChan: make(chan ConnectionRequest, 1), - shutdownChan: make(chan struct{}, 1), -+ switchOrgChan: make(chan SwitchOrgRequest, 1), - peerStatuses: make(map[int]*PeerStatus), - } - -@@ -73,6 +82,7 @@ func NewAPISocket(socketPath string) *API { - socketPath: socketPath, - connectionChan: make(chan ConnectionRequest, 1), - shutdownChan: make(chan struct{}, 1), -+ switchOrgChan: make(chan SwitchOrgRequest, 1), - peerStatuses: make(map[int]*PeerStatus), - } - -@@ -85,6 +95,7 @@ func (s *API) Start() error { - mux.HandleFunc("/connect", s.handleConnect) - mux.HandleFunc("/status", s.handleStatus) - mux.HandleFunc("/exit", s.handleExit) -+ mux.HandleFunc("/switch-org", s.handleSwitchOrg) - - s.server = &http.Server{ - Handler: mux, -@@ -143,6 +154,11 @@ func (s *API) GetShutdownChannel() <-chan struct{} { - return s.shutdownChan - } - -+// GetSwitchOrgChannel returns the channel for receiving org switch requests -+func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest { -+ return s.switchOrgChan -+} -+ - // UpdatePeerStatus updates the status of a peer including endpoint and relay info - func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) { - s.statusMu.Lock() -@@ -198,6 +214,13 @@ func (s *API) SetVersion(version string) { - s.version = version - } - -+// SetOrgID sets the org ID -+func (s *API) SetOrgID(orgID string) { -+ s.statusMu.Lock() -+ defer s.statusMu.Unlock() -+ s.orgID = orgID -+} -+ - // UpdatePeerRelayStatus updates only the relay status of a peer - func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) { - s.statusMu.Lock() -@@ -261,6 +284,7 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) { - Registered: s.isRegistered, - TunnelIP: s.tunnelIP, - Version: s.version, -+ OrgID: s.orgID, - PeerStatuses: s.peerStatuses, - } - -@@ -292,3 +316,44 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) { - "status": "shutdown initiated", - }) - } -+ -+// handleSwitchOrg handles the /switch-org endpoint -+func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) { -+ if r.Method != http.MethodPost { -+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) -+ return -+ } -+ -+ var req SwitchOrgRequest -+ decoder := json.NewDecoder(r.Body) -+ if err := decoder.Decode(&req); err != nil { -+ http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest) -+ return -+ } -+ -+ // Validate required fields -+ if req.OrgID == "" { -+ http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest) -+ return -+ } -+ -+ logger.Info("Received org switch request to orgId: %s", req.OrgID) -+ -+ // Send the request to the main goroutine -+ select { -+ case s.switchOrgChan <- req: -+ // Signal sent successfully -+ default: -+ // Channel already has a signal, don't block -+ http.Error(w, "Org switch already in progress", http.StatusTooManyRequests) -+ return -+ } -+ -+ // Return a success response -+ w.Header().Set("Content-Type", "application/json") -+ w.WriteHeader(http.StatusAccepted) -+ json.NewEncoder(w).Encode(map[string]string{ -+ "status": "org switch initiated", -+ "orgId": req.OrgID, -+ }) -+} -diff --git a/olm/olm.go b/olm/olm.go -index 78080c4..5e292d6 100644 ---- a/olm/olm.go -+++ b/olm/olm.go -@@ -58,6 +58,58 @@ type Config struct { - OrgID string - } - -+// tunnelState holds all the active tunnel resources that need cleanup -+type tunnelState struct { -+ dev *device.Device -+ tdev tun.Device -+ uapiListener net.Listener -+ peerMonitor *peermonitor.PeerMonitor -+ stopRegister func() -+ connected bool -+} -+ -+// teardownTunnel cleans up all tunnel resources -+func teardownTunnel(state *tunnelState) { -+ if state == nil { -+ return -+ } -+ -+ logger.Info("Tearing down tunnel...") -+ -+ // Stop registration messages -+ if state.stopRegister != nil { -+ state.stopRegister() -+ state.stopRegister = nil -+ } -+ -+ // Stop peer monitor -+ if state.peerMonitor != nil { -+ state.peerMonitor.Stop() -+ state.peerMonitor = nil -+ } -+ -+ // Close UAPI listener -+ if state.uapiListener != nil { -+ state.uapiListener.Close() -+ state.uapiListener = nil -+ } -+ -+ // Close WireGuard device -+ if state.dev != nil { -+ state.dev.Close() -+ state.dev = nil -+ } -+ -+ // Close TUN device -+ if state.tdev != nil { -+ state.tdev.Close() -+ state.tdev = nil -+ } -+ -+ state.connected = false -+ logger.Info("Tunnel teardown complete") -+} -+ - func Run(ctx context.Context, config Config) { - // Create a cancellable context for internal shutdown control - ctx, cancel := context.WithCancel(ctx) -@@ -75,14 +127,14 @@ func Run(ctx context.Context, config Config) { - pingTimeout = config.PingTimeoutDuration - doHolepunch = config.Holepunch - privateKey wgtypes.Key -- connected bool -- dev *device.Device - wgData WgData - holePunchData HolePunchData -- uapiListener net.Listener -- tdev tun.Device -+ orgID = config.OrgID - ) - -+ // Tunnel state that can be torn down and recreated -+ tunnel := &tunnelState{} -+ - stopHolepunch = make(chan struct{}) - stopPing = make(chan struct{}) - -@@ -110,6 +162,7 @@ func Run(ctx context.Context, config Config) { - } - - apiServer.SetVersion(config.Version) -+ apiServer.SetOrgID(orgID) - if err := apiServer.Start(); err != nil { - logger.Fatal("Failed to start HTTP server: %v", err) - } -@@ -249,14 +302,14 @@ func Run(ctx context.Context, config Config) { - olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) { - logger.Debug("Received message: %v", msg.Data) - -- if connected { -+ if tunnel.connected { - logger.Info("Already connected. Ignoring new connection request.") - return - } - -- if stopRegister != nil { -- stopRegister() -- stopRegister = nil -+ if tunnel.stopRegister != nil { -+ tunnel.stopRegister() -+ tunnel.stopRegister = nil - } - - close(stopHolepunch) -@@ -266,9 +319,9 @@ func Run(ctx context.Context, config Config) { - time.Sleep(500 * time.Millisecond) - - // if there is an existing tunnel then close it -- if dev != nil { -+ if tunnel.dev != nil { - logger.Info("Got new message. Closing existing tunnel!") -- dev.Close() -+ tunnel.dev.Close() - } - - jsonData, err := json.Marshal(msg.Data) -@@ -282,7 +335,7 @@ func Run(ctx context.Context, config Config) { - return - } - -- tdev, err = func() (tun.Device, error) { -+ tunnel.tdev, err = func() (tun.Device, error) { - if runtime.GOOS == "darwin" { - interfaceName, err := findUnusedUTUN() - if err != nil { -@@ -301,7 +354,7 @@ func Run(ctx context.Context, config Config) { - return - } - -- if realInterfaceName, err2 := tdev.Name(); err2 == nil { -+ if realInterfaceName, err2 := tunnel.tdev.Name(); err2 == nil { - interfaceName = realInterfaceName - } - -@@ -321,9 +374,9 @@ func Run(ctx context.Context, config Config) { - return - } - -- dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: ")) -+ tunnel.dev = device.NewDevice(tunnel.tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: ")) - -- uapiListener, err = uapiListen(interfaceName, fileUAPI) -+ tunnel.uapiListener, err = uapiListen(interfaceName, fileUAPI) - if err != nil { - logger.Error("Failed to listen on uapi socket: %v", err) - os.Exit(1) -@@ -331,16 +384,16 @@ func Run(ctx context.Context, config Config) { - - go func() { - for { -- conn, err := uapiListener.Accept() -+ conn, err := tunnel.uapiListener.Accept() - if err != nil { - return - } -- go dev.IpcHandle(conn) -+ go tunnel.dev.IpcHandle(conn) - } - }() - logger.Info("UAPI listener started") - -- if err = dev.Up(); err != nil { -+ if err = tunnel.dev.Up(); err != nil { - logger.Error("Failed to bring up WireGuard device: %v", err) - } - if err = ConfigureInterface(interfaceName, wgData); err != nil { -@@ -350,7 +403,7 @@ func Run(ctx context.Context, config Config) { - apiServer.SetTunnelIP(wgData.TunnelIP) - } - -- peerMonitor = peermonitor.NewPeerMonitor( -+ tunnel.peerMonitor = peermonitor.NewPeerMonitor( - func(siteID int, connected bool, rtt time.Duration) { - if apiServer != nil { - // Find the site config to get endpoint information -@@ -375,7 +428,7 @@ func Run(ctx context.Context, config Config) { - }, - fixKey(privateKey.String()), - olm, -- dev, -+ tunnel.dev, - doHolepunch, - ) - -@@ -388,7 +441,7 @@ func Run(ctx context.Context, config Config) { - // Format the endpoint before configuring the peer. - site.Endpoint = formatEndpoint(site.Endpoint) - -- if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil { -+ if err := ConfigurePeer(tunnel.dev, *site, privateKey, endpoint); err != nil { - logger.Error("Failed to configure peer: %v", err) - return - } -@@ -404,13 +457,13 @@ func Run(ctx context.Context, config Config) { - logger.Info("Configured peer %s", site.PublicKey) - } - -- peerMonitor.Start() -+ tunnel.peerMonitor.Start() - - if apiServer != nil { - apiServer.SetRegistered(true) - } - -- connected = true -+ tunnel.connected = true - - logger.Info("WireGuard device created.") - }) -@@ -441,7 +494,7 @@ func Run(ctx context.Context, config Config) { - } - - // Update the peer in WireGuard -- if dev != nil { -+ if tunnel.dev != nil { - // Find the existing peer to get old data - var oldRemoteSubnets string - var oldPublicKey string -@@ -456,7 +509,7 @@ func Run(ctx context.Context, config Config) { - // If the public key has changed, remove the old peer first - if oldPublicKey != "" && oldPublicKey != updateData.PublicKey { - logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey) -- if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil { -+ if err := RemovePeer(tunnel.dev, updateData.SiteId, oldPublicKey); err != nil { - logger.Error("Failed to remove old peer: %v", err) - return - } -@@ -465,7 +518,7 @@ func Run(ctx context.Context, config Config) { - // Format the endpoint before updating the peer. - siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint) - -- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { -+ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil { - logger.Error("Failed to update peer: %v", err) - return - } -@@ -524,11 +577,11 @@ func Run(ctx context.Context, config Config) { - } - - // Add the peer to WireGuard -- if dev != nil { -+ if tunnel.dev != nil { - // Format the endpoint before adding the new peer. - siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint) - -- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { -+ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil { - logger.Error("Failed to add peer: %v", err) - return - } -@@ -585,8 +638,8 @@ func Run(ctx context.Context, config Config) { - } - - // Remove the peer from WireGuard -- if dev != nil { -- if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil { -+ if tunnel.dev != nil { -+ if err := RemovePeer(tunnel.dev, removeData.SiteId, peerToRemove.PublicKey); err != nil { - logger.Error("Failed to remove peer: %v", err) - // Send error response if needed - return -@@ -640,7 +693,7 @@ func Run(ctx context.Context, config Config) { - apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true) - } - -- peerMonitor.HandleFailover(relayData.SiteId, primaryRelay) -+ tunnel.peerMonitor.HandleFailover(relayData.SiteId, primaryRelay) - }) - - olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) { -@@ -673,7 +726,7 @@ func Run(ctx context.Context, config Config) { - apiServer.SetConnectionStatus(true) - } - -- if connected { -+ if tunnel.connected { - logger.Debug("Already connected, skipping registration") - return nil - } -@@ -682,11 +735,11 @@ func Run(ctx context.Context, config Config) { - - logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch) - -- stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ -+ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ - "publicKey": publicKey.String(), - "relay": !doHolepunch, - "olmVersion": config.Version, -- "orgId": config.OrgID, -+ "orgId": orgID, - }, 1*time.Second) - - go keepSendingPing(olm) -@@ -705,6 +758,49 @@ func Run(ctx context.Context, config Config) { - } - defer olm.Close() - -+ // Listen for org switch requests from the API (after olm is created) -+ if apiServer != nil { -+ go func() { -+ for req := range apiServer.GetSwitchOrgChannel() { -+ logger.Info("Org switch requested via API to orgId: %s", req.OrgID) -+ -+ // Update the orgId -+ orgID = req.OrgID -+ -+ // Teardown existing tunnel -+ teardownTunnel(tunnel) -+ -+ // Reset tunnel state -+ tunnel = &tunnelState{} -+ -+ // Stop holepunch -+ select { -+ case <-stopHolepunch: -+ // Channel already closed -+ default: -+ close(stopHolepunch) -+ } -+ stopHolepunch = make(chan struct{}) -+ -+ // Clear API server state -+ apiServer.SetRegistered(false) -+ apiServer.SetTunnelIP("") -+ apiServer.SetOrgID(orgID) -+ -+ // Send new registration message with updated orgId -+ publicKey := privateKey.PublicKey() -+ logger.Info("Sending registration message with new orgId: %s", orgID) -+ -+ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ -+ "publicKey": publicKey.String(), -+ "relay": !doHolepunch, -+ "olmVersion": config.Version, -+ "orgId": orgID, -+ }, 1*time.Second) -+ } -+ }() -+ } -+ - select { - case <-ctx.Done(): - logger.Info("Context cancelled") -@@ -717,9 +813,9 @@ func Run(ctx context.Context, config Config) { - close(stopHolepunch) - } - -- if stopRegister != nil { -- stopRegister() -- stopRegister = nil -+ if tunnel.stopRegister != nil { -+ tunnel.stopRegister() -+ tunnel.stopRegister = nil - } - - select { -@@ -729,16 +825,8 @@ func Run(ctx context.Context, config Config) { - close(stopPing) - } - -- if peerMonitor != nil { -- peerMonitor.Stop() -- } -- -- if uapiListener != nil { -- uapiListener.Close() -- } -- if dev != nil { -- dev.Close() -- } -+ // Use teardownTunnel to clean up all tunnel resources -+ teardownTunnel(tunnel) - - if apiServer != nil { - apiServer.Stop() diff --git a/olm/interface.go b/network/interface.go similarity index 91% rename from olm/interface.go rename to network/interface.go index 622382d..e110ec1 100644 --- a/olm/interface.go +++ b/network/interface.go @@ -1,4 +1,4 @@ -package olm +package network import ( "fmt" @@ -10,16 +10,15 @@ import ( "time" "github.com/fosrl/newt/logger" - "github.com/fosrl/olm/network" "github.com/vishvananda/netlink" ) // ConfigureInterface configures a network interface with an IP address and brings it up -func ConfigureInterface(interfaceName string, wgData WgData, mtu int) error { - logger.Info("The tunnel IP is: %s", wgData.TunnelIP) +func ConfigureInterface(interfaceName string, tunnelIp string, mtu int) error { + logger.Info("The tunnel IP is: %s", tunnelIp) // Parse the IP address and network - ip, ipNet, err := net.ParseCIDR(wgData.TunnelIP) + ip, ipNet, err := net.ParseCIDR(tunnelIp) if err != nil { return fmt.Errorf("invalid IP address: %v", err) } @@ -31,9 +30,8 @@ func ConfigureInterface(interfaceName string, wgData WgData, mtu int) error { logger.Debug("The destination address is: %s", destinationAddress) // network.SetTunnelRemoteAddress() // what does this do? - network.SetIPv4Settings([]string{destinationAddress}, []string{mask}) - network.SetMTU(mtu) - apiServer.SetTunnelIP(destinationAddress) + SetIPv4Settings([]string{destinationAddress}, []string{mask}) + SetMTU(mtu) if interfaceName == "" { return nil @@ -89,7 +87,7 @@ func waitForInterfaceUp(interfaceName string, expectedIP net.IP, timeout time.Du return fmt.Errorf("timed out waiting for interface %s to be up with IP %s", interfaceName, expectedIP) } -func findUnusedUTUN() (string, error) { +func FindUnusedUTUN() (string, error) { ifaces, err := net.Interfaces() if err != nil { return "", fmt.Errorf("failed to list interfaces: %v", err) diff --git a/olm/interface_notwindows.go b/network/interface_notwindows.go similarity index 92% rename from olm/interface_notwindows.go rename to network/interface_notwindows.go index 75e8553..5d15ace 100644 --- a/olm/interface_notwindows.go +++ b/network/interface_notwindows.go @@ -1,6 +1,6 @@ //go:build !windows -package olm +package network import ( "fmt" diff --git a/olm/interface_windows.go b/network/interface_windows.go similarity index 99% rename from olm/interface_windows.go rename to network/interface_windows.go index cf769bf..966486b 100644 --- a/olm/interface_windows.go +++ b/network/interface_windows.go @@ -1,6 +1,6 @@ //go:build windows -package olm +package network import ( "fmt" diff --git a/olm/route.go b/network/route.go similarity index 88% rename from olm/route.go rename to network/route.go index e4e4006..861fec1 100644 --- a/olm/route.go +++ b/network/route.go @@ -1,4 +1,4 @@ -package olm +package network import ( "fmt" @@ -8,7 +8,6 @@ import ( "strings" "github.com/fosrl/newt/logger" - "github.com/fosrl/olm/network" "github.com/vishvananda/netlink" ) @@ -126,8 +125,8 @@ func LinuxRemoveRoute(destination string) error { } // addRouteForServerIP adds an OS-specific route for the server IP -func addRouteForServerIP(serverIP, interfaceName string) error { - if err := addRouteForNetworkConfig(serverIP); err != nil { +func AddRouteForServerIP(serverIP, interfaceName string) error { + if err := AddRouteForNetworkConfig(serverIP); err != nil { return err } if interfaceName == "" { @@ -145,8 +144,8 @@ func addRouteForServerIP(serverIP, interfaceName string) error { } // removeRouteForServerIP removes an OS-specific route for the server IP -func removeRouteForServerIP(serverIP string, interfaceName string) error { - if err := removeRouteForNetworkConfig(serverIP); err != nil { +func RemoveRouteForServerIP(serverIP string, interfaceName string) error { + if err := RemoveRouteForNetworkConfig(serverIP); err != nil { return err } if interfaceName == "" { @@ -163,7 +162,7 @@ func removeRouteForServerIP(serverIP string, interfaceName string) error { return nil } -func addRouteForNetworkConfig(destination string) error { +func AddRouteForNetworkConfig(destination string) error { // Parse the subnet to extract IP and mask _, ipNet, err := net.ParseCIDR(destination) if err != nil { @@ -174,12 +173,12 @@ func addRouteForNetworkConfig(destination string) error { mask := net.IP(ipNet.Mask).String() destinationAddress := ipNet.IP.String() - network.AddIPv4IncludedRoute(network.IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask}) + AddIPv4IncludedRoute(IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask}) return nil } -func removeRouteForNetworkConfig(destination string) error { +func RemoveRouteForNetworkConfig(destination string) error { // Parse the subnet to extract IP and mask _, ipNet, err := net.ParseCIDR(destination) if err != nil { @@ -190,13 +189,13 @@ func removeRouteForNetworkConfig(destination string) error { mask := net.IP(ipNet.Mask).String() destinationAddress := ipNet.IP.String() - network.RemoveIPv4IncludedRoute(network.IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask}) + RemoveIPv4IncludedRoute(IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask}) return nil } // addRoutes adds routes for each subnet in RemoteSubnets -func addRoutes(remoteSubnets []string, interfaceName string) error { +func AddRoutes(remoteSubnets []string, interfaceName string) error { if len(remoteSubnets) == 0 { return nil } @@ -208,7 +207,7 @@ func addRoutes(remoteSubnets []string, interfaceName string) error { continue } - if err := addRouteForNetworkConfig(subnet); err != nil { + if err := AddRouteForNetworkConfig(subnet); err != nil { logger.Error("Failed to add network config for subnet %s: %v", subnet, err) continue } @@ -241,7 +240,7 @@ func addRoutes(remoteSubnets []string, interfaceName string) error { } // removeRoutesForRemoteSubnets removes routes for each subnet in RemoteSubnets -func removeRoutesForRemoteSubnets(remoteSubnets []string) error { +func RemoveRoutesForRemoteSubnets(remoteSubnets []string) error { if len(remoteSubnets) == 0 { return nil } @@ -253,7 +252,7 @@ func removeRoutesForRemoteSubnets(remoteSubnets []string) error { continue } - if err := removeRouteForNetworkConfig(subnet); err != nil { + if err := RemoveRouteForNetworkConfig(subnet); err != nil { logger.Error("Failed to remove network config for subnet %s: %v", subnet, err) continue } diff --git a/olm/route_notwindows.go b/network/route_notwindows.go similarity index 92% rename from olm/route_notwindows.go rename to network/route_notwindows.go index 910ed26..6984c71 100644 --- a/olm/route_notwindows.go +++ b/network/route_notwindows.go @@ -1,6 +1,6 @@ //go:build !windows -package olm +package network func WindowsAddRoute(destination string, gateway string, interfaceName string) error { return nil diff --git a/olm/route_windows.go b/network/route_windows.go similarity index 99% rename from olm/route_windows.go rename to network/route_windows.go index c478a04..ba613b6 100644 --- a/olm/route_windows.go +++ b/network/route_windows.go @@ -1,6 +1,6 @@ //go:build windows -package olm +package network import ( "fmt" diff --git a/network/network.go b/network/settings.go similarity index 97% rename from network/network.go rename to network/settings.go index f9503ce..e7792e0 100644 --- a/network/network.go +++ b/network/settings.go @@ -177,6 +177,12 @@ func GetJSON() (string, error) { return string(data), nil } +func GetSettings() NetworkSettings { + networkSettingsMutex.RLock() + defer networkSettingsMutex.RUnlock() + return networkSettings +} + func GetIncrementor() int { networkSettingsMutex.Lock() defer networkSettingsMutex.Unlock() diff --git a/olm/olm.go b/olm/olm.go index 37e607e..65ec9c1 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -14,7 +14,7 @@ import ( "github.com/fosrl/newt/logger" "github.com/fosrl/newt/util" "github.com/fosrl/olm/api" - middleDevice "github.com/fosrl/olm/device" + olmDevice "github.com/fosrl/olm/device" "github.com/fosrl/olm/dns" dnsOverride "github.com/fosrl/olm/dns/override" "github.com/fosrl/olm/network" @@ -79,7 +79,7 @@ var ( holePunchData HolePunchData uapiListener net.Listener tdev tun.Device - middleDev *middleDevice.MiddleDevice + middleDev *olmDevice.MiddleDevice dnsProxy *dns.DNSProxy apiServer *api.API olmClient *websocket.Client @@ -201,7 +201,6 @@ func Init(ctx context.Context, config GlobalConfig) { // Clear peer statuses in API apiServer.SetRegistered(false) - apiServer.SetTunnelIP("") // Trigger re-registration with new orgId logger.Info("Re-registering with new orgId: %s", req.OrgID) @@ -418,11 +417,11 @@ func StartTunnel(config TunnelConfig) { tdev, err = func() (tun.Device, error) { if config.FileDescriptorTun != 0 { - return createTUNFromFD(config.FileDescriptorTun, config.MTU) + return olmDevice.CreateTUNFromFD(config.FileDescriptorTun, config.MTU) } var ifName = interfaceName if runtime.GOOS == "darwin" { // this is if we dont pass a fd - ifName, err = findUnusedUTUN() + ifName, err = network.FindUnusedUTUN() if err != nil { return nil, err } @@ -458,7 +457,7 @@ func StartTunnel(config TunnelConfig) { // } // Wrap TUN device with packet filter for DNS proxy - middleDev = middleDevice.NewMiddleDevice(tdev) + middleDev = olmDevice.NewMiddleDevice(tdev) wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ") // Use filtered device instead of raw TUN device @@ -495,11 +494,11 @@ func StartTunnel(config TunnelConfig) { logger.Error("Failed to create DNS proxy: %v", err) } - if err = ConfigureInterface(interfaceName, wgData, config.MTU); err != nil { + if err = network.ConfigureInterface(interfaceName, wgData.TunnelIP, config.MTU); err != nil { logger.Error("Failed to configure interface: %v", err) } - if addRoutes([]string{wgData.UtilitySubnet}, interfaceName); err != nil { // also route the utility subnet + if network.AddRoutes([]string{wgData.UtilitySubnet}, interfaceName); err != nil { // also route the utility subnet logger.Error("Failed to add route for utility subnet: %v", err) } @@ -549,11 +548,11 @@ func StartTunnel(config TunnelConfig) { logger.Error("Failed to configure peer: %v", err) return } - if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil { // this is something for darwin only thats required + if err := network.AddRouteForServerIP(site.ServerIP, interfaceName); err != nil { // this is something for darwin only thats required logger.Error("Failed to add route for peer: %v", err) return } - if err := addRoutes(site.RemoteSubnets, interfaceName); err != nil { + if err := network.AddRoutes(site.RemoteSubnets, interfaceName); err != nil { logger.Error("Failed to add routes for remote subnets: %v", err) return } @@ -676,13 +675,13 @@ func StartTunnel(config TunnelConfig) { // Handle remote subnet route changes if !stringSlicesEqual(oldRemoteSubnets, siteConfig.RemoteSubnets) { - if err := removeRoutesForRemoteSubnets(oldRemoteSubnets); err != nil { + if err := network.RemoveRoutesForRemoteSubnets(oldRemoteSubnets); err != nil { logger.Error("Failed to remove old remote subnet routes: %v", err) // Continue anyway to add new routes } // Add new remote subnet routes - if err := addRoutes(siteConfig.RemoteSubnets, interfaceName); err != nil { + if err := network.AddRoutes(siteConfig.RemoteSubnets, interfaceName); err != nil { logger.Error("Failed to add new remote subnet routes: %v", err) return } @@ -721,11 +720,11 @@ func StartTunnel(config TunnelConfig) { logger.Error("Failed to add peer: %v", err) return } - if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil { + if err := network.AddRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil { logger.Error("Failed to add route for new peer: %v", err) return } - if err := addRoutes(siteConfig.RemoteSubnets, interfaceName); err != nil { + if err := network.AddRoutes(siteConfig.RemoteSubnets, interfaceName); err != nil { logger.Error("Failed to add routes for remote subnets: %v", err) return } @@ -782,14 +781,14 @@ func StartTunnel(config TunnelConfig) { } // Remove route for the peer - err = removeRouteForServerIP(peerToRemove.ServerIP, interfaceName) + err = network.RemoveRouteForServerIP(peerToRemove.ServerIP, interfaceName) if err != nil { logger.Error("Failed to remove route for peer: %v", err) return } // Remove routes for remote subnets - if err := removeRoutesForRemoteSubnets(peerToRemove.RemoteSubnets); err != nil { + if err := network.RemoveRoutesForRemoteSubnets(peerToRemove.RemoteSubnets); err != nil { logger.Error("Failed to remove routes for remote subnets: %v", err) return } @@ -851,7 +850,7 @@ func StartTunnel(config TunnelConfig) { } // Add routes for the new subnets - if err := addRoutes(newSubnets, interfaceName); err != nil { + if err := network.AddRoutes(newSubnets, interfaceName); err != nil { logger.Error("Failed to add routes for new remote subnets: %v", err) return } @@ -912,7 +911,7 @@ func StartTunnel(config TunnelConfig) { } // Remove routes for the removed subnets - if err := removeRoutesForRemoteSubnets(removedSubnets); err != nil { + if err := network.RemoveRoutesForRemoteSubnets(removedSubnets); err != nil { logger.Error("Failed to remove routes for remote subnets: %v", err) return } @@ -955,7 +954,7 @@ func StartTunnel(config TunnelConfig) { // First, remove routes for old subnets if len(updateSubnetsData.OldRemoteSubnets) > 0 { - if err := removeRoutesForRemoteSubnets(updateSubnetsData.OldRemoteSubnets); err != nil { + if err := network.RemoveRoutesForRemoteSubnets(updateSubnetsData.OldRemoteSubnets); err != nil { logger.Error("Failed to remove routes for old remote subnets: %v", err) return } @@ -964,10 +963,10 @@ func StartTunnel(config TunnelConfig) { // Then, add routes for new subnets if len(updateSubnetsData.NewRemoteSubnets) > 0 { - if err := addRoutes(updateSubnetsData.NewRemoteSubnets, interfaceName); err != nil { + if err := network.AddRoutes(updateSubnetsData.NewRemoteSubnets, interfaceName); err != nil { logger.Error("Failed to add routes for new remote subnets: %v", err) // Attempt to rollback by re-adding old routes - if rollbackErr := addRoutes(updateSubnetsData.OldRemoteSubnets, interfaceName); rollbackErr != nil { + if rollbackErr := network.AddRoutes(updateSubnetsData.OldRemoteSubnets, interfaceName); rollbackErr != nil { logger.Error("Failed to rollback old routes: %v", rollbackErr) } return @@ -1186,7 +1185,6 @@ func StopTunnel() { // Update API server status apiServer.SetConnectionStatus(false) apiServer.SetRegistered(false) - apiServer.SetTunnelIP("") network.ClearNetworkSettings() diff --git a/olm/common.go b/olm/util.go similarity index 100% rename from olm/common.go rename to olm/util.go