Refactor modules

Former-commit-id: 20b3331fff
This commit is contained in:
Owen
2025-11-24 17:04:33 -05:00
parent 0802673048
commit fff234bdd5
15 changed files with 71 additions and 1000 deletions

View File

@@ -1,186 +0,0 @@
# Virtual DNS Proxy Implementation
## Overview
This implementation adds a high-performance virtual DNS proxy that intercepts DNS queries destined for `10.30.30.30:53` before they reach the WireGuard tunnel. The proxy processes DNS queries using a gvisor netstack and forwards them to upstream DNS servers, bypassing the VPN tunnel entirely.
## Architecture
### Components
1. **FilteredDevice** (`olm/device_filter.go`)
- Wraps the TUN device with packet filtering capabilities
- Provides fast packet inspection without deep packet processing
- Supports multiple filtering rules that can be added/removed dynamically
- Optimized for performance - only extracts destination IP on fast path
2. **DNSProxy** (`olm/dns_proxy.go`)
- Uses gvisor netstack to handle DNS protocol processing
- Listens on `10.30.30.30:53` within its own network stack
- Forwards queries to Google DNS (8.8.8.8, 8.8.4.4)
- Writes responses directly back to the TUN device, bypassing WireGuard
### Packet Flow
```
┌─────────────────────────────────────────────────────────────┐
│ Application │
└──────────────────────┬──────────────────────────────────────┘
│ DNS Query to 10.30.30.30:53
┌─────────────────────────────────────────────────────────────┐
│ TUN Interface │
└──────────────────────┬──────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ FilteredDevice (Read) │
│ - Fast IP extraction │
│ - Rule matching (10.30.30.30) │
└──────────────┬──────────────────────────────────────────────┘
┌──────────┴──────────┐
│ │
▼ ▼
┌─────────┐ ┌─────────────────────────┐
│DNS Proxy│ │ WireGuard Device │
│Netstack │ │ (other traffic) │
└────┬────┘ └─────────────────────────┘
│ Forward to 8.8.8.8
┌─────────────┐
│ Internet │
│ (Direct) │
└──────┬──────┘
│ DNS Response
┌─────────────────────────────────────────────────────────────┐
│ DNSProxy writes directly to TUN │
└──────────────────────┬──────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ Application │
└─────────────────────────────────────────────────────────────┘
```
## Performance Considerations
### Fast Path Optimization
1. **Minimal Packet Inspection**
- Only extracts destination IP (bytes 16-19 for IPv4, 24-39 for IPv6)
- No deep packet inspection unless packet matches a rule
- Zero-copy operations where possible
2. **Rule Matching**
- Simple IP comparison (not prefix matching for rules)
- Linear scan of rules (fast for small number of rules)
- Read-lock only for rule access
3. **Packet Processing**
- Filtered packets are removed from the slice in-place
- Non-matching packets passed through with minimal overhead
- No memory allocation for packets that don't match rules
### Memory Efficiency
- Packet copies are only made when absolutely necessary
- gvisor netstack uses buffer pooling internally
- DNS proxy uses a separate goroutine for response handling
## Usage
### Configuration
The DNS proxy is automatically started when the tunnel is created. By default:
- DNS proxy IP: `10.30.30.30`
- DNS port: `53`
- Upstream DNS: `8.8.8.8` (primary), `8.8.4.4` (fallback)
### Testing
To test the DNS proxy, configure your DNS settings to use `10.30.30.30`:
```bash
# Using dig
dig @10.30.30.30 google.com
# Using nslookup
nslookup google.com 10.30.30.30
```
## Extensibility
The `FilteredDevice` architecture is designed to be extensible:
### Adding New Services
To add a new service (e.g., HTTP proxy on 10.30.30.31):
1. Create a new service similar to `DNSProxy`
2. Register a filter rule with `filteredDev.AddRule()`
3. Process packets in your handler
4. Write responses back to the TUN device
Example:
```go
// In your service
func (s *MyService) handlePacket(packet []byte) bool {
// Parse packet
// Process request
// Write response to TUN device
s.tunDevice.Write([][]byte{response}, 0)
return true // Drop from normal path
}
// During initialization
filteredDev.AddRule(myServiceIP, myService.handlePacket)
```
### Adding Filtering Rules
Rules can be added/removed dynamically:
```go
// Add a rule
filteredDev.AddRule(netip.MustParseAddr("10.30.30.40"), handleSpecialIP)
// Remove a rule
filteredDev.RemoveRule(netip.MustParseAddr("10.30.30.40"))
```
## Implementation Details
### Why Direct TUN Write?
The DNS proxy writes responses directly back to the TUN device instead of going through the filter because:
1. Responses should go to the host, not through WireGuard
2. Avoids infinite loops (response → filter → DNS proxy → ...)
3. Better performance (one less layer)
### Thread Safety
- `FilteredDevice` uses RWMutex for rule access (read-heavy workload)
- `DNSProxy` goroutines are properly synchronized
- TUN device write operations are thread-safe
### Error Handling
- Failed DNS queries fall back to secondary DNS server
- Malformed packets are logged but don't crash the proxy
- Context cancellation ensures clean shutdown
## Future Enhancements
Potential improvements:
1. DNS caching to reduce upstream queries
2. DNS-over-HTTPS (DoH) support
3. Custom DNS filtering/blocking
4. Metrics and monitoring
5. IPv6 support for DNS proxy
6. Multiple upstream DNS servers with health checking
7. HTTP/HTTPS proxy on different IPs
8. SOCKS5 proxy support

View File

