mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
@@ -1,186 +0,0 @@
|
|||||||
# Virtual DNS Proxy Implementation
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
This implementation adds a high-performance virtual DNS proxy that intercepts DNS queries destined for `10.30.30.30:53` before they reach the WireGuard tunnel. The proxy processes DNS queries using a gvisor netstack and forwards them to upstream DNS servers, bypassing the VPN tunnel entirely.
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
### Components
|
|
||||||
|
|
||||||
1. **FilteredDevice** (`olm/device_filter.go`)
|
|
||||||
- Wraps the TUN device with packet filtering capabilities
|
|
||||||
- Provides fast packet inspection without deep packet processing
|
|
||||||
- Supports multiple filtering rules that can be added/removed dynamically
|
|
||||||
- Optimized for performance - only extracts destination IP on fast path
|
|
||||||
|
|
||||||
2. **DNSProxy** (`olm/dns_proxy.go`)
|
|
||||||
- Uses gvisor netstack to handle DNS protocol processing
|
|
||||||
- Listens on `10.30.30.30:53` within its own network stack
|
|
||||||
- Forwards queries to Google DNS (8.8.8.8, 8.8.4.4)
|
|
||||||
- Writes responses directly back to the TUN device, bypassing WireGuard
|
|
||||||
|
|
||||||
### Packet Flow
|
|
||||||
|
|
||||||
```
|
|
||||||
┌─────────────────────────────────────────────────────────────┐
|
|
||||||
│ Application │
|
|
||||||
└──────────────────────┬──────────────────────────────────────┘
|
|
||||||
│ DNS Query to 10.30.30.30:53
|
|
||||||
▼
|
|
||||||
┌─────────────────────────────────────────────────────────────┐
|
|
||||||
│ TUN Interface │
|
|
||||||
└──────────────────────┬──────────────────────────────────────┘
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
┌─────────────────────────────────────────────────────────────┐
|
|
||||||
│ FilteredDevice (Read) │
|
|
||||||
│ - Fast IP extraction │
|
|
||||||
│ - Rule matching (10.30.30.30) │
|
|
||||||
└──────────────┬──────────────────────────────────────────────┘
|
|
||||||
│
|
|
||||||
┌──────────┴──────────┐
|
|
||||||
│ │
|
|
||||||
▼ ▼
|
|
||||||
┌─────────┐ ┌─────────────────────────┐
|
|
||||||
│DNS Proxy│ │ WireGuard Device │
|
|
||||||
│Netstack │ │ (other traffic) │
|
|
||||||
└────┬────┘ └─────────────────────────┘
|
|
||||||
│
|
|
||||||
│ Forward to 8.8.8.8
|
|
||||||
▼
|
|
||||||
┌─────────────┐
|
|
||||||
│ Internet │
|
|
||||||
│ (Direct) │
|
|
||||||
└──────┬──────┘
|
|
||||||
│ DNS Response
|
|
||||||
▼
|
|
||||||
┌─────────────────────────────────────────────────────────────┐
|
|
||||||
│ DNSProxy writes directly to TUN │
|
|
||||||
└──────────────────────┬──────────────────────────────────────┘
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
┌─────────────────────────────────────────────────────────────┐
|
|
||||||
│ Application │
|
|
||||||
└─────────────────────────────────────────────────────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
## Performance Considerations
|
|
||||||
|
|
||||||
### Fast Path Optimization
|
|
||||||
|
|
||||||
1. **Minimal Packet Inspection**
|
|
||||||
- Only extracts destination IP (bytes 16-19 for IPv4, 24-39 for IPv6)
|
|
||||||
- No deep packet inspection unless packet matches a rule
|
|
||||||
- Zero-copy operations where possible
|
|
||||||
|
|
||||||
2. **Rule Matching**
|
|
||||||
- Simple IP comparison (not prefix matching for rules)
|
|
||||||
- Linear scan of rules (fast for small number of rules)
|
|
||||||
- Read-lock only for rule access
|
|
||||||
|
|
||||||
3. **Packet Processing**
|
|
||||||
- Filtered packets are removed from the slice in-place
|
|
||||||
- Non-matching packets passed through with minimal overhead
|
|
||||||
- No memory allocation for packets that don't match rules
|
|
||||||
|
|
||||||
### Memory Efficiency
|
|
||||||
|
|
||||||
- Packet copies are only made when absolutely necessary
|
|
||||||
- gvisor netstack uses buffer pooling internally
|
|
||||||
- DNS proxy uses a separate goroutine for response handling
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
### Configuration
|
|
||||||
|
|
||||||
The DNS proxy is automatically started when the tunnel is created. By default:
|
|
||||||
- DNS proxy IP: `10.30.30.30`
|
|
||||||
- DNS port: `53`
|
|
||||||
- Upstream DNS: `8.8.8.8` (primary), `8.8.4.4` (fallback)
|
|
||||||
|
|
||||||
### Testing
|
|
||||||
|
|
||||||
To test the DNS proxy, configure your DNS settings to use `10.30.30.30`:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Using dig
|
|
||||||
dig @10.30.30.30 google.com
|
|
||||||
|
|
||||||
# Using nslookup
|
|
||||||
nslookup google.com 10.30.30.30
|
|
||||||
```
|
|
||||||
|
|
||||||
## Extensibility
|
|
||||||
|
|
||||||
The `FilteredDevice` architecture is designed to be extensible:
|
|
||||||
|
|
||||||
### Adding New Services
|
|
||||||
|
|
||||||
To add a new service (e.g., HTTP proxy on 10.30.30.31):
|
|
||||||
|
|
||||||
1. Create a new service similar to `DNSProxy`
|
|
||||||
2. Register a filter rule with `filteredDev.AddRule()`
|
|
||||||
3. Process packets in your handler
|
|
||||||
4. Write responses back to the TUN device
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```go
|
|
||||||
// In your service
|
|
||||||
func (s *MyService) handlePacket(packet []byte) bool {
|
|
||||||
// Parse packet
|
|
||||||
// Process request
|
|
||||||
// Write response to TUN device
|
|
||||||
s.tunDevice.Write([][]byte{response}, 0)
|
|
||||||
return true // Drop from normal path
|
|
||||||
}
|
|
||||||
|
|
||||||
// During initialization
|
|
||||||
filteredDev.AddRule(myServiceIP, myService.handlePacket)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Adding Filtering Rules
|
|
||||||
|
|
||||||
Rules can be added/removed dynamically:
|
|
||||||
|
|
||||||
```go
|
|
||||||
// Add a rule
|
|
||||||
filteredDev.AddRule(netip.MustParseAddr("10.30.30.40"), handleSpecialIP)
|
|
||||||
|
|
||||||
// Remove a rule
|
|
||||||
filteredDev.RemoveRule(netip.MustParseAddr("10.30.30.40"))
|
|
||||||
```
|
|
||||||
|
|
||||||
## Implementation Details
|
|
||||||
|
|
||||||
### Why Direct TUN Write?
|
|
||||||
|
|
||||||
The DNS proxy writes responses directly back to the TUN device instead of going through the filter because:
|
|
||||||
1. Responses should go to the host, not through WireGuard
|
|
||||||
2. Avoids infinite loops (response → filter → DNS proxy → ...)
|
|
||||||
3. Better performance (one less layer)
|
|
||||||
|
|
||||||
### Thread Safety
|
|
||||||
|
|
||||||
- `FilteredDevice` uses RWMutex for rule access (read-heavy workload)
|
|
||||||
- `DNSProxy` goroutines are properly synchronized
|
|
||||||
- TUN device write operations are thread-safe
|
|
||||||
|
|
||||||
### Error Handling
|
|
||||||
|
|
||||||
- Failed DNS queries fall back to secondary DNS server
|
|
||||||
- Malformed packets are logged but don't crash the proxy
|
|
||||||
- Context cancellation ensures clean shutdown
|
|
||||||
|
|
||||||
## Future Enhancements
|
|
||||||
|
|
||||||
Potential improvements:
|
|
||||||
1. DNS caching to reduce upstream queries
|
|
||||||
2. DNS-over-HTTPS (DoH) support
|
|
||||||
3. Custom DNS filtering/blocking
|
|
||||||
4. Metrics and monitoring
|
|
||||||
5. IPv6 support for DNS proxy
|
|
||||||
6. Multiple upstream DNS servers with health checking
|
|
||||||
7. HTTP/HTTPS proxy on different IPs
|
|
||||||
8. SOCKS5 proxy support
|
|
||||||
@@ -1,214 +0,0 @@
|
|||||||
# Virtual DNS Proxy Implementation - Summary
|
|
||||||
|
|
||||||
## What Was Implemented
|
|
||||||
|
|
||||||
A high-performance virtual DNS proxy for the olm WireGuard client that intercepts DNS queries before they enter the WireGuard tunnel. The implementation consists of three main components:
|
|
||||||
|
|
||||||
### 1. FilteredDevice (`olm/device_filter.go`)
|
|
||||||
A TUN device wrapper that provides fast packet filtering:
|
|
||||||
- **Performance**: 2.6 ns per packet inspection (benchmarked)
|
|
||||||
- **Zero overhead** for non-matching packets
|
|
||||||
- **Extensible**: Easy to add new filter rules for other services
|
|
||||||
- **Thread-safe**: Uses RWMutex for concurrent access
|
|
||||||
|
|
||||||
Key features:
|
|
||||||
- Fast destination IP extraction (IPv4 and IPv6)
|
|
||||||
- Protocol and port extraction utilities
|
|
||||||
- Rule-based packet interception
|
|
||||||
- In-place packet filtering (no unnecessary allocations)
|
|
||||||
|
|
||||||
### 2. DNSProxy (`olm/dns_proxy.go`)
|
|
||||||
A DNS proxy implementation using gvisor netstack:
|
|
||||||
- **Listens on**: `10.30.30.30:53`
|
|
||||||
- **Upstream DNS**: Google DNS (8.8.8.8, 8.8.4.4)
|
|
||||||
- **Bypass WireGuard**: DNS responses go directly to host
|
|
||||||
- **No tunnel overhead**: DNS queries don't consume VPN bandwidth
|
|
||||||
|
|
||||||
Architecture:
|
|
||||||
- Uses gvisor netstack for full TCP/IP stack simulation
|
|
||||||
- Separate goroutines for DNS query handling and response writing
|
|
||||||
- Direct TUN device write for responses (bypasses filter)
|
|
||||||
- Automatic failover between primary and secondary DNS servers
|
|
||||||
|
|
||||||
### 3. Integration (`olm/olm.go`)
|
|
||||||
Seamless integration into the tunnel lifecycle:
|
|
||||||
- Automatically started when tunnel is created
|
|
||||||
- Properly cleaned up when tunnel stops
|
|
||||||
- No configuration required (works out of the box)
|
|
||||||
|
|
||||||
## Performance Characteristics
|
|
||||||
|
|
||||||
### Packet Processing Speed
|
|
||||||
```
|
|
||||||
BenchmarkExtractDestIP-16 1000000 2.619 ns/op
|
|
||||||
```
|
|
||||||
|
|
||||||
This means:
|
|
||||||
- Can process ~380 million packets/second per core
|
|
||||||
- Negligible overhead on WireGuard throughput
|
|
||||||
- No measurable latency impact
|
|
||||||
|
|
||||||
### Memory Efficiency
|
|
||||||
- Zero allocations for non-matching packets
|
|
||||||
- Minimal allocations for DNS packets
|
|
||||||
- gvisor uses internal buffer pooling
|
|
||||||
|
|
||||||
## How to Use
|
|
||||||
|
|
||||||
### Basic Usage
|
|
||||||
The DNS proxy starts automatically when the tunnel is created. To use it:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Configure your system to use 10.30.30.30 as DNS server
|
|
||||||
# Or test with dig/nslookup:
|
|
||||||
dig @10.30.30.30 google.com
|
|
||||||
nslookup google.com 10.30.30.30
|
|
||||||
```
|
|
||||||
|
|
||||||
### Adding New Virtual Services
|
|
||||||
|
|
||||||
To add a new service (e.g., HTTP proxy on 10.30.30.31):
|
|
||||||
|
|
||||||
```go
|
|
||||||
// 1. Create your service
|
|
||||||
type HTTPProxy struct {
|
|
||||||
tunDevice tun.Device
|
|
||||||
// ... other fields
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. Implement packet handler
|
|
||||||
func (h *HTTPProxy) handlePacket(packet []byte) bool {
|
|
||||||
// Process packet
|
|
||||||
// Write response to h.tunDevice
|
|
||||||
return true // Drop from normal path
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. Register with filter (in olm.go)
|
|
||||||
httpProxyIP := netip.MustParseAddr("10.30.30.31")
|
|
||||||
filteredDev.AddRule(httpProxyIP, httpProxy.handlePacket)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Files Created
|
|
||||||
|
|
||||||
1. **`olm/device_filter.go`** - TUN device wrapper with packet filtering
|
|
||||||
2. **`olm/dns_proxy.go`** - DNS proxy using gvisor netstack
|
|
||||||
3. **`olm/device_filter_test.go`** - Unit tests and benchmarks
|
|
||||||
4. **`DNS_PROXY_README.md`** - Detailed architecture documentation
|
|
||||||
5. **`IMPLEMENTATION_SUMMARY.md`** - This file
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
Tests included:
|
|
||||||
- `TestExtractDestIP` - Validates IPv4/IPv6 IP extraction
|
|
||||||
- `TestGetProtocol` - Validates protocol extraction
|
|
||||||
- `BenchmarkExtractDestIP` - Performance benchmark
|
|
||||||
|
|
||||||
Run tests:
|
|
||||||
```bash
|
|
||||||
go test ./olm -v -run "TestExtractDestIP|TestGetProtocol"
|
|
||||||
go test ./olm -bench=BenchmarkExtractDestIP
|
|
||||||
```
|
|
||||||
|
|
||||||
## Technical Details
|
|
||||||
|
|
||||||
### Packet Flow
|
|
||||||
```
|
|
||||||
Application → TUN → FilteredDevice → [DNS Proxy | WireGuard]
|
|
||||||
↓
|
|
||||||
DNS Response
|
|
||||||
↓
|
|
||||||
TUN ← Direct Write
|
|
||||||
```
|
|
||||||
|
|
||||||
### Why This Design?
|
|
||||||
|
|
||||||
1. **Wrapping TUN device**: Allows interception before WireGuard encryption
|
|
||||||
2. **Fast path optimization**: Only extracts what's needed (destination IP)
|
|
||||||
3. **Direct TUN write**: Responses bypass WireGuard to go straight to host
|
|
||||||
4. **Separate netstack**: Isolated DNS processing doesn't affect main stack
|
|
||||||
|
|
||||||
### Limitations & Future Work
|
|
||||||
|
|
||||||
Current limitations:
|
|
||||||
- Only IPv4 DNS (10.30.30.30)
|
|
||||||
- Hardcoded upstream DNS servers
|
|
||||||
- No DNS caching
|
|
||||||
- No DNS filtering/blocking
|
|
||||||
|
|
||||||
Potential enhancements:
|
|
||||||
- DNS caching layer
|
|
||||||
- DNS-over-HTTPS (DoH)
|
|
||||||
- IPv6 support
|
|
||||||
- Custom DNS rules/filtering
|
|
||||||
- HTTP/HTTPS proxy on other IPs
|
|
||||||
- SOCKS5 proxy support
|
|
||||||
- Metrics and monitoring
|
|
||||||
|
|
||||||
## Extensibility Examples
|
|
||||||
|
|
||||||
### Adding a TCP Service
|
|
||||||
|
|
||||||
```go
|
|
||||||
type TCPProxy struct {
|
|
||||||
stack *stack.Stack
|
|
||||||
tunDevice tun.Device
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TCPProxy) handlePacket(packet []byte) bool {
|
|
||||||
// Check if it's TCP to our IP:port
|
|
||||||
proto, _ := GetProtocol(packet)
|
|
||||||
if proto != 6 { // TCP
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
port, _ := GetDestPort(packet)
|
|
||||||
if port != 8080 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Inject into our netstack
|
|
||||||
// ... handle TCP connection
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Adding Multiple DNS Servers
|
|
||||||
|
|
||||||
Modify `dns_proxy.go` to support multiple virtual DNS IPs:
|
|
||||||
|
|
||||||
```go
|
|
||||||
const (
|
|
||||||
DNSProxyIP1 = "10.30.30.30"
|
|
||||||
DNSProxyIP2 = "10.30.30.31"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Register multiple rules
|
|
||||||
filteredDev.AddRule(ip1, dnsProxy1.handlePacket)
|
|
||||||
filteredDev.AddRule(ip2, dnsProxy2.handlePacket)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Build & Deploy
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Build
|
|
||||||
cd /home/owen/fossorial/olm
|
|
||||||
go build -o olm-binary .
|
|
||||||
|
|
||||||
# Test
|
|
||||||
go test ./olm -v
|
|
||||||
|
|
||||||
# Benchmark
|
|
||||||
go test ./olm -bench=. -benchmem
|
|
||||||
```
|
|
||||||
|
|
||||||
## Conclusion
|
|
||||||
|
|
||||||
This implementation provides:
|
|
||||||
- ✅ High-performance packet filtering (2.6 ns/packet)
|
|
||||||
- ✅ Zero overhead for non-DNS traffic
|
|
||||||
- ✅ Extensible architecture for future services
|
|
||||||
- ✅ Clean integration with existing codebase
|
|
||||||
- ✅ Comprehensive tests and documentation
|
|
||||||
- ✅ Production-ready code
|
|
||||||
|
|
||||||
The DNS proxy successfully intercepts DNS queries to 10.30.30.30, processes them through a separate gvisor netstack, forwards to upstream DNS servers, and returns responses directly to the host - all while bypassing the WireGuard tunnel.
|
|
||||||
13
api/api.go
13
api/api.go
@@ -9,6 +9,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/fosrl/olm/network"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionRequest defines the structure for an incoming connection request
|
// ConnectionRequest defines the structure for an incoming connection request
|
||||||
@@ -49,10 +50,10 @@ type PeerStatus struct {
|
|||||||
type StatusResponse struct {
|
type StatusResponse struct {
|
||||||
Connected bool `json:"connected"`
|
Connected bool `json:"connected"`
|
||||||
Registered bool `json:"registered"`
|
Registered bool `json:"registered"`
|
||||||
TunnelIP string `json:"tunnelIP,omitempty"`
|
|
||||||
Version string `json:"version,omitempty"`
|
Version string `json:"version,omitempty"`
|
||||||
OrgID string `json:"orgId,omitempty"`
|
OrgID string `json:"orgId,omitempty"`
|
||||||
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
|
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
|
||||||
|
NetworkSettings network.NetworkSettings `json:"networkSettings,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// API represents the HTTP server and its state
|
// API represents the HTTP server and its state
|
||||||
@@ -70,7 +71,6 @@ type API struct {
|
|||||||
connectedAt time.Time
|
connectedAt time.Time
|
||||||
isConnected bool
|
isConnected bool
|
||||||
isRegistered bool
|
isRegistered bool
|
||||||
tunnelIP string
|
|
||||||
version string
|
version string
|
||||||
orgID string
|
orgID string
|
||||||
}
|
}
|
||||||
@@ -206,13 +206,6 @@ func (s *API) SetRegistered(registered bool) {
|
|||||||
s.isRegistered = registered
|
s.isRegistered = registered
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetTunnelIP sets the tunnel IP address
|
|
||||||
func (s *API) SetTunnelIP(tunnelIP string) {
|
|
||||||
s.statusMu.Lock()
|
|
||||||
defer s.statusMu.Unlock()
|
|
||||||
s.tunnelIP = tunnelIP
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetVersion sets the olm version
|
// SetVersion sets the olm version
|
||||||
func (s *API) SetVersion(version string) {
|
func (s *API) SetVersion(version string) {
|
||||||
s.statusMu.Lock()
|
s.statusMu.Lock()
|
||||||
@@ -302,10 +295,10 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
|
|||||||
resp := StatusResponse{
|
resp := StatusResponse{
|
||||||
Connected: s.isConnected,
|
Connected: s.isConnected,
|
||||||
Registered: s.isRegistered,
|
Registered: s.isRegistered,
|
||||||
TunnelIP: s.tunnelIP,
|
|
||||||
Version: s.version,
|
Version: s.version,
|
||||||
OrgID: s.orgID,
|
OrgID: s.orgID,
|
||||||
PeerStatuses: s.peerStatuses,
|
PeerStatuses: s.peerStatuses,
|
||||||
|
NetworkSettings: network.GetSettings(),
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
//go:build !windows
|
//go:build !windows
|
||||||
|
|
||||||
package olm
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
func createTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
|
func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
|
||||||
dupTunFd, err := unix.Dup(int(tunFd))
|
dupTunFd, err := unix.Dup(int(tunFd))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Unable to dup tun fd: %v", err)
|
logger.Error("Unable to dup tun fd: %v", err)
|
||||||
@@ -35,10 +35,10 @@ func createTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
|
|||||||
return device, nil
|
return device, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func uapiOpen(interfaceName string) (*os.File, error) {
|
func UapiOpen(interfaceName string) (*os.File, error) {
|
||||||
return ipc.UAPIOpen(interfaceName)
|
return ipc.UAPIOpen(interfaceName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
||||||
return ipc.UAPIListen(interfaceName, fileUAPI)
|
return ipc.UAPIListen(interfaceName, fileUAPI)
|
||||||
}
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
//go:build windows
|
//go:build windows
|
||||||
|
|
||||||
package olm
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
@@ -11,15 +11,15 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
func createTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
|
func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
|
||||||
return nil, errors.New("CreateTUNFromFile not supported on Windows")
|
return nil, errors.New("CreateTUNFromFile not supported on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
func uapiOpen(interfaceName string) (*os.File, error) {
|
func UapiOpen(interfaceName string) (*os.File, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
func UapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) {
|
||||||
// On Windows, UAPIListen only takes one parameter
|
// On Windows, UAPIListen only takes one parameter
|
||||||
return ipc.UAPIListen(interfaceName)
|
return ipc.UAPIListen(interfaceName)
|
||||||
}
|
}
|
||||||
523
diff
523
diff
@@ -1,523 +0,0 @@
|
|||||||
diff --git a/api/api.go b/api/api.go
|
|
||||||
index dd07751..0d2e4ef 100644
|
|
||||||
--- a/api/api.go
|
|
||||||
+++ b/api/api.go
|
|
||||||
@@ -18,6 +18,11 @@ type ConnectionRequest struct {
|
|
||||||
Endpoint string `json:"endpoint"`
|
|
||||||
}
|
|
||||||
|
|
||||||
+// SwitchOrgRequest defines the structure for switching organizations
|
|
||||||
+type SwitchOrgRequest struct {
|
|
||||||
+ OrgID string `json:"orgId"`
|
|
||||||
+}
|
|
||||||
+
|
|
||||||
// PeerStatus represents the status of a peer connection
|
|
||||||
type PeerStatus struct {
|
|
||||||
SiteID int `json:"siteId"`
|
|
||||||
@@ -35,6 +40,7 @@ type StatusResponse struct {
|
|
||||||
Registered bool `json:"registered"`
|
|
||||||
TunnelIP string `json:"tunnelIP,omitempty"`
|
|
||||||
Version string `json:"version,omitempty"`
|
|
||||||
+ OrgID string `json:"orgId,omitempty"`
|
|
||||||
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
@@ -46,6 +52,7 @@ type API struct {
|
|
||||||
server *http.Server
|
|
||||||
connectionChan chan ConnectionRequest
|
|
||||||
shutdownChan chan struct{}
|
|
||||||
+ switchOrgChan chan SwitchOrgRequest
|
|
||||||
statusMu sync.RWMutex
|
|
||||||
peerStatuses map[int]*PeerStatus
|
|
||||||
connectedAt time.Time
|
|
||||||
@@ -53,6 +60,7 @@ type API struct {
|
|
||||||
isRegistered bool
|
|
||||||
tunnelIP string
|
|
||||||
version string
|
|
||||||
+ orgID string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewAPI creates a new HTTP server that listens on a TCP address
|
|
||||||
@@ -61,6 +69,7 @@ func NewAPI(addr string) *API {
|
|
||||||
addr: addr,
|
|
||||||
connectionChan: make(chan ConnectionRequest, 1),
|
|
||||||
shutdownChan: make(chan struct{}, 1),
|
|
||||||
+ switchOrgChan: make(chan SwitchOrgRequest, 1),
|
|
||||||
peerStatuses: make(map[int]*PeerStatus),
|
|
||||||
}
|
|
||||||
|
|
||||||
@@ -73,6 +82,7 @@ func NewAPISocket(socketPath string) *API {
|
|
||||||
socketPath: socketPath,
|
|
||||||
connectionChan: make(chan ConnectionRequest, 1),
|
|
||||||
shutdownChan: make(chan struct{}, 1),
|
|
||||||
+ switchOrgChan: make(chan SwitchOrgRequest, 1),
|
|
||||||
peerStatuses: make(map[int]*PeerStatus),
|
|
||||||
}
|
|
||||||
|
|
||||||
@@ -85,6 +95,7 @@ func (s *API) Start() error {
|
|
||||||
mux.HandleFunc("/connect", s.handleConnect)
|
|
||||||
mux.HandleFunc("/status", s.handleStatus)
|
|
||||||
mux.HandleFunc("/exit", s.handleExit)
|
|
||||||
+ mux.HandleFunc("/switch-org", s.handleSwitchOrg)
|
|
||||||
|
|
||||||
s.server = &http.Server{
|
|
||||||
Handler: mux,
|
|
||||||
@@ -143,6 +154,11 @@ func (s *API) GetShutdownChannel() <-chan struct{} {
|
|
||||||
return s.shutdownChan
|
|
||||||
}
|
|
||||||
|
|
||||||
+// GetSwitchOrgChannel returns the channel for receiving org switch requests
|
|
||||||
+func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest {
|
|
||||||
+ return s.switchOrgChan
|
|
||||||
+}
|
|
||||||
+
|
|
||||||
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
|
|
||||||
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
|
|
||||||
s.statusMu.Lock()
|
|
||||||
@@ -198,6 +214,13 @@ func (s *API) SetVersion(version string) {
|
|
||||||
s.version = version
|
|
||||||
}
|
|
||||||
|
|
||||||
+// SetOrgID sets the org ID
|
|
||||||
+func (s *API) SetOrgID(orgID string) {
|
|
||||||
+ s.statusMu.Lock()
|
|
||||||
+ defer s.statusMu.Unlock()
|
|
||||||
+ s.orgID = orgID
|
|
||||||
+}
|
|
||||||
+
|
|
||||||
// UpdatePeerRelayStatus updates only the relay status of a peer
|
|
||||||
func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
|
|
||||||
s.statusMu.Lock()
|
|
||||||
@@ -261,6 +284,7 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
|
|
||||||
Registered: s.isRegistered,
|
|
||||||
TunnelIP: s.tunnelIP,
|
|
||||||
Version: s.version,
|
|
||||||
+ OrgID: s.orgID,
|
|
||||||
PeerStatuses: s.peerStatuses,
|
|
||||||
}
|
|
||||||
|
|
||||||
@@ -292,3 +316,44 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
|
|
||||||
"status": "shutdown initiated",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
+
|
|
||||||
+// handleSwitchOrg handles the /switch-org endpoint
|
|
||||||
+func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
|
|
||||||
+ if r.Method != http.MethodPost {
|
|
||||||
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
||||||
+ return
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ var req SwitchOrgRequest
|
|
||||||
+ decoder := json.NewDecoder(r.Body)
|
|
||||||
+ if err := decoder.Decode(&req); err != nil {
|
|
||||||
+ http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
|
|
||||||
+ return
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ // Validate required fields
|
|
||||||
+ if req.OrgID == "" {
|
|
||||||
+ http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest)
|
|
||||||
+ return
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ logger.Info("Received org switch request to orgId: %s", req.OrgID)
|
|
||||||
+
|
|
||||||
+ // Send the request to the main goroutine
|
|
||||||
+ select {
|
|
||||||
+ case s.switchOrgChan <- req:
|
|
||||||
+ // Signal sent successfully
|
|
||||||
+ default:
|
|
||||||
+ // Channel already has a signal, don't block
|
|
||||||
+ http.Error(w, "Org switch already in progress", http.StatusTooManyRequests)
|
|
||||||
+ return
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ // Return a success response
|
|
||||||
+ w.Header().Set("Content-Type", "application/json")
|
|
||||||
+ w.WriteHeader(http.StatusAccepted)
|
|
||||||
+ json.NewEncoder(w).Encode(map[string]string{
|
|
||||||
+ "status": "org switch initiated",
|
|
||||||
+ "orgId": req.OrgID,
|
|
||||||
+ })
|
|
||||||
+}
|
|
||||||
diff --git a/olm/olm.go b/olm/olm.go
|
|
||||||
index 78080c4..5e292d6 100644
|
|
||||||
--- a/olm/olm.go
|
|
||||||
+++ b/olm/olm.go
|
|
||||||
@@ -58,6 +58,58 @@ type Config struct {
|
|
||||||
OrgID string
|
|
||||||
}
|
|
||||||
|
|
||||||
+// tunnelState holds all the active tunnel resources that need cleanup
|
|
||||||
+type tunnelState struct {
|
|
||||||
+ dev *device.Device
|
|
||||||
+ tdev tun.Device
|
|
||||||
+ uapiListener net.Listener
|
|
||||||
+ peerMonitor *peermonitor.PeerMonitor
|
|
||||||
+ stopRegister func()
|
|
||||||
+ connected bool
|
|
||||||
+}
|
|
||||||
+
|
|
||||||
+// teardownTunnel cleans up all tunnel resources
|
|
||||||
+func teardownTunnel(state *tunnelState) {
|
|
||||||
+ if state == nil {
|
|
||||||
+ return
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ logger.Info("Tearing down tunnel...")
|
|
||||||
+
|
|
||||||
+ // Stop registration messages
|
|
||||||
+ if state.stopRegister != nil {
|
|
||||||
+ state.stopRegister()
|
|
||||||
+ state.stopRegister = nil
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ // Stop peer monitor
|
|
||||||
+ if state.peerMonitor != nil {
|
|
||||||
+ state.peerMonitor.Stop()
|
|
||||||
+ state.peerMonitor = nil
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ // Close UAPI listener
|
|
||||||
+ if state.uapiListener != nil {
|
|
||||||
+ state.uapiListener.Close()
|
|
||||||
+ state.uapiListener = nil
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ // Close WireGuard device
|
|
||||||
+ if state.dev != nil {
|
|
||||||
+ state.dev.Close()
|
|
||||||
+ state.dev = nil
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ // Close TUN device
|
|
||||||
+ if state.tdev != nil {
|
|
||||||
+ state.tdev.Close()
|
|
||||||
+ state.tdev = nil
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ state.connected = false
|
|
||||||
+ logger.Info("Tunnel teardown complete")
|
|
||||||
+}
|
|
||||||
+
|
|
||||||
func Run(ctx context.Context, config Config) {
|
|
||||||
// Create a cancellable context for internal shutdown control
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
@@ -75,14 +127,14 @@ func Run(ctx context.Context, config Config) {
|
|
||||||
pingTimeout = config.PingTimeoutDuration
|
|
||||||
doHolepunch = config.Holepunch
|
|
||||||
privateKey wgtypes.Key
|
|
||||||
- connected bool
|
|
||||||
- dev *device.Device
|
|
||||||
wgData WgData
|
|
||||||
holePunchData HolePunchData
|
|
||||||
- uapiListener net.Listener
|
|
||||||
- tdev tun.Device
|
|
||||||
+ orgID = config.OrgID
|
|
||||||
)
|
|
||||||
|
|
||||||
+ // Tunnel state that can be torn down and recreated
|
|
||||||
+ tunnel := &tunnelState{}
|
|
||||||
+
|
|
||||||
stopHolepunch = make(chan struct{})
|
|
||||||
stopPing = make(chan struct{})
|
|
||||||
|
|
||||||
@@ -110,6 +162,7 @@ func Run(ctx context.Context, config Config) {
|
|
||||||
}
|
|
||||||
|
|
||||||
apiServer.SetVersion(config.Version)
|
|
||||||
+ apiServer.SetOrgID(orgID)
|
|
||||||
if err := apiServer.Start(); err != nil {
|
|
||||||
logger.Fatal("Failed to start HTTP server: %v", err)
|
|
||||||
}
|
|
||||||
@@ -249,14 +302,14 @@ func Run(ctx context.Context, config Config) {
|
|
||||||
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
|
|
||||||
logger.Debug("Received message: %v", msg.Data)
|
|
||||||
|
|
||||||
- if connected {
|
|
||||||
+ if tunnel.connected {
|
|
||||||
logger.Info("Already connected. Ignoring new connection request.")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
- if stopRegister != nil {
|
|
||||||
- stopRegister()
|
|
||||||
- stopRegister = nil
|
|
||||||
+ if tunnel.stopRegister != nil {
|
|
||||||
+ tunnel.stopRegister()
|
|
||||||
+ tunnel.stopRegister = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
close(stopHolepunch)
|
|
||||||
@@ -266,9 +319,9 @@ func Run(ctx context.Context, config Config) {
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
|
|
||||||
// if there is an existing tunnel then close it
|
|
||||||
- if dev != nil {
|
|
||||||
+ if tunnel.dev != nil {
|
|
||||||
logger.Info("Got new message. Closing existing tunnel!")
|
|
||||||
- dev.Close()
|
|
||||||
+ tunnel.dev.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
|
||||||
@@ -282,7 +335,7 @@ func Run(ctx context.Context, config Config) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
- tdev, err = func() (tun.Device, error) {
|
|
||||||
+ tunnel.tdev, err = func() (tun.Device, error) {
|
|
||||||
if runtime.GOOS == "darwin" {
|
|
||||||
interfaceName, err := findUnusedUTUN()
|
|
||||||
if err != nil {
|
|
||||||
@@ -301,7 +354,7 @@ func Run(ctx context.Context, config Config) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
- if realInterfaceName, err2 := tdev.Name(); err2 == nil {
|
|
||||||
+ if realInterfaceName, err2 := tunnel.tdev.Name(); err2 == nil {
|
|
||||||
interfaceName = realInterfaceName
|
|
||||||
}
|
|
||||||
|
|
||||||
@@ -321,9 +374,9 @@ func Run(ctx context.Context, config Config) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
- dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
|
|
||||||
+ tunnel.dev = device.NewDevice(tunnel.tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
|
|
||||||
|
|
||||||
- uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
|
||||||
+ tunnel.uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to listen on uapi socket: %v", err)
|
|
||||||
os.Exit(1)
|
|
||||||
@@ -331,16 +384,16 @@ func Run(ctx context.Context, config Config) {
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
- conn, err := uapiListener.Accept()
|
|
||||||
+ conn, err := tunnel.uapiListener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
- go dev.IpcHandle(conn)
|
|
||||||
+ go tunnel.dev.IpcHandle(conn)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
logger.Info("UAPI listener started")
|
|
||||||
|
|
||||||
- if err = dev.Up(); err != nil {
|
|
||||||
+ if err = tunnel.dev.Up(); err != nil {
|
|
||||||
logger.Error("Failed to bring up WireGuard device: %v", err)
|
|
||||||
}
|
|
||||||
if err = ConfigureInterface(interfaceName, wgData); err != nil {
|
|
||||||
@@ -350,7 +403,7 @@ func Run(ctx context.Context, config Config) {
|
|
||||||
apiServer.SetTunnelIP(wgData.TunnelIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
- peerMonitor = peermonitor.NewPeerMonitor(
|
|
||||||
+ tunnel.peerMonitor = peermonitor.NewPeerMonitor(
|
|
||||||
func(siteID int, connected bool, rtt time.Duration) {
|
|
||||||
if apiServer != nil {
|
|
||||||
// Find the site config to get endpoint information
|
|
||||||
@@ -375,7 +428,7 @@ func Run(ctx context.Context, config Config) {
|
|
||||||
},
|
|
||||||
fixKey(privateKey.String()),
|
|
||||||
olm,
|
|
||||||
- dev,
|
|
||||||
+ tunnel.dev,
|
|
||||||
doHolepunch,
|
|
||||||
)
|
|
||||||
|
|
||||||
@@ -388,7 +441,7 @@ func Run(ctx context.Context, config Config) {
|
|
||||||
// Format the endpoint before configuring the peer.
|
|
||||||
site.Endpoint = formatEndpoint(site.Endpoint)
|
|
||||||
|
|
||||||
- if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil {
|
|
||||||
+ if err := ConfigurePeer(tunnel.dev, *site, privateKey, endpoint); err != nil {
|
|
||||||
logger.Error("Failed to configure peer: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
@@ -404,13 +457,13 @@ func Run(ctx context.Context, config Config) {
|
|
||||||
logger.Info("Configured peer %s", site.PublicKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
- peerMonitor.Start()
|
|
||||||
+ tunnel.peerMonitor.Start()
|
|
||||||
|
|
||||||
if apiServer != nil {
|
|
||||||
apiServer.SetRegistered(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
- connected = true
|
|
||||||
+ tunnel.connected = true
|
|
||||||
|
|
||||||
logger.Info("WireGuard device created.")
|
|
||||||
})
|
|
||||||
@@ -441,7 +494,7 @@ func Run(ctx context.Context, config Config) {
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update the peer in WireGuard
|
|
||||||
- if dev != nil {
|
|
||||||
+ if tunnel.dev != nil {
|
|
||||||
// Find the existing peer to get old data
|
|
||||||
var oldRemoteSubnets string
|
|
||||||
var oldPublicKey string
|
|
||||||
@@ -456,7 +509,7 @@ func Run(ctx context.Context, config Config) {
|
|
||||||
// If the public key has changed, remove the old peer first
|
|
||||||
if oldPublicKey != "" && oldPublicKey != updateData.PublicKey {
|
|
||||||
logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey)
|
|
||||||
- if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil {
|
|
||||||
+ if err := RemovePeer(tunnel.dev, updateData.SiteId, oldPublicKey); err != nil {
|
|
||||||
logger.Error("Failed to remove old peer: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
@@ -465,7 +518,7 @@ func Run(ctx context.Context, config Config) {
|
|
||||||
// Format the endpoint before updating the peer.
|
|
||||||
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
|
|
||||||
|
|
||||||
- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
|
||||||
+ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil {
|
|
||||||
logger.Error("Failed to update peer: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
@@ -524,11 +577,11 @@ func Run(ctx context.Context, config Config) {
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the peer to WireGuard
|
|
||||||
- if dev != nil {
|
|
||||||
+ if tunnel.dev != nil {
|
|
||||||
// Format the endpoint before adding the new peer.
|
|
||||||
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
|
|
||||||
|
|
||||||
- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
|
||||||
+ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil {
|
|
||||||
logger.Error("Failed to add peer: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
@@ -585,8 +638,8 @@ func Run(ctx context.Context, config Config) {
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove the peer from WireGuard
|
|
||||||
- if dev != nil {
|
|
||||||
- if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
|
|
||||||
+ if tunnel.dev != nil {
|
|
||||||
+ if err := RemovePeer(tunnel.dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
|
|
||||||
logger.Error("Failed to remove peer: %v", err)
|
|
||||||
// Send error response if needed
|
|
||||||
return
|
|
||||||
@@ -640,7 +693,7 @@ func Run(ctx context.Context, config Config) {
|
|
||||||
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
- peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
|
|
||||||
+ tunnel.peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
|
|
||||||
})
|
|
||||||
|
|
||||||
olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) {
|
|
||||||
@@ -673,7 +726,7 @@ func Run(ctx context.Context, config Config) {
|
|
||||||
apiServer.SetConnectionStatus(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
- if connected {
|
|
||||||
+ if tunnel.connected {
|
|
||||||
logger.Debug("Already connected, skipping registration")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -682,11 +735,11 @@ func Run(ctx context.Context, config Config) {
|
|
||||||
|
|
||||||
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch)
|
|
||||||
|
|
||||||
- stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
|
||||||
+ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
|
||||||
"publicKey": publicKey.String(),
|
|
||||||
"relay": !doHolepunch,
|
|
||||||
"olmVersion": config.Version,
|
|
||||||
- "orgId": config.OrgID,
|
|
||||||
+ "orgId": orgID,
|
|
||||||
}, 1*time.Second)
|
|
||||||
|
|
||||||
go keepSendingPing(olm)
|
|
||||||
@@ -705,6 +758,49 @@ func Run(ctx context.Context, config Config) {
|
|
||||||
}
|
|
||||||
defer olm.Close()
|
|
||||||
|
|
||||||
+ // Listen for org switch requests from the API (after olm is created)
|
|
||||||
+ if apiServer != nil {
|
|
||||||
+ go func() {
|
|
||||||
+ for req := range apiServer.GetSwitchOrgChannel() {
|
|
||||||
+ logger.Info("Org switch requested via API to orgId: %s", req.OrgID)
|
|
||||||
+
|
|
||||||
+ // Update the orgId
|
|
||||||
+ orgID = req.OrgID
|
|
||||||
+
|
|
||||||
+ // Teardown existing tunnel
|
|
||||||
+ teardownTunnel(tunnel)
|
|
||||||
+
|
|
||||||
+ // Reset tunnel state
|
|
||||||
+ tunnel = &tunnelState{}
|
|
||||||
+
|
|
||||||
+ // Stop holepunch
|
|
||||||
+ select {
|
|
||||||
+ case <-stopHolepunch:
|
|
||||||
+ // Channel already closed
|
|
||||||
+ default:
|
|
||||||
+ close(stopHolepunch)
|
|
||||||
+ }
|
|
||||||
+ stopHolepunch = make(chan struct{})
|
|
||||||
+
|
|
||||||
+ // Clear API server state
|
|
||||||
+ apiServer.SetRegistered(false)
|
|
||||||
+ apiServer.SetTunnelIP("")
|
|
||||||
+ apiServer.SetOrgID(orgID)
|
|
||||||
+
|
|
||||||
+ // Send new registration message with updated orgId
|
|
||||||
+ publicKey := privateKey.PublicKey()
|
|
||||||
+ logger.Info("Sending registration message with new orgId: %s", orgID)
|
|
||||||
+
|
|
||||||
+ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
|
||||||
+ "publicKey": publicKey.String(),
|
|
||||||
+ "relay": !doHolepunch,
|
|
||||||
+ "olmVersion": config.Version,
|
|
||||||
+ "orgId": orgID,
|
|
||||||
+ }, 1*time.Second)
|
|
||||||
+ }
|
|
||||||
+ }()
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
logger.Info("Context cancelled")
|
|
||||||
@@ -717,9 +813,9 @@ func Run(ctx context.Context, config Config) {
|
|
||||||
close(stopHolepunch)
|
|
||||||
}
|
|
||||||
|
|
||||||
- if stopRegister != nil {
|
|
||||||
- stopRegister()
|
|
||||||
- stopRegister = nil
|
|
||||||
+ if tunnel.stopRegister != nil {
|
|
||||||
+ tunnel.stopRegister()
|
|
||||||
+ tunnel.stopRegister = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
@@ -729,16 +825,8 @@ func Run(ctx context.Context, config Config) {
|
|
||||||
close(stopPing)
|
|
||||||
}
|
|
||||||
|
|
||||||
- if peerMonitor != nil {
|
|
||||||
- peerMonitor.Stop()
|
|
||||||
- }
|
|
||||||
-
|
|
||||||
- if uapiListener != nil {
|
|
||||||
- uapiListener.Close()
|
|
||||||
- }
|
|
||||||
- if dev != nil {
|
|
||||||
- dev.Close()
|
|
||||||
- }
|
|
||||||
+ // Use teardownTunnel to clean up all tunnel resources
|
|
||||||
+ teardownTunnel(tunnel)
|
|
||||||
|
|
||||||
if apiServer != nil {
|
|
||||||
apiServer.Stop()
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package olm
|
package network
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -10,16 +10,15 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/olm/network"
|
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConfigureInterface configures a network interface with an IP address and brings it up
|
// ConfigureInterface configures a network interface with an IP address and brings it up
|
||||||
func ConfigureInterface(interfaceName string, wgData WgData, mtu int) error {
|
func ConfigureInterface(interfaceName string, tunnelIp string, mtu int) error {
|
||||||
logger.Info("The tunnel IP is: %s", wgData.TunnelIP)
|
logger.Info("The tunnel IP is: %s", tunnelIp)
|
||||||
|
|
||||||
// Parse the IP address and network
|
// Parse the IP address and network
|
||||||
ip, ipNet, err := net.ParseCIDR(wgData.TunnelIP)
|
ip, ipNet, err := net.ParseCIDR(tunnelIp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid IP address: %v", err)
|
return fmt.Errorf("invalid IP address: %v", err)
|
||||||
}
|
}
|
||||||
@@ -31,9 +30,8 @@ func ConfigureInterface(interfaceName string, wgData WgData, mtu int) error {
|
|||||||
logger.Debug("The destination address is: %s", destinationAddress)
|
logger.Debug("The destination address is: %s", destinationAddress)
|
||||||
|
|
||||||
// network.SetTunnelRemoteAddress() // what does this do?
|
// network.SetTunnelRemoteAddress() // what does this do?
|
||||||
network.SetIPv4Settings([]string{destinationAddress}, []string{mask})
|
SetIPv4Settings([]string{destinationAddress}, []string{mask})
|
||||||
network.SetMTU(mtu)
|
SetMTU(mtu)
|
||||||
apiServer.SetTunnelIP(destinationAddress)
|
|
||||||
|
|
||||||
if interfaceName == "" {
|
if interfaceName == "" {
|
||||||
return nil
|
return nil
|
||||||
@@ -89,7 +87,7 @@ func waitForInterfaceUp(interfaceName string, expectedIP net.IP, timeout time.Du
|
|||||||
return fmt.Errorf("timed out waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
|
return fmt.Errorf("timed out waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
func findUnusedUTUN() (string, error) {
|
func FindUnusedUTUN() (string, error) {
|
||||||
ifaces, err := net.Interfaces()
|
ifaces, err := net.Interfaces()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to list interfaces: %v", err)
|
return "", fmt.Errorf("failed to list interfaces: %v", err)
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
//go:build !windows
|
//go:build !windows
|
||||||
|
|
||||||
package olm
|
package network
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
//go:build windows
|
//go:build windows
|
||||||
|
|
||||||
package olm
|
package network
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package olm
|
package network
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/olm/network"
|
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -126,8 +125,8 @@ func LinuxRemoveRoute(destination string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addRouteForServerIP adds an OS-specific route for the server IP
|
// addRouteForServerIP adds an OS-specific route for the server IP
|
||||||
func addRouteForServerIP(serverIP, interfaceName string) error {
|
func AddRouteForServerIP(serverIP, interfaceName string) error {
|
||||||
if err := addRouteForNetworkConfig(serverIP); err != nil {
|
if err := AddRouteForNetworkConfig(serverIP); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if interfaceName == "" {
|
if interfaceName == "" {
|
||||||
@@ -145,8 +144,8 @@ func addRouteForServerIP(serverIP, interfaceName string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// removeRouteForServerIP removes an OS-specific route for the server IP
|
// removeRouteForServerIP removes an OS-specific route for the server IP
|
||||||
func removeRouteForServerIP(serverIP string, interfaceName string) error {
|
func RemoveRouteForServerIP(serverIP string, interfaceName string) error {
|
||||||
if err := removeRouteForNetworkConfig(serverIP); err != nil {
|
if err := RemoveRouteForNetworkConfig(serverIP); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if interfaceName == "" {
|
if interfaceName == "" {
|
||||||
@@ -163,7 +162,7 @@ func removeRouteForServerIP(serverIP string, interfaceName string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func addRouteForNetworkConfig(destination string) error {
|
func AddRouteForNetworkConfig(destination string) error {
|
||||||
// Parse the subnet to extract IP and mask
|
// Parse the subnet to extract IP and mask
|
||||||
_, ipNet, err := net.ParseCIDR(destination)
|
_, ipNet, err := net.ParseCIDR(destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -174,12 +173,12 @@ func addRouteForNetworkConfig(destination string) error {
|
|||||||
mask := net.IP(ipNet.Mask).String()
|
mask := net.IP(ipNet.Mask).String()
|
||||||
destinationAddress := ipNet.IP.String()
|
destinationAddress := ipNet.IP.String()
|
||||||
|
|
||||||
network.AddIPv4IncludedRoute(network.IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask})
|
AddIPv4IncludedRoute(IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func removeRouteForNetworkConfig(destination string) error {
|
func RemoveRouteForNetworkConfig(destination string) error {
|
||||||
// Parse the subnet to extract IP and mask
|
// Parse the subnet to extract IP and mask
|
||||||
_, ipNet, err := net.ParseCIDR(destination)
|
_, ipNet, err := net.ParseCIDR(destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -190,13 +189,13 @@ func removeRouteForNetworkConfig(destination string) error {
|
|||||||
mask := net.IP(ipNet.Mask).String()
|
mask := net.IP(ipNet.Mask).String()
|
||||||
destinationAddress := ipNet.IP.String()
|
destinationAddress := ipNet.IP.String()
|
||||||
|
|
||||||
network.RemoveIPv4IncludedRoute(network.IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask})
|
RemoveIPv4IncludedRoute(IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// addRoutes adds routes for each subnet in RemoteSubnets
|
// addRoutes adds routes for each subnet in RemoteSubnets
|
||||||
func addRoutes(remoteSubnets []string, interfaceName string) error {
|
func AddRoutes(remoteSubnets []string, interfaceName string) error {
|
||||||
if len(remoteSubnets) == 0 {
|
if len(remoteSubnets) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -208,7 +207,7 @@ func addRoutes(remoteSubnets []string, interfaceName string) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := addRouteForNetworkConfig(subnet); err != nil {
|
if err := AddRouteForNetworkConfig(subnet); err != nil {
|
||||||
logger.Error("Failed to add network config for subnet %s: %v", subnet, err)
|
logger.Error("Failed to add network config for subnet %s: %v", subnet, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -241,7 +240,7 @@ func addRoutes(remoteSubnets []string, interfaceName string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// removeRoutesForRemoteSubnets removes routes for each subnet in RemoteSubnets
|
// removeRoutesForRemoteSubnets removes routes for each subnet in RemoteSubnets
|
||||||
func removeRoutesForRemoteSubnets(remoteSubnets []string) error {
|
func RemoveRoutesForRemoteSubnets(remoteSubnets []string) error {
|
||||||
if len(remoteSubnets) == 0 {
|
if len(remoteSubnets) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -253,7 +252,7 @@ func removeRoutesForRemoteSubnets(remoteSubnets []string) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := removeRouteForNetworkConfig(subnet); err != nil {
|
if err := RemoveRouteForNetworkConfig(subnet); err != nil {
|
||||||
logger.Error("Failed to remove network config for subnet %s: %v", subnet, err)
|
logger.Error("Failed to remove network config for subnet %s: %v", subnet, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
//go:build !windows
|
//go:build !windows
|
||||||
|
|
||||||
package olm
|
package network
|
||||||
|
|
||||||
func WindowsAddRoute(destination string, gateway string, interfaceName string) error {
|
func WindowsAddRoute(destination string, gateway string, interfaceName string) error {
|
||||||
return nil
|
return nil
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
//go:build windows
|
//go:build windows
|
||||||
|
|
||||||
package olm
|
package network
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -177,6 +177,12 @@ func GetJSON() (string, error) {
|
|||||||
return string(data), nil
|
return string(data), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetSettings() NetworkSettings {
|
||||||
|
networkSettingsMutex.RLock()
|
||||||
|
defer networkSettingsMutex.RUnlock()
|
||||||
|
return networkSettings
|
||||||
|
}
|
||||||
|
|
||||||
func GetIncrementor() int {
|
func GetIncrementor() int {
|
||||||
networkSettingsMutex.Lock()
|
networkSettingsMutex.Lock()
|
||||||
defer networkSettingsMutex.Unlock()
|
defer networkSettingsMutex.Unlock()
|
||||||
42
olm/olm.go
42
olm/olm.go
@@ -14,7 +14,7 @@ import (
|
|||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/newt/util"
|
"github.com/fosrl/newt/util"
|
||||||
"github.com/fosrl/olm/api"
|
"github.com/fosrl/olm/api"
|
||||||
middleDevice "github.com/fosrl/olm/device"
|
olmDevice "github.com/fosrl/olm/device"
|
||||||
"github.com/fosrl/olm/dns"
|
"github.com/fosrl/olm/dns"
|
||||||
dnsOverride "github.com/fosrl/olm/dns/override"
|
dnsOverride "github.com/fosrl/olm/dns/override"
|
||||||
"github.com/fosrl/olm/network"
|
"github.com/fosrl/olm/network"
|
||||||
@@ -79,7 +79,7 @@ var (
|
|||||||
holePunchData HolePunchData
|
holePunchData HolePunchData
|
||||||
uapiListener net.Listener
|
uapiListener net.Listener
|
||||||
tdev tun.Device
|
tdev tun.Device
|
||||||
middleDev *middleDevice.MiddleDevice
|
middleDev *olmDevice.MiddleDevice
|
||||||
dnsProxy *dns.DNSProxy
|
dnsProxy *dns.DNSProxy
|
||||||
apiServer *api.API
|
apiServer *api.API
|
||||||
olmClient *websocket.Client
|
olmClient *websocket.Client
|
||||||
@@ -201,7 +201,6 @@ func Init(ctx context.Context, config GlobalConfig) {
|
|||||||
|
|
||||||
// Clear peer statuses in API
|
// Clear peer statuses in API
|
||||||
apiServer.SetRegistered(false)
|
apiServer.SetRegistered(false)
|
||||||
apiServer.SetTunnelIP("")
|
|
||||||
|
|
||||||
// Trigger re-registration with new orgId
|
// Trigger re-registration with new orgId
|
||||||
logger.Info("Re-registering with new orgId: %s", req.OrgID)
|
logger.Info("Re-registering with new orgId: %s", req.OrgID)
|
||||||
@@ -418,11 +417,11 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
|
|
||||||
tdev, err = func() (tun.Device, error) {
|
tdev, err = func() (tun.Device, error) {
|
||||||
if config.FileDescriptorTun != 0 {
|
if config.FileDescriptorTun != 0 {
|
||||||
return createTUNFromFD(config.FileDescriptorTun, config.MTU)
|
return olmDevice.CreateTUNFromFD(config.FileDescriptorTun, config.MTU)
|
||||||
}
|
}
|
||||||
var ifName = interfaceName
|
var ifName = interfaceName
|
||||||
if runtime.GOOS == "darwin" { // this is if we dont pass a fd
|
if runtime.GOOS == "darwin" { // this is if we dont pass a fd
|
||||||
ifName, err = findUnusedUTUN()
|
ifName, err = network.FindUnusedUTUN()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -458,7 +457,7 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
// }
|
// }
|
||||||
|
|
||||||
// Wrap TUN device with packet filter for DNS proxy
|
// Wrap TUN device with packet filter for DNS proxy
|
||||||
middleDev = middleDevice.NewMiddleDevice(tdev)
|
middleDev = olmDevice.NewMiddleDevice(tdev)
|
||||||
|
|
||||||
wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ")
|
wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ")
|
||||||
// Use filtered device instead of raw TUN device
|
// Use filtered device instead of raw TUN device
|
||||||
@@ -495,11 +494,11 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
logger.Error("Failed to create DNS proxy: %v", err)
|
logger.Error("Failed to create DNS proxy: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = ConfigureInterface(interfaceName, wgData, config.MTU); err != nil {
|
if err = network.ConfigureInterface(interfaceName, wgData.TunnelIP, config.MTU); err != nil {
|
||||||
logger.Error("Failed to configure interface: %v", err)
|
logger.Error("Failed to configure interface: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if addRoutes([]string{wgData.UtilitySubnet}, interfaceName); err != nil { // also route the utility subnet
|
if network.AddRoutes([]string{wgData.UtilitySubnet}, interfaceName); err != nil { // also route the utility subnet
|
||||||
logger.Error("Failed to add route for utility subnet: %v", err)
|
logger.Error("Failed to add route for utility subnet: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -549,11 +548,11 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
logger.Error("Failed to configure peer: %v", err)
|
logger.Error("Failed to configure peer: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil { // this is something for darwin only thats required
|
if err := network.AddRouteForServerIP(site.ServerIP, interfaceName); err != nil { // this is something for darwin only thats required
|
||||||
logger.Error("Failed to add route for peer: %v", err)
|
logger.Error("Failed to add route for peer: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := addRoutes(site.RemoteSubnets, interfaceName); err != nil {
|
if err := network.AddRoutes(site.RemoteSubnets, interfaceName); err != nil {
|
||||||
logger.Error("Failed to add routes for remote subnets: %v", err)
|
logger.Error("Failed to add routes for remote subnets: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -676,13 +675,13 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
|
|
||||||
// Handle remote subnet route changes
|
// Handle remote subnet route changes
|
||||||
if !stringSlicesEqual(oldRemoteSubnets, siteConfig.RemoteSubnets) {
|
if !stringSlicesEqual(oldRemoteSubnets, siteConfig.RemoteSubnets) {
|
||||||
if err := removeRoutesForRemoteSubnets(oldRemoteSubnets); err != nil {
|
if err := network.RemoveRoutesForRemoteSubnets(oldRemoteSubnets); err != nil {
|
||||||
logger.Error("Failed to remove old remote subnet routes: %v", err)
|
logger.Error("Failed to remove old remote subnet routes: %v", err)
|
||||||
// Continue anyway to add new routes
|
// Continue anyway to add new routes
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add new remote subnet routes
|
// Add new remote subnet routes
|
||||||
if err := addRoutes(siteConfig.RemoteSubnets, interfaceName); err != nil {
|
if err := network.AddRoutes(siteConfig.RemoteSubnets, interfaceName); err != nil {
|
||||||
logger.Error("Failed to add new remote subnet routes: %v", err)
|
logger.Error("Failed to add new remote subnet routes: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -721,11 +720,11 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
logger.Error("Failed to add peer: %v", err)
|
logger.Error("Failed to add peer: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil {
|
if err := network.AddRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil {
|
||||||
logger.Error("Failed to add route for new peer: %v", err)
|
logger.Error("Failed to add route for new peer: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := addRoutes(siteConfig.RemoteSubnets, interfaceName); err != nil {
|
if err := network.AddRoutes(siteConfig.RemoteSubnets, interfaceName); err != nil {
|
||||||
logger.Error("Failed to add routes for remote subnets: %v", err)
|
logger.Error("Failed to add routes for remote subnets: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -782,14 +781,14 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Remove route for the peer
|
// Remove route for the peer
|
||||||
err = removeRouteForServerIP(peerToRemove.ServerIP, interfaceName)
|
err = network.RemoveRouteForServerIP(peerToRemove.ServerIP, interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to remove route for peer: %v", err)
|
logger.Error("Failed to remove route for peer: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove routes for remote subnets
|
// Remove routes for remote subnets
|
||||||
if err := removeRoutesForRemoteSubnets(peerToRemove.RemoteSubnets); err != nil {
|
if err := network.RemoveRoutesForRemoteSubnets(peerToRemove.RemoteSubnets); err != nil {
|
||||||
logger.Error("Failed to remove routes for remote subnets: %v", err)
|
logger.Error("Failed to remove routes for remote subnets: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -851,7 +850,7 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add routes for the new subnets
|
// Add routes for the new subnets
|
||||||
if err := addRoutes(newSubnets, interfaceName); err != nil {
|
if err := network.AddRoutes(newSubnets, interfaceName); err != nil {
|
||||||
logger.Error("Failed to add routes for new remote subnets: %v", err)
|
logger.Error("Failed to add routes for new remote subnets: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -912,7 +911,7 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Remove routes for the removed subnets
|
// Remove routes for the removed subnets
|
||||||
if err := removeRoutesForRemoteSubnets(removedSubnets); err != nil {
|
if err := network.RemoveRoutesForRemoteSubnets(removedSubnets); err != nil {
|
||||||
logger.Error("Failed to remove routes for remote subnets: %v", err)
|
logger.Error("Failed to remove routes for remote subnets: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -955,7 +954,7 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
|
|
||||||
// First, remove routes for old subnets
|
// First, remove routes for old subnets
|
||||||
if len(updateSubnetsData.OldRemoteSubnets) > 0 {
|
if len(updateSubnetsData.OldRemoteSubnets) > 0 {
|
||||||
if err := removeRoutesForRemoteSubnets(updateSubnetsData.OldRemoteSubnets); err != nil {
|
if err := network.RemoveRoutesForRemoteSubnets(updateSubnetsData.OldRemoteSubnets); err != nil {
|
||||||
logger.Error("Failed to remove routes for old remote subnets: %v", err)
|
logger.Error("Failed to remove routes for old remote subnets: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -964,10 +963,10 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
|
|
||||||
// Then, add routes for new subnets
|
// Then, add routes for new subnets
|
||||||
if len(updateSubnetsData.NewRemoteSubnets) > 0 {
|
if len(updateSubnetsData.NewRemoteSubnets) > 0 {
|
||||||
if err := addRoutes(updateSubnetsData.NewRemoteSubnets, interfaceName); err != nil {
|
if err := network.AddRoutes(updateSubnetsData.NewRemoteSubnets, interfaceName); err != nil {
|
||||||
logger.Error("Failed to add routes for new remote subnets: %v", err)
|
logger.Error("Failed to add routes for new remote subnets: %v", err)
|
||||||
// Attempt to rollback by re-adding old routes
|
// Attempt to rollback by re-adding old routes
|
||||||
if rollbackErr := addRoutes(updateSubnetsData.OldRemoteSubnets, interfaceName); rollbackErr != nil {
|
if rollbackErr := network.AddRoutes(updateSubnetsData.OldRemoteSubnets, interfaceName); rollbackErr != nil {
|
||||||
logger.Error("Failed to rollback old routes: %v", rollbackErr)
|
logger.Error("Failed to rollback old routes: %v", rollbackErr)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@@ -1186,7 +1185,6 @@ func StopTunnel() {
|
|||||||
// Update API server status
|
// Update API server status
|
||||||
apiServer.SetConnectionStatus(false)
|
apiServer.SetConnectionStatus(false)
|
||||||
apiServer.SetRegistered(false)
|
apiServer.SetRegistered(false)
|
||||||
apiServer.SetTunnelIP("")
|
|
||||||
|
|
||||||
network.ClearNetworkSettings()
|
network.ClearNetworkSettings()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user