@@ -1,214 +0,0 @@
# Virtual DNS Proxy Implementation - Summary
## What Was Implemented
A high-performance virtual DNS proxy for the olm WireGuard client that intercepts DNS queries before they enter the WireGuard tunnel. The implementation consists of three main components:
### 1. FilteredDevice (`olm/device_filter.go`)
A TUN device wrapper that provides fast packet filtering:
- **Performance**: 2.6 ns per packet inspection (benchmarked)
- **Zero overhead** for non-matching packets
- **Extensible**: Easy to add new filter rules for other services
- **Thread-safe**: Uses RWMutex for concurrent access
Key features:
- Fast destination IP extraction (IPv4 and IPv6)
- Protocol and port extraction utilities
- Rule-based packet interception
- In-place packet filtering (no unnecessary allocations)
### 2. DNSProxy (`olm/dns_proxy.go`)
A DNS proxy implementation using gvisor netstack:
- **Listens on**: `10.30.30.30:53`
- **Upstream DNS**: Google DNS (8.8.8.8, 8.8.4.4)
- **Bypass WireGuard**: DNS responses go directly to host
- **No tunnel overhead**: DNS queries don't consume VPN bandwidth
Architecture:
- Uses gvisor netstack for full TCP/IP stack simulation
- Separate goroutines for DNS query handling and response writing
- Direct TUN device write for responses (bypasses filter)
- Automatic failover between primary and secondary DNS servers
### 3. Integration (`olm/olm.go`)
Seamless integration into the tunnel lifecycle:
- Automatically started when tunnel is created
- Properly cleaned up when tunnel stops
- No configuration required (works out of the box)
## Performance Characteristics
### Packet Processing Speed
```
BenchmarkExtractDestIP-16 1000000 2.619 ns/op
```
This means:
- Can process ~380 million packets/second per core
- Negligible overhead on WireGuard throughput
- No measurable latency impact
### Memory Efficiency
- Zero allocations for non-matching packets
- Minimal allocations for DNS packets
- gvisor uses internal buffer pooling
## How to Use
### Basic Usage
The DNS proxy starts automatically when the tunnel is created. To use it:
```bash
# Configure your system to use 10.30.30.30 as DNS server
# Or test with dig/nslookup:
dig @10.30.30.30 google.com
nslookup google.com 10.30.30.30
```
### Adding New Virtual Services
To add a new service (e.g., HTTP proxy on 10.30.30.31):
```go
// 1. Create your service
type HTTPProxy struct {
tunDevice tun.Device
// ... other fields
}
// 2. Implement packet handler
func (h *HTTPProxy) handlePacket(packet []byte) bool {
// Process packet
// Write response to h.tunDevice
return true // Drop from normal path
}
// 3. Register with filter (in olm.go)
httpProxyIP := netip.MustParseAddr("10.30.30.31")
filteredDev.AddRule(httpProxyIP, httpProxy.handlePacket)
```
## Files Created
1. **`olm/device_filter.go`** - TUN device wrapper with packet filtering
2. **`olm/dns_proxy.go`** - DNS proxy using gvisor netstack
3. **`olm/device_filter_test.go`** - Unit tests and benchmarks
4. **`DNS_PROXY_README.md`** - Detailed architecture documentation
5. **`IMPLEMENTATION_SUMMARY.md`** - This file
## Testing
Tests included:
- `TestExtractDestIP` - Validates IPv4/IPv6 IP extraction
- `TestGetProtocol` - Validates protocol extraction
- `BenchmarkExtractDestIP` - Performance benchmark
Run tests:
```bash
go test ./olm -v -run "TestExtractDestIP|TestGetProtocol"
go test ./olm -bench=BenchmarkExtractDestIP
```
## Technical Details
### Packet Flow
```
Application → TUN → FilteredDevice → [DNS Proxy | WireGuard]
DNS Response
TUN ← Direct Write
```
### Why This Design?
1. **Wrapping TUN device**: Allows interception before WireGuard encryption
2. **Fast path optimization**: Only extracts what's needed (destination IP)
3. **Direct TUN write**: Responses bypass WireGuard to go straight to host
4. **Separate netstack**: Isolated DNS processing doesn't affect main stack
### Limitations & Future Work
Current limitations:
- Only IPv4 DNS (10.30.30.30)
- Hardcoded upstream DNS servers
- No DNS caching
- No DNS filtering/blocking
Potential enhancements:
- DNS caching layer
- DNS-over-HTTPS (DoH)
- IPv6 support
- Custom DNS rules/filtering
- HTTP/HTTPS proxy on other IPs
- SOCKS5 proxy support
- Metrics and monitoring
## Extensibility Examples
### Adding a TCP Service
```go
type TCPProxy struct {
stack *stack.Stack
tunDevice tun.Device
}
func (t *TCPProxy) handlePacket(packet []byte) bool {
// Check if it's TCP to our IP:port
proto, _ := GetProtocol(packet)
if proto != 6 { // TCP
return false
}
port, _ := GetDestPort(packet)
if port != 8080 {
return false
}
// Inject into our netstack
// ... handle TCP connection
return true
}
```
### Adding Multiple DNS Servers
Modify `dns_proxy.go` to support multiple virtual DNS IPs:
```go
const (
DNSProxyIP1 = "10.30.30.30"
DNSProxyIP2 = "10.30.30.31"
)
// Register multiple rules
filteredDev.AddRule(ip1, dnsProxy1.handlePacket)
filteredDev.AddRule(ip2, dnsProxy2.handlePacket)
```
## Build & Deploy
```bash
# Build
cd /home/owen/fossorial/olm
go build -o olm-binary .
# Test
go test ./olm -v
# Benchmark
go test ./olm -bench=. -benchmem
```
## Conclusion
This implementation provides:
- ✅ High-performance packet filtering (2.6 ns/packet)
- ✅ Zero overhead for non-DNS traffic
- ✅ Extensible architecture for future services
- ✅ Clean integration with existing codebase
- ✅ Comprehensive tests and documentation
- ✅ Production-ready code
The DNS proxy successfully intercepts DNS queries to 10.30.30.30, processes them through a separate gvisor netstack, forwards to upstream DNS servers, and returns responses directly to the host - all while bypassing the WireGuard tunnel.

View File

@@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"github.com/fosrl/olm/network"
) )
// ConnectionRequest defines the structure for an incoming connection request // ConnectionRequest defines the structure for an incoming connection request
@@ -49,10 +50,10 @@ type PeerStatus struct {
type StatusResponse struct { type StatusResponse struct {
Connected bool `json:"connected"` Connected bool `json:"connected"`
Registered bool `json:"registered"` Registered bool `json:"registered"`
TunnelIP string `json:"tunnelIP,omitempty"`
Version string `json:"version,omitempty"` Version string `json:"version,omitempty"`
OrgID string `json:"orgId,omitempty"` OrgID string `json:"orgId,omitempty"`
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"` PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
NetworkSettings network.NetworkSettings `json:"networkSettings,omitempty"`
} }
// API represents the HTTP server and its state // API represents the HTTP server and its state
@@ -70,7 +71,6 @@ type API struct {
connectedAt time.Time connectedAt time.Time
isConnected bool isConnected bool
isRegistered bool isRegistered bool
tunnelIP string
version string version string
orgID string orgID string
} }
@@ -206,13 +206,6 @@ func (s *API) SetRegistered(registered bool) {
s.isRegistered = registered s.isRegistered = registered
} }
// SetTunnelIP sets the tunnel IP address
func (s *API) SetTunnelIP(tunnelIP string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.tunnelIP = tunnelIP
}
// SetVersion sets the olm version // SetVersion sets the olm version
func (s *API) SetVersion(version string) { func (s *API) SetVersion(version string) {
s.statusMu.Lock() s.statusMu.Lock()
@@ -302,10 +295,10 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
resp := StatusResponse{ resp := StatusResponse{
Connected: s.isConnected, Connected: s.isConnected,
Registered: s.isRegistered, Registered: s.isRegistered,
TunnelIP: s.tunnelIP,
Version: s.version, Version: s.version,
OrgID: s.orgID, OrgID: s.orgID,
PeerStatuses: s.peerStatuses, PeerStatuses: s.peerStatuses,
NetworkSettings: network.GetSettings(),
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")

View File

@@ -1,6 +1,6 @@
//go:build !windows //go:build !windows
package olm package device
import ( import (
"net" "net"
@@ -12,7 +12,7 @@ import (
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
) )
func createTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) { func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
dupTunFd, err := unix.Dup(int(tunFd)) dupTunFd, err := unix.Dup(int(tunFd))
if err != nil { if err != nil {
logger.Error("Unable to dup tun fd: %v", err) logger.Error("Unable to dup tun fd: %v", err)
@@ -35,10 +35,10 @@ func createTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
return device, nil return device, nil
} }
func uapiOpen(interfaceName string) (*os.File, error) { func UapiOpen(interfaceName string) (*os.File, error) {
return ipc.UAPIOpen(interfaceName) return ipc.UAPIOpen(interfaceName)
} }
func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) { func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
return ipc.UAPIListen(interfaceName, fileUAPI) return ipc.UAPIListen(interfaceName, fileUAPI)
} }

View File

@@ -1,6 +1,6 @@
//go:build windows //go:build windows
package olm package device
import ( import (
"errors" "errors"
@@ -11,15 +11,15 @@ import (
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
) )
func createTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) { func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
return nil, errors.New("CreateTUNFromFile not supported on Windows") return nil, errors.New("CreateTUNFromFile not supported on Windows")
} }
func uapiOpen(interfaceName string) (*os.File, error) { func UapiOpen(interfaceName string) (*os.File, error) {
return nil, nil return nil, nil
} }
func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) { func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
// On Windows, UAPIListen only takes one parameter // On Windows, UAPIListen only takes one parameter
return ipc.UAPIListen(interfaceName) return ipc.UAPIListen(interfaceName)
} }

523
diff
View File

@@ -1,523 +0,0 @@
diff --git a/api/api.go b/api/api.go
index dd07751..0d2e4ef 100644
--- a/api/api.go
+++ b/api/api.go
@@ -18,6 +18,11 @@ type ConnectionRequest struct {
Endpoint string `json:"endpoint"`
}
+// SwitchOrgRequest defines the structure for switching organizations
+type SwitchOrgRequest struct {
+ OrgID string `json:"orgId"`
+}
+
// PeerStatus represents the status of a peer connection
type PeerStatus struct {
SiteID int `json:"siteId"`
@@ -35,6 +40,7 @@ type StatusResponse struct {
Registered bool `json:"registered"`
TunnelIP string `json:"tunnelIP,omitempty"`
Version string `json:"version,omitempty"`
+ OrgID string `json:"orgId,omitempty"`
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
}
@@ -46,6 +52,7 @@ type API struct {
server *http.Server
connectionChan chan ConnectionRequest
shutdownChan chan struct{}
+ switchOrgChan chan SwitchOrgRequest
statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus
connectedAt time.Time
@@ -53,6 +60,7 @@ type API struct {
isRegistered bool
tunnelIP string
version string
+ orgID string
}
// NewAPI creates a new HTTP server that listens on a TCP address
@@ -61,6 +69,7 @@ func NewAPI(addr string) *API {
addr: addr,
connectionChan: make(chan ConnectionRequest, 1),
shutdownChan: make(chan struct{}, 1),
+ switchOrgChan: make(chan SwitchOrgRequest, 1),
peerStatuses: make(map[int]*PeerStatus),
}
@@ -73,6 +82,7 @@ func NewAPISocket(socketPath string) *API {
socketPath: socketPath,
connectionChan: make(chan ConnectionRequest, 1),
shutdownChan: make(chan struct{}, 1),
+ switchOrgChan: make(chan SwitchOrgRequest, 1),
peerStatuses: make(map[int]*PeerStatus),
}
@@ -85,6 +95,7 @@ func (s *API) Start() error {
mux.HandleFunc("/connect", s.handleConnect)
mux.HandleFunc("/status", s.handleStatus)
mux.HandleFunc("/exit", s.handleExit)
+ mux.HandleFunc("/switch-org", s.handleSwitchOrg)
s.server = &http.Server{
Handler: mux,
@@ -143,6 +154,11 @@ func (s *API) GetShutdownChannel() <-chan struct{} {
return s.shutdownChan
}
+// GetSwitchOrgChannel returns the channel for receiving org switch requests
+func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest {
+ return s.switchOrgChan
+}
+
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
s.statusMu.Lock()
@@ -198,6 +214,13 @@ func (s *API) SetVersion(version string) {
s.version = version
}
+// SetOrgID sets the org ID
+func (s *API) SetOrgID(orgID string) {
+ s.statusMu.Lock()
+ defer s.statusMu.Unlock()
+ s.orgID = orgID
+}
+
// UpdatePeerRelayStatus updates only the relay status of a peer
func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
s.statusMu.Lock()
@@ -261,6 +284,7 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
Registered: s.isRegistered,
TunnelIP: s.tunnelIP,
Version: s.version,
+ OrgID: s.orgID,
PeerStatuses: s.peerStatuses,
}
@@ -292,3 +316,44 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
"status": "shutdown initiated",
})
}
+
+// handleSwitchOrg handles the /switch-org endpoint
+func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ var req SwitchOrgRequest
+ decoder := json.NewDecoder(r.Body)
+ if err := decoder.Decode(&req); err != nil {
+ http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
+ return
+ }
+
+ // Validate required fields
+ if req.OrgID == "" {
+ http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest)
+ return
+ }
+
+ logger.Info("Received org switch request to orgId: %s", req.OrgID)
+
+ // Send the request to the main goroutine
+ select {
+ case s.switchOrgChan <- req:
+ // Signal sent successfully
+ default:
+ // Channel already has a signal, don't block
+ http.Error(w, "Org switch already in progress", http.StatusTooManyRequests)
+ return
+ }
+
+ // Return a success response
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusAccepted)
+ json.NewEncoder(w).Encode(map[string]string{
+ "status": "org switch initiated",
+ "orgId": req.OrgID,
+ })
+}
diff --git a/olm/olm.go b/olm/olm.go
index 78080c4..5e292d6 100644
--- a/olm/olm.go
+++ b/olm/olm.go
@@ -58,6 +58,58 @@ type Config struct {
OrgID string
}
+// tunnelState holds all the active tunnel resources that need cleanup
+type tunnelState struct {
+ dev *device.Device
+ tdev tun.Device
+ uapiListener net.Listener
+ peerMonitor *peermonitor.PeerMonitor
+ stopRegister func()
+ connected bool
+}
+
+// teardownTunnel cleans up all tunnel resources
+func teardownTunnel(state *tunnelState) {
+ if state == nil {
+ return
+ }
+
+ logger.Info("Tearing down tunnel...")
+
+ // Stop registration messages
+ if state.stopRegister != nil {
+ state.stopRegister()
+ state.stopRegister = nil
+ }
+
+ // Stop peer monitor
+ if state.peerMonitor != nil {
+ state.peerMonitor.Stop()
+ state.peerMonitor = nil
+ }
+
+ // Close UAPI listener
+ if state.uapiListener != nil {
+ state.uapiListener.Close()
+ state.uapiListener = nil
+ }
+
+ // Close WireGuard device
+ if state.dev != nil {
+ state.dev.Close()
+ state.dev = nil
+ }
+
+ // Close TUN device
+ if state.tdev != nil {
+ state.tdev.Close()
+ state.tdev = nil
+ }
+
+ state.connected = false
+ logger.Info("Tunnel teardown complete")
+}
+
func Run(ctx context.Context, config Config) {
// Create a cancellable context for internal shutdown control
ctx, cancel := context.WithCancel(ctx)
@@ -75,14 +127,14 @@ func Run(ctx context.Context, config Config) {
pingTimeout = config.PingTimeoutDuration
doHolepunch = config.Holepunch
privateKey wgtypes.Key
- connected bool
- dev *device.Device
wgData WgData
holePunchData HolePunchData
- uapiListener net.Listener
- tdev tun.Device
+ orgID = config.OrgID
)
+ // Tunnel state that can be torn down and recreated
+ tunnel := &tunnelState{}
+
stopHolepunch = make(chan struct{})
stopPing = make(chan struct{})
@@ -110,6 +162,7 @@ func Run(ctx context.Context, config Config) {
}
apiServer.SetVersion(config.Version)
+ apiServer.SetOrgID(orgID)
if err := apiServer.Start(); err != nil {
logger.Fatal("Failed to start HTTP server: %v", err)
}
@@ -249,14 +302,14 @@ func Run(ctx context.Context, config Config) {
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
- if connected {
+ if tunnel.connected {
logger.Info("Already connected. Ignoring new connection request.")
return
}
- if stopRegister != nil {
- stopRegister()
- stopRegister = nil
+ if tunnel.stopRegister != nil {
+ tunnel.stopRegister()
+ tunnel.stopRegister = nil
}
close(stopHolepunch)
@@ -266,9 +319,9 @@ func Run(ctx context.Context, config Config) {
time.Sleep(500 * time.Millisecond)
// if there is an existing tunnel then close it
- if dev != nil {
+ if tunnel.dev != nil {
logger.Info("Got new message. Closing existing tunnel!")
- dev.Close()
+ tunnel.dev.Close()
}
jsonData, err := json.Marshal(msg.Data)
@@ -282,7 +335,7 @@ func Run(ctx context.Context, config Config) {
return
}
- tdev, err = func() (tun.Device, error) {
+ tunnel.tdev, err = func() (tun.Device, error) {
if runtime.GOOS == "darwin" {
interfaceName, err := findUnusedUTUN()
if err != nil {
@@ -301,7 +354,7 @@ func Run(ctx context.Context, config Config) {
return
}
- if realInterfaceName, err2 := tdev.Name(); err2 == nil {
+ if realInterfaceName, err2 := tunnel.tdev.Name(); err2 == nil {
interfaceName = realInterfaceName
}
@@ -321,9 +374,9 @@ func Run(ctx context.Context, config Config) {
return
}
- dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
+ tunnel.dev = device.NewDevice(tunnel.tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
- uapiListener, err = uapiListen(interfaceName, fileUAPI)
+ tunnel.uapiListener, err = uapiListen(interfaceName, fileUAPI)
if err != nil {
logger.Error("Failed to listen on uapi socket: %v", err)
os.Exit(1)
@@ -331,16 +384,16 @@ func Run(ctx context.Context, config Config) {
go func() {
for {
- conn, err := uapiListener.Accept()
+ conn, err := tunnel.uapiListener.Accept()
if err != nil {
return
}
- go dev.IpcHandle(conn)
+ go tunnel.dev.IpcHandle(conn)
}
}()
logger.Info("UAPI listener started")
- if err = dev.Up(); err != nil {
+ if err = tunnel.dev.Up(); err != nil {
logger.Error("Failed to bring up WireGuard device: %v", err)
}
if err = ConfigureInterface(interfaceName, wgData); err != nil {
@@ -350,7 +403,7 @@ func Run(ctx context.Context, config Config) {
apiServer.SetTunnelIP(wgData.TunnelIP)
}
- peerMonitor = peermonitor.NewPeerMonitor(
+ tunnel.peerMonitor = peermonitor.NewPeerMonitor(
func(siteID int, connected bool, rtt time.Duration) {
if apiServer != nil {
// Find the site config to get endpoint information
@@ -375,7 +428,7 @@ func Run(ctx context.Context, config Config) {
},
fixKey(privateKey.String()),
olm,
- dev,
+ tunnel.dev,
doHolepunch,
)
@@ -388,7 +441,7 @@ func Run(ctx context.Context, config Config) {
// Format the endpoint before configuring the peer.
site.Endpoint = formatEndpoint(site.Endpoint)
- if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil {
+ if err := ConfigurePeer(tunnel.dev, *site, privateKey, endpoint); err != nil {
logger.Error("Failed to configure peer: %v", err)
return
}
@@ -404,13 +457,13 @@ func Run(ctx context.Context, config Config) {
logger.Info("Configured peer %s", site.PublicKey)
}
- peerMonitor.Start()
+ tunnel.peerMonitor.Start()
if apiServer != nil {
apiServer.SetRegistered(true)
}
- connected = true
+ tunnel.connected = true
logger.Info("WireGuard device created.")
})
@@ -441,7 +494,7 @@ func Run(ctx context.Context, config Config) {
}
// Update the peer in WireGuard
- if dev != nil {
+ if tunnel.dev != nil {
// Find the existing peer to get old data
var oldRemoteSubnets string
var oldPublicKey string
@@ -456,7 +509,7 @@ func Run(ctx context.Context, config Config) {
// If the public key has changed, remove the old peer first
if oldPublicKey != "" && oldPublicKey != updateData.PublicKey {
logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey)
- if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil {
+ if err := RemovePeer(tunnel.dev, updateData.SiteId, oldPublicKey); err != nil {
logger.Error("Failed to remove old peer: %v", err)
return
}
@@ -465,7 +518,7 @@ func Run(ctx context.Context, config Config) {
// Format the endpoint before updating the peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
+ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to update peer: %v", err)
return
}
@@ -524,11 +577,11 @@ func Run(ctx context.Context, config Config) {
}
// Add the peer to WireGuard
- if dev != nil {
+ if tunnel.dev != nil {
// Format the endpoint before adding the new peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
+ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to add peer: %v", err)
return
}
@@ -585,8 +638,8 @@ func Run(ctx context.Context, config Config) {
}
// Remove the peer from WireGuard
- if dev != nil {
- if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
+ if tunnel.dev != nil {
+ if err := RemovePeer(tunnel.dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
logger.Error("Failed to remove peer: %v", err)
// Send error response if needed
return
@@ -640,7 +693,7 @@ func Run(ctx context.Context, config Config) {
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
}
- peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
+ tunnel.peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
})
olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) {
@@ -673,7 +726,7 @@ func Run(ctx context.Context, config Config) {
apiServer.SetConnectionStatus(true)
}
- if connected {
+ if tunnel.connected {
logger.Debug("Already connected, skipping registration")
return nil
}
@@ -682,11 +735,11 @@ func Run(ctx context.Context, config Config) {
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch)
- stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
+ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !doHolepunch,
"olmVersion": config.Version,
- "orgId": config.OrgID,
+ "orgId": orgID,
}, 1*time.Second)
go keepSendingPing(olm)
@@ -705,6 +758,49 @@ func Run(ctx context.Context, config Config) {
}
defer olm.Close()
+ // Listen for org switch requests from the API (after olm is created)
+ if apiServer != nil {
+ go func() {
+ for req := range apiServer.GetSwitchOrgChannel() {
+ logger.Info("Org switch requested via API to orgId: %s", req.OrgID)
+
+ // Update the orgId
+ orgID = req.OrgID
+
+ // Teardown existing tunnel
+ teardownTunnel(tunnel)
+
+ // Reset tunnel state
+ tunnel = &tunnelState{}
+
+ // Stop holepunch
+ select {
+ case <-stopHolepunch:
+ // Channel already closed
+ default:
+ close(stopHolepunch)
+ }
+ stopHolepunch = make(chan struct{})
+
+ // Clear API server state
+ apiServer.SetRegistered(false)
+ apiServer.SetTunnelIP("")
+ apiServer.SetOrgID(orgID)
+
+ // Send new registration message with updated orgId
+ publicKey := privateKey.PublicKey()
+ logger.Info("Sending registration message with new orgId: %s", orgID)
+
+ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
+ "publicKey": publicKey.String(),
+ "relay": !doHolepunch,
+ "olmVersion": config.Version,
+ "orgId": orgID,
+ }, 1*time.Second)
+ }
+ }()
+ }
+
select {
case <-ctx.Done():
logger.Info("Context cancelled")
@@ -717,9 +813,9 @@ func Run(ctx context.Context, config Config) {
close(stopHolepunch)
}
- if stopRegister != nil {
- stopRegister()
- stopRegister = nil
+ if tunnel.stopRegister != nil {
+ tunnel.stopRegister()
+ tunnel.stopRegister = nil
}
select {
@@ -729,16 +825,8 @@ func Run(ctx context.Context, config Config) {
close(stopPing)
}
- if peerMonitor != nil {
- peerMonitor.Stop()
- }
-
- if uapiListener != nil {
- uapiListener.Close()
- }
- if dev != nil {
- dev.Close()
- }
+ // Use teardownTunnel to clean up all tunnel resources
+ teardownTunnel(tunnel)
if apiServer != nil {
apiServer.Stop()

View File

@@ -1,4 +1,4 @@
package olm package network
import ( import (
"fmt" "fmt"
@@ -10,16 +10,15 @@ import (
"time" "time"
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"github.com/fosrl/olm/network"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
) )
// ConfigureInterface configures a network interface with an IP address and brings it up // ConfigureInterface configures a network interface with an IP address and brings it up
func ConfigureInterface(interfaceName string, wgData WgData, mtu int) error { func ConfigureInterface(interfaceName string, tunnelIp string, mtu int) error {
logger.Info("The tunnel IP is: %s", wgData.TunnelIP) logger.Info("The tunnel IP is: %s", tunnelIp)
// Parse the IP address and network // Parse the IP address and network
ip, ipNet, err := net.ParseCIDR(wgData.TunnelIP) ip, ipNet, err := net.ParseCIDR(tunnelIp)
if err != nil { if err != nil {
return fmt.Errorf("invalid IP address: %v", err) return fmt.Errorf("invalid IP address: %v", err)
} }
@@ -31,9 +30,8 @@ func ConfigureInterface(interfaceName string, wgData WgData, mtu int) error {
logger.Debug("The destination address is: %s", destinationAddress) logger.Debug("The destination address is: %s", destinationAddress)
// network.SetTunnelRemoteAddress() // what does this do? // network.SetTunnelRemoteAddress() // what does this do?
network.SetIPv4Settings([]string{destinationAddress}, []string{mask}) SetIPv4Settings([]string{destinationAddress}, []string{mask})
network.SetMTU(mtu) SetMTU(mtu)
apiServer.SetTunnelIP(destinationAddress)
if interfaceName == "" { if interfaceName == "" {
return nil return nil
@@ -89,7 +87,7 @@ func waitForInterfaceUp(interfaceName string, expectedIP net.IP, timeout time.Du
return fmt.Errorf("timed out waiting for interface %s to be up with IP %s", interfaceName, expectedIP) return fmt.Errorf("timed out waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
} }
func findUnusedUTUN() (string, error) { func FindUnusedUTUN() (string, error) {
ifaces, err := net.Interfaces() ifaces, err := net.Interfaces()
if err != nil { if err != nil {
return "", fmt.Errorf("failed to list interfaces: %v", err) return "", fmt.Errorf("failed to list interfaces: %v", err)

View File

@@ -1,6 +1,6 @@
//go:build !windows //go:build !windows
package olm package network
import ( import (
"fmt" "fmt"

View File

@@ -1,6 +1,6 @@
//go:build windows //go:build windows
package olm package network
import ( import (
"fmt" "fmt"

View File

@@ -1,4 +1,4 @@
package olm package network
import ( import (
"fmt" "fmt"
@@ -8,7 +8,6 @@ import (
"strings" "strings"
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"github.com/fosrl/olm/network"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
) )
@@ -126,8 +125,8 @@ func LinuxRemoveRoute(destination string) error {
} }
// addRouteForServerIP adds an OS-specific route for the server IP // addRouteForServerIP adds an OS-specific route for the server IP
func addRouteForServerIP(serverIP, interfaceName string) error { func AddRouteForServerIP(serverIP, interfaceName string) error {
if err := addRouteForNetworkConfig(serverIP); err != nil { if err := AddRouteForNetworkConfig(serverIP); err != nil {
return err return err
} }
if interfaceName == "" { if interfaceName == "" {
@@ -145,8 +144,8 @@ func addRouteForServerIP(serverIP, interfaceName string) error {
} }
// removeRouteForServerIP removes an OS-specific route for the server IP // removeRouteForServerIP removes an OS-specific route for the server IP
func removeRouteForServerIP(serverIP string, interfaceName string) error { func RemoveRouteForServerIP(serverIP string, interfaceName string) error {
if err := removeRouteForNetworkConfig(serverIP); err != nil { if err := RemoveRouteForNetworkConfig(serverIP); err != nil {
return err return err
} }
if interfaceName == "" { if interfaceName == "" {
@@ -163,7 +162,7 @@ func removeRouteForServerIP(serverIP string, interfaceName string) error {
return nil return nil
} }
func addRouteForNetworkConfig(destination string) error { func AddRouteForNetworkConfig(destination string) error {
// Parse the subnet to extract IP and mask // Parse the subnet to extract IP and mask
_, ipNet, err := net.ParseCIDR(destination) _, ipNet, err := net.ParseCIDR(destination)
if err != nil { if err != nil {
@@ -174,12 +173,12 @@ func addRouteForNetworkConfig(destination string) error {
mask := net.IP(ipNet.Mask).String() mask := net.IP(ipNet.Mask).String()
destinationAddress := ipNet.IP.String() destinationAddress := ipNet.IP.String()
network.AddIPv4IncludedRoute(network.IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask}) AddIPv4IncludedRoute(IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask})
return nil return nil
} }
func removeRouteForNetworkConfig(destination string) error { func RemoveRouteForNetworkConfig(destination string) error {
// Parse the subnet to extract IP and mask // Parse the subnet to extract IP and mask
_, ipNet, err := net.ParseCIDR(destination) _, ipNet, err := net.ParseCIDR(destination)
if err != nil { if err != nil {
@@ -190,13 +189,13 @@ func removeRouteForNetworkConfig(destination string) error {
mask := net.IP(ipNet.Mask).String() mask := net.IP(ipNet.Mask).String()
destinationAddress := ipNet.IP.String() destinationAddress := ipNet.IP.String()
network.RemoveIPv4IncludedRoute(network.IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask}) RemoveIPv4IncludedRoute(IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask})
return nil return nil
} }
// addRoutes adds routes for each subnet in RemoteSubnets // addRoutes adds routes for each subnet in RemoteSubnets
func addRoutes(remoteSubnets []string, interfaceName string) error { func AddRoutes(remoteSubnets []string, interfaceName string) error {
if len(remoteSubnets) == 0 { if len(remoteSubnets) == 0 {
return nil return nil
} }
@@ -208,7 +207,7 @@ func addRoutes(remoteSubnets []string, interfaceName string) error {
continue continue
} }
if err := addRouteForNetworkConfig(subnet); err != nil { if err := AddRouteForNetworkConfig(subnet); err != nil {
logger.Error("Failed to add network config for subnet %s: %v", subnet, err) logger.Error("Failed to add network config for subnet %s: %v", subnet, err)
continue continue
} }
@@ -241,7 +240,7 @@ func addRoutes(remoteSubnets []string, interfaceName string) error {
} }
// removeRoutesForRemoteSubnets removes routes for each subnet in RemoteSubnets // removeRoutesForRemoteSubnets removes routes for each subnet in RemoteSubnets
func removeRoutesForRemoteSubnets(remoteSubnets []string) error { func RemoveRoutesForRemoteSubnets(remoteSubnets []string) error {
if len(remoteSubnets) == 0 { if len(remoteSubnets) == 0 {
return nil return nil
} }
@@ -253,7 +252,7 @@ func removeRoutesForRemoteSubnets(remoteSubnets []string) error {
continue continue
} }
if err := removeRouteForNetworkConfig(subnet); err != nil { if err := RemoveRouteForNetworkConfig(subnet); err != nil {
logger.Error("Failed to remove network config for subnet %s: %v", subnet, err) logger.Error("Failed to remove network config for subnet %s: %v", subnet, err)
continue continue
} }

View File

@@ -1,6 +1,6 @@
//go:build !windows //go:build !windows
package olm package network
func WindowsAddRoute(destination string, gateway string, interfaceName string) error { func WindowsAddRoute(destination string, gateway string, interfaceName string) error {
return nil return nil

View File

@@ -1,6 +1,6 @@
//go:build windows //go:build windows
package olm package network
import ( import (
"fmt" "fmt"

View File

@@ -177,6 +177,12 @@ func GetJSON() (string, error) {
return string(data), nil return string(data), nil
} }
func GetSettings() NetworkSettings {
networkSettingsMutex.RLock()
defer networkSettingsMutex.RUnlock()
return networkSettings
}
func GetIncrementor() int { func GetIncrementor() int {
networkSettingsMutex.Lock() networkSettingsMutex.Lock()
defer networkSettingsMutex.Unlock() defer networkSettingsMutex.Unlock()

View File

@@ -14,7 +14,7 @@ import (
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"github.com/fosrl/newt/util" "github.com/fosrl/newt/util"
"github.com/fosrl/olm/api" "github.com/fosrl/olm/api"
middleDevice "github.com/fosrl/olm/device" olmDevice "github.com/fosrl/olm/device"
"github.com/fosrl/olm/dns" "github.com/fosrl/olm/dns"
dnsOverride "github.com/fosrl/olm/dns/override" dnsOverride "github.com/fosrl/olm/dns/override"
"github.com/fosrl/olm/network" "github.com/fosrl/olm/network"
@@ -79,7 +79,7 @@ var (
holePunchData HolePunchData holePunchData HolePunchData
uapiListener net.Listener uapiListener net.Listener
tdev tun.Device tdev tun.Device
middleDev *middleDevice.MiddleDevice middleDev *olmDevice.MiddleDevice
dnsProxy *dns.DNSProxy dnsProxy *dns.DNSProxy
apiServer *api.API apiServer *api.API
olmClient *websocket.Client olmClient *websocket.Client
@@ -201,7 +201,6 @@ func Init(ctx context.Context, config GlobalConfig) {
// Clear peer statuses in API // Clear peer statuses in API
apiServer.SetRegistered(false) apiServer.SetRegistered(false)
apiServer.SetTunnelIP("")
// Trigger re-registration with new orgId // Trigger re-registration with new orgId
logger.Info("Re-registering with new orgId: %s", req.OrgID) logger.Info("Re-registering with new orgId: %s", req.OrgID)
@@ -418,11 +417,11 @@ func StartTunnel(config TunnelConfig) {
tdev, err = func() (tun.Device, error) { tdev, err = func() (tun.Device, error) {
if config.FileDescriptorTun != 0 { if config.FileDescriptorTun != 0 {
return createTUNFromFD(config.FileDescriptorTun, config.MTU) return olmDevice.CreateTUNFromFD(config.FileDescriptorTun, config.MTU)
} }
var ifName = interfaceName var ifName = interfaceName
if runtime.GOOS == "darwin" { // this is if we dont pass a fd if runtime.GOOS == "darwin" { // this is if we dont pass a fd
ifName, err = findUnusedUTUN() ifName, err = network.FindUnusedUTUN()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -458,7 +457,7 @@ func StartTunnel(config TunnelConfig) {
// } // }
// Wrap TUN device with packet filter for DNS proxy // Wrap TUN device with packet filter for DNS proxy
middleDev = middleDevice.NewMiddleDevice(tdev) middleDev = olmDevice.NewMiddleDevice(tdev)
wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ") wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ")
// Use filtered device instead of raw TUN device // Use filtered device instead of raw TUN device
@@ -495,11 +494,11 @@ func StartTunnel(config TunnelConfig) {
logger.Error("Failed to create DNS proxy: %v", err) logger.Error("Failed to create DNS proxy: %v", err)
} }
if err = ConfigureInterface(interfaceName, wgData, config.MTU); err != nil { if err = network.ConfigureInterface(interfaceName, wgData.TunnelIP, config.MTU); err != nil {
logger.Error("Failed to configure interface: %v", err) logger.Error("Failed to configure interface: %v", err)
} }
if addRoutes([]string{wgData.UtilitySubnet}, interfaceName); err != nil { // also route the utility subnet if network.AddRoutes([]string{wgData.UtilitySubnet}, interfaceName); err != nil { // also route the utility subnet
logger.Error("Failed to add route for utility subnet: %v", err) logger.Error("Failed to add route for utility subnet: %v", err)
} }
@@ -549,11 +548,11 @@ func StartTunnel(config TunnelConfig) {
logger.Error("Failed to configure peer: %v", err) logger.Error("Failed to configure peer: %v", err)
return return
} }
if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil { // this is something for darwin only thats required if err := network.AddRouteForServerIP(site.ServerIP, interfaceName); err != nil { // this is something for darwin only thats required
logger.Error("Failed to add route for peer: %v", err) logger.Error("Failed to add route for peer: %v", err)
return return
} }
if err := addRoutes(site.RemoteSubnets, interfaceName); err != nil { if err := network.AddRoutes(site.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err) logger.Error("Failed to add routes for remote subnets: %v", err)
return return
} }
@@ -676,13 +675,13 @@ func StartTunnel(config TunnelConfig) {
// Handle remote subnet route changes // Handle remote subnet route changes
if !stringSlicesEqual(oldRemoteSubnets, siteConfig.RemoteSubnets) { if !stringSlicesEqual(oldRemoteSubnets, siteConfig.RemoteSubnets) {
if err := removeRoutesForRemoteSubnets(oldRemoteSubnets); err != nil { if err := network.RemoveRoutesForRemoteSubnets(oldRemoteSubnets); err != nil {
logger.Error("Failed to remove old remote subnet routes: %v", err) logger.Error("Failed to remove old remote subnet routes: %v", err)
// Continue anyway to add new routes // Continue anyway to add new routes
} }
// Add new remote subnet routes // Add new remote subnet routes
if err := addRoutes(siteConfig.RemoteSubnets, interfaceName); err != nil { if err := network.AddRoutes(siteConfig.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add new remote subnet routes: %v", err) logger.Error("Failed to add new remote subnet routes: %v", err)
return return
} }
@@ -721,11 +720,11 @@ func StartTunnel(config TunnelConfig) {
logger.Error("Failed to add peer: %v", err) logger.Error("Failed to add peer: %v", err)
return return
} }
if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil { if err := network.AddRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil {
logger.Error("Failed to add route for new peer: %v", err) logger.Error("Failed to add route for new peer: %v", err)
return return
} }
if err := addRoutes(siteConfig.RemoteSubnets, interfaceName); err != nil { if err := network.AddRoutes(siteConfig.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err) logger.Error("Failed to add routes for remote subnets: %v", err)
return return
} }
@@ -782,14 +781,14 @@ func StartTunnel(config TunnelConfig) {
} }
// Remove route for the peer // Remove route for the peer
err = removeRouteForServerIP(peerToRemove.ServerIP, interfaceName) err = network.RemoveRouteForServerIP(peerToRemove.ServerIP, interfaceName)
if err != nil { if err != nil {
logger.Error("Failed to remove route for peer: %v", err) logger.Error("Failed to remove route for peer: %v", err)
return return
} }
// Remove routes for remote subnets // Remove routes for remote subnets
if err := removeRoutesForRemoteSubnets(peerToRemove.RemoteSubnets); err != nil { if err := network.RemoveRoutesForRemoteSubnets(peerToRemove.RemoteSubnets); err != nil {
logger.Error("Failed to remove routes for remote subnets: %v", err) logger.Error("Failed to remove routes for remote subnets: %v", err)
return return
} }
@@ -851,7 +850,7 @@ func StartTunnel(config TunnelConfig) {
} }
// Add routes for the new subnets // Add routes for the new subnets
if err := addRoutes(newSubnets, interfaceName); err != nil { if err := network.AddRoutes(newSubnets, interfaceName); err != nil {
logger.Error("Failed to add routes for new remote subnets: %v", err) logger.Error("Failed to add routes for new remote subnets: %v", err)
return return
} }
@@ -912,7 +911,7 @@ func StartTunnel(config TunnelConfig) {
} }
// Remove routes for the removed subnets // Remove routes for the removed subnets
if err := removeRoutesForRemoteSubnets(removedSubnets); err != nil { if err := network.RemoveRoutesForRemoteSubnets(removedSubnets); err != nil {
logger.Error("Failed to remove routes for remote subnets: %v", err) logger.Error("Failed to remove routes for remote subnets: %v", err)
return return
} }
@@ -955,7 +954,7 @@ func StartTunnel(config TunnelConfig) {
// First, remove routes for old subnets // First, remove routes for old subnets
if len(updateSubnetsData.OldRemoteSubnets) > 0 { if len(updateSubnetsData.OldRemoteSubnets) > 0 {
if err := removeRoutesForRemoteSubnets(updateSubnetsData.OldRemoteSubnets); err != nil { if err := network.RemoveRoutesForRemoteSubnets(updateSubnetsData.OldRemoteSubnets); err != nil {
logger.Error("Failed to remove routes for old remote subnets: %v", err) logger.Error("Failed to remove routes for old remote subnets: %v", err)
return return
} }
@@ -964,10 +963,10 @@ func StartTunnel(config TunnelConfig) {
// Then, add routes for new subnets // Then, add routes for new subnets
if len(updateSubnetsData.NewRemoteSubnets) > 0 { if len(updateSubnetsData.NewRemoteSubnets) > 0 {
if err := addRoutes(updateSubnetsData.NewRemoteSubnets, interfaceName); err != nil { if err := network.AddRoutes(updateSubnetsData.NewRemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add routes for new remote subnets: %v", err) logger.Error("Failed to add routes for new remote subnets: %v", err)
// Attempt to rollback by re-adding old routes // Attempt to rollback by re-adding old routes
if rollbackErr := addRoutes(updateSubnetsData.OldRemoteSubnets, interfaceName); rollbackErr != nil { if rollbackErr := network.AddRoutes(updateSubnetsData.OldRemoteSubnets, interfaceName); rollbackErr != nil {
logger.Error("Failed to rollback old routes: %v", rollbackErr) logger.Error("Failed to rollback old routes: %v", rollbackErr)
} }
return return
@@ -1186,7 +1185,6 @@ func StopTunnel() {
// Update API server status // Update API server status
apiServer.SetConnectionStatus(false) apiServer.SetConnectionStatus(false)
apiServer.SetRegistered(false) apiServer.SetRegistered(false)
apiServer.SetTunnelIP("")
network.ClearNetworkSettings() network.ClearNetworkSettings()