diff --git a/.dockerignore b/.dockerignore index e5cc1f8..df8d8ae 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,6 +1,6 @@ .gitignore .dockerignore -newt +olm *.json README.md Makefile diff --git a/.github/workflows/cicd.yml b/.github/workflows/cicd.yml index 2d339b8..20f5df7 100644 --- a/.github/workflows/cicd.yml +++ b/.github/workflows/cicd.yml @@ -36,7 +36,7 @@ jobs: run: | TAG=${{ env.TAG }} if [ -f main.go ]; then - sed -i 's/Newt version replaceme/Newt version '"$TAG"'/' main.go + sed -i 's/Olm version replaceme/Olm version '"$TAG"'/' main.go echo "Updated main.go with version $TAG" else echo "main.go not found" @@ -52,7 +52,7 @@ jobs: make go-build-release - name: Upload artifacts from /bin - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: binaries path: bin/ diff --git a/.gitignore b/.gitignore index 8b1c477..6a52691 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ -newt +olm .DS_Store bin/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index d573c7b..f3dddb3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,7 +13,7 @@ RUN go mod download COPY . . # Build the application -RUN CGO_ENABLED=0 GOOS=linux go build -o /newt +RUN CGO_ENABLED=0 GOOS=linux go build -o /olm # Start a new stage from scratch FROM ubuntu:22.04 AS runner @@ -21,7 +21,7 @@ FROM ubuntu:22.04 AS runner RUN apt-get update && apt-get install ca-certificates -y && rm -rf /var/lib/apt/lists/* # Copy the pre-built binary file from the previous stage and the entrypoint script -COPY --from=builder /newt /usr/local/bin/ +COPY --from=builder /olm /usr/local/bin/ COPY entrypoint.sh / RUN chmod +x /entrypoint.sh @@ -30,4 +30,4 @@ RUN chmod +x /entrypoint.sh ENTRYPOINT ["/entrypoint.sh"] # Command to run the executable -CMD ["newt"] \ No newline at end of file +CMD ["olm"] \ No newline at end of file diff --git a/Makefile b/Makefile index 09e1cfa..9303e87 100644 --- a/Makefile +++ b/Makefile @@ -6,30 +6,27 @@ docker-build-release: echo "Error: tag is required. Usage: make build-all tag="; \ exit 1; \ fi - docker buildx build --platform linux/arm64,linux/amd64 -t fosrl/newt:latest -f Dockerfile --push . - docker buildx build --platform linux/arm64,linux/amd64 -t fosrl/newt:$(tag) -f Dockerfile --push . + docker buildx build --platform linux/arm64,linux/amd64 -t fosrl/olm:latest -f Dockerfile --push . + docker buildx build --platform linux/arm64,linux/amd64 -t fosrl/olm:$(tag) -f Dockerfile --push . build: - docker build -t fosrl/newt:latest . + docker build -t fosrl/olm:latest . push: - docker push fosrl/newt:latest + docker push fosrl/olm:latest test: - docker run fosrl/newt:latest + docker run fosrl/olm:latest local: - CGO_ENABLED=0 go build -o newt + CGO_ENABLED=0 go build -o olm go-build-release: - CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -o bin/newt_linux_arm64 - CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=7 go build -o bin/newt_linux_arm32 - CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/newt_linux_amd64 - CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -o bin/newt_darwin_arm64 - CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -o bin/newt_darwin_amd64 - CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o bin/newt_windows_amd64.exe - CGO_ENABLED=0 GOOS=freebsd GOARCH=amd64 go build -o bin/newt_freebsd_amd64 - CGO_ENABLED=0 GOOS=freebsd GOARCH=arm64 go build -o bin/newt_freebsd_arm64 - + CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -o bin/olm_linux_arm64 + CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/olm_linux_amd64 + CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -o bin/olm_darwin_arm64 + CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -o bin/olm_darwin_amd64 + CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o bin/olm_windows_amd64.exe + clean: - rm newt + rm olm diff --git a/README.md b/README.md index 471b603..6809b69 100644 --- a/README.md +++ b/README.md @@ -1,46 +1,40 @@ -# Newt +# Olm -Newt is a fully user space [WireGuard](https://www.wireguard.com/) tunnel client and TCP/UDP proxy, designed to securely expose private resources controlled by Pangolin. By using Newt, you don't need to manage complex WireGuard tunnels and NATing. +Olm is a [WireGuard](https://www.wireguard.com/) tunnel manager designed to securely connect to private resources. By using Olm, you don't need to manage complex WireGuard tunnels. ### Installation and Documentation -Newt is used with Pangolin and Gerbil as part of the larger system. See documentation below: +Olm is used with Pangolin and Newt as part of the larger system. See documentation below: - [Installation Instructions](https://docs.fossorial.io) - [Full Documentation](https://docs.fossorial.io) -## Preview - -Preview - -_Sample output of a Newt container connected to Pangolin and hosting various resource target proxies._ - ## Key Functions ### Registers with Pangolin -Using the Newt ID and a secret, the client will make HTTP requests to Pangolin to receive a session token. Using that token, it will connect to a websocket and maintain that connection. Control messages will be sent over the websocket. +Using the Olm ID and a secret, the olm will make HTTP requests to Pangolin to receive a session token. Using that token, it will connect to a websocket and maintain that connection. Control messages will be sent over the websocket. ### Receives WireGuard Control Messages -When Newt receives WireGuard control messages, it will use the information encoded (endpoint, public key) to bring up a WireGuard tunnel using [netstack](https://github.com/WireGuard/wireguard-go/blob/master/tun/netstack/examples/http_server.go) fully in user space. It will ping over the tunnel to ensure the peer on the Gerbil side is brought up. +When Olm receives WireGuard control messages, it will use the information encoded (endpoint, public key) to bring up a WireGuard tunnel using [netstack](https://github.com/WireGuard/wireguard-go/blob/master/tun/netstack/examples/http_server.go) fully in user space. It will ping over the tunnel to ensure the peer on the Gerbil side is brought up. ### Receives Proxy Control Messages -When Newt receives WireGuard control messages, it will use the information encoded to create a local low level TCP and UDP proxies attached to the virtual tunnel in order to relay traffic to programmed targets. +When Olm receives WireGuard control messages, it will use the information encoded to create a local low level TCP and UDP proxies attached to the virtual tunnel in order to relay traffic to programmed targets. ## CLI Args - `endpoint`: The endpoint where both Gerbil and Pangolin reside in order to connect to the websocket. -- `id`: Newt ID generated by Pangolin to identify the client. -- `secret`: A unique secret (not shared and kept private) used to authenticate the client ID with the websocket in order to receive commands. +- `id`: Olm ID generated by Pangolin to identify the olm. +- `secret`: A unique secret (not shared and kept private) used to authenticate the olm ID with the websocket in order to receive commands. - `dns`: DNS server to use to resolve the endpoint - `log-level` (optional): The log level to use. Default: INFO Example: ```bash -./newt \ +./olm \ --id 31frd0uzbjvp721 \ --secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \ --endpoint https://example.com @@ -50,23 +44,23 @@ You can also run it with Docker compose. For example, a service in your `docker- ```yaml services: - newt: - image: fosrl/newt - container_name: newt + olm: + image: fosrl/olm + container_name: olm restart: unless-stopped environment: - PANGOLIN_ENDPOINT=https://example.com - - NEWT_ID=2ix2t8xk22ubpfy - - NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2 + - OLM_ID=2ix2t8xk22ubpfy + - OLM_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2 ``` You can also pass the CLI args to the container: ```yaml services: - newt: - image: fosrl/newt - container_name: newt + olm: + image: fosrl/olm + container_name: olm restart: unless-stopped command: - --id 31frd0uzbjvp721 @@ -78,11 +72,11 @@ Finally a basic systemd service: ``` [Unit] -Description=Newt VPN Client +Description=Olm VPN Olm After=network.target [Service] -ExecStart=/usr/local/bin/newt --id 31frd0uzbjvp721 --secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 --endpoint https://example.com +ExecStart=/usr/local/bin/olm --id 31frd0uzbjvp721 --secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 --endpoint https://example.com Restart=always User=root @@ -90,7 +84,70 @@ User=root WantedBy=multi-user.target ``` -Make sure to `mv ./newt /usr/local/bin/newt`! +Make sure to `mv ./olm /usr/local/bin/olm`! + +## Windows Service + +On Windows, Olm can be installed and run as a Windows service. This allows it to start automatically at boot and run in the background. + +### Service Management Commands + +```cmd +# Install the service +olm.exe install + +# Start the service +olm.exe start + +# Stop the service +olm.exe stop + +# Check service status +olm.exe status + +# Remove the service +olm.exe remove + +# Run in debug mode (console output) +olm.exe debug + +# Show help +olm.exe help +``` + +**Helper Scripts**: For easier service management, you can use the provided helper scripts: +- `olm-service.bat` - Batch script (requires Administrator privileges) +- `olm-service.ps1` - PowerShell script with better error handling + +Example using the batch script: +```cmd +# Run as Administrator +olm-service.bat install +olm-service.bat start +olm-service.bat status +``` + +### Service Configuration + +When running as a service, Olm will read configuration from environment variables or you can modify the service to include command-line arguments: + +1. Install the service: `olm.exe install` +2. Configure the service with your credentials using Windows Service Manager or by setting system environment variables: + - `PANGOLIN_ENDPOINT=https://example.com` + - `OLM_ID=your_olm_id` + - `OLM_SECRET=your_secret` +3. Start the service: `olm.exe start` + +### Service Logs + +When running as a service, logs are written to: +- Windows Event Log (Application log, source: "OlmWireguardService") +- Log files in: `%PROGRAMDATA%\Olm\logs\olm.log` + +You can view the Windows Event Log using Event Viewer or PowerShell: +```powershell +Get-EventLog -LogName Application -Source "OlmWireguardService" -Newest 10 +``` ## Build @@ -112,7 +169,7 @@ make local ## Licensing -Newt is dual licensed under the AGPLv3 and the Fossorial Commercial license. For inquiries about commercial licensing, please contact us. +Olm is dual licensed under the AGPLv3 and the Fossorial Commercial license. For inquiries about commercial licensing, please contact us. ## Contributions diff --git a/common.go b/common.go new file mode 100644 index 0000000..db8c155 --- /dev/null +++ b/common.go @@ -0,0 +1,1030 @@ +package main + +import ( + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "net" + "os/exec" + "regexp" + "runtime" + "strconv" + "strings" + "time" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/websocket" + "github.com/fosrl/olm/peermonitor" + "github.com/vishvananda/netlink" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/curve25519" + "golang.org/x/exp/rand" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +type WgData struct { + Sites []SiteConfig `json:"sites"` + TunnelIP string `json:"tunnelIP"` +} + +type SiteConfig struct { + SiteId int `json:"siteId"` + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` + ServerIP string `json:"serverIP"` + ServerPort uint16 `json:"serverPort"` + RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access +} + +type TargetsByType struct { + UDP []string `json:"udp"` + TCP []string `json:"tcp"` +} + +type TargetData struct { + Targets []string `json:"targets"` +} + +type HolePunchMessage struct { + NewtID string `json:"newtId"` +} + +type HolePunchData struct { + ServerPubKey string `json:"serverPubKey"` + Endpoint string `json:"endpoint"` +} + +type EncryptedHolePunchMessage struct { + EphemeralPublicKey string `json:"ephemeralPublicKey"` + Nonce []byte `json:"nonce"` + Ciphertext []byte `json:"ciphertext"` +} + +var ( + peerMonitor *peermonitor.PeerMonitor + stopHolepunch chan struct{} + stopRegister func() + stopPing chan struct{} + olmToken string + gerbilServerPubKey string + holePunchRunning bool +) + +const ( + ENV_WG_TUN_FD = "WG_TUN_FD" + ENV_WG_UAPI_FD = "WG_UAPI_FD" + ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND" +) + +type fixedPortBind struct { + port uint16 + conn.Bind +} + +// PeerAction represents a request to add, update, or remove a peer +type PeerAction struct { + Action string `json:"action"` // "add", "update", or "remove" + SiteInfo SiteConfig `json:"siteInfo"` // Site configuration information +} + +// UpdatePeerData represents the data needed to update a peer +type UpdatePeerData struct { + SiteId int `json:"siteId"` + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` + ServerIP string `json:"serverIP"` + ServerPort uint16 `json:"serverPort"` + RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access +} + +// AddPeerData represents the data needed to add a peer +type AddPeerData struct { + SiteId int `json:"siteId"` + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` + ServerIP string `json:"serverIP"` + ServerPort uint16 `json:"serverPort"` + RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access +} + +// RemovePeerData represents the data needed to remove a peer +type RemovePeerData struct { + SiteId int `json:"siteId"` +} + +type RelayPeerData struct { + SiteId int `json:"siteId"` + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` +} + +func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) { + // Ignore the requested port and use our fixed port + return b.Bind.Open(b.port) +} + +func NewFixedPortBind(port uint16) conn.Bind { + return &fixedPortBind{ + port: port, + Bind: conn.NewDefaultBind(), + } +} + +func fixKey(key string) string { + // Remove any whitespace + key = strings.TrimSpace(key) + + // Decode from base64 + decoded, err := base64.StdEncoding.DecodeString(key) + if err != nil { + logger.Fatal("Error decoding base64") + } + + // Convert to hex + return hex.EncodeToString(decoded) +} + +func parseLogLevel(level string) logger.LogLevel { + switch strings.ToUpper(level) { + case "DEBUG": + return logger.DEBUG + case "INFO": + return logger.INFO + case "WARN": + return logger.WARN + case "ERROR": + return logger.ERROR + case "FATAL": + return logger.FATAL + default: + return logger.INFO // default to INFO if invalid level provided + } +} + +func mapToWireGuardLogLevel(level logger.LogLevel) int { + switch level { + case logger.DEBUG: + return device.LogLevelVerbose + // case logger.INFO: + // return device.LogLevel + case logger.WARN: + return device.LogLevelError + case logger.ERROR, logger.FATAL: + return device.LogLevelSilent + default: + return device.LogLevelSilent + } +} + +func resolveDomain(domain string) (string, error) { + // First handle any protocol prefix + domain = strings.TrimPrefix(strings.TrimPrefix(domain, "https://"), "http://") + + // if there are any trailing slashes, remove them + domain = strings.TrimSuffix(domain, "/") + + // Now split host and port + host, port, err := net.SplitHostPort(domain) + if err != nil { + // No port found, use the domain as is + host = domain + port = "" + } + + // Lookup IP addresses + ips, err := net.LookupIP(host) + if err != nil { + return "", fmt.Errorf("DNS lookup failed: %v", err) + } + + if len(ips) == 0 { + return "", fmt.Errorf("no IP addresses found for domain %s", host) + } + + // Get the first IPv4 address if available + var ipAddr string + for _, ip := range ips { + if ipv4 := ip.To4(); ipv4 != nil { + ipAddr = ipv4.String() + break + } + } + + // If no IPv4 found, use the first IP (might be IPv6) + if ipAddr == "" { + ipAddr = ips[0].String() + } + + // Add port back if it existed + if port != "" { + ipAddr = net.JoinHostPort(ipAddr, port) + } + + return ipAddr, nil +} + +func sendUDPHolePunchWithConn(conn *net.UDPConn, remoteAddr *net.UDPAddr, olmID string) error { + if gerbilServerPubKey == "" || olmToken == "" { + return nil + } + + payload := struct { + OlmID string `json:"olmId"` + Token string `json:"token"` + }{ + OlmID: olmID, + Token: olmToken, + } + + // Convert payload to JSON + payloadBytes, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to marshal payload: %v", err) + } + + // Encrypt the payload using the server's WireGuard public key + encryptedPayload, err := encryptPayload(payloadBytes, gerbilServerPubKey) + if err != nil { + return fmt.Errorf("failed to encrypt payload: %v", err) + } + + jsonData, err := json.Marshal(encryptedPayload) + if err != nil { + return fmt.Errorf("failed to marshal encrypted payload: %v", err) + } + + _, err = conn.WriteToUDP(jsonData, remoteAddr) + if err != nil { + return fmt.Errorf("failed to send UDP packet: %v", err) + } + + logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData)) + + return nil +} + +func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) { + // Generate an ephemeral keypair for this message + ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err) + } + ephemeralPublicKey := ephemeralPrivateKey.PublicKey() + + // Parse the server's public key + serverPubKey, err := wgtypes.ParseKey(serverPublicKey) + if err != nil { + return nil, fmt.Errorf("failed to parse server public key: %v", err) + } + + // Use X25519 for key exchange (replacing deprecated ScalarMult) + var ephPrivKeyFixed [32]byte + copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:]) + + // Perform X25519 key exchange + sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:]) + if err != nil { + return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err) + } + + // Create an AEAD cipher using the shared secret + aead, err := chacha20poly1305.New(sharedSecret) + if err != nil { + return nil, fmt.Errorf("failed to create AEAD cipher: %v", err) + } + + // Generate a random nonce + nonce := make([]byte, aead.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return nil, fmt.Errorf("failed to generate nonce: %v", err) + } + + // Encrypt the payload + ciphertext := aead.Seal(nil, nonce, payload, nil) + + // Prepare the final encrypted message + encryptedMsg := struct { + EphemeralPublicKey string `json:"ephemeralPublicKey"` + Nonce []byte `json:"nonce"` + Ciphertext []byte `json:"ciphertext"` + }{ + EphemeralPublicKey: ephemeralPublicKey.String(), + Nonce: nonce, + Ciphertext: ciphertext, + } + + return encryptedMsg, nil +} + +func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16) { + // Check if hole punching is already running + if holePunchRunning { + logger.Debug("UDP hole punch already running, skipping new request") + return + } + + // Set the flag to indicate hole punching is running + holePunchRunning = true + defer func() { + holePunchRunning = false + logger.Info("UDP hole punch goroutine ended") + }() + + host, err := resolveDomain(endpoint) + if err != nil { + logger.Error("Failed to resolve endpoint: %v", err) + return + } + + serverAddr := host + ":21820" + + // Create the UDP connection once and reuse it + localAddr := &net.UDPAddr{ + Port: int(sourcePort), + IP: net.IPv4zero, + } + + remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) + if err != nil { + logger.Error("Failed to resolve UDP address: %v", err) + return + } + + conn, err := net.ListenUDP("udp", localAddr) + if err != nil { + logger.Error("Failed to bind UDP socket: %v", err) + return + } + defer conn.Close() + + // Execute once immediately before starting the loop + if err := sendUDPHolePunchWithConn(conn, remoteAddr, olmID); err != nil { + logger.Error("Failed to send UDP hole punch: %v", err) + } + + ticker := time.NewTicker(250 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-stopHolepunch: + logger.Info("Stopping UDP holepunch") + return + case <-ticker.C: + if err := sendUDPHolePunchWithConn(conn, remoteAddr, olmID); err != nil { + logger.Error("Failed to send UDP hole punch: %v", err) + } + } + } +} + +func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { + if maxPort < minPort { + return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort) + } + + // Create a slice of all ports in the range + portRange := make([]uint16, maxPort-minPort+1) + for i := range portRange { + portRange[i] = minPort + uint16(i) + } + + // Fisher-Yates shuffle to randomize the port order + rand.Seed(uint64(time.Now().UnixNano())) + for i := len(portRange) - 1; i > 0; i-- { + j := rand.Intn(i + 1) + portRange[i], portRange[j] = portRange[j], portRange[i] + } + + // Try each port in the randomized order + for _, port := range portRange { + addr := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: int(port), + } + conn, err := net.ListenUDP("udp", addr) + if err != nil { + continue // Port is in use or there was an error, try next port + } + _ = conn.SetDeadline(time.Now()) + conn.Close() + return port, nil + } + + return 0, fmt.Errorf("no available UDP ports found in range %d-%d", minPort, maxPort) +} + +func sendPing(olm *websocket.Client) error { + err := olm.SendMessage("olm/ping", map[string]interface{}{ + "timestamp": time.Now().Unix(), + }) + if err != nil { + logger.Error("Failed to send ping message: %v", err) + return err + } + logger.Debug("Sent ping message") + return nil +} + +func keepSendingPing(olm *websocket.Client) { + // Send ping immediately on startup + if err := sendPing(olm); err != nil { + logger.Error("Failed to send initial ping: %v", err) + } else { + logger.Info("Sent initial ping message") + } + + // Set up ticker for one minute intervals + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-stopPing: + logger.Info("Stopping ping messages") + return + case <-ticker.C: + if err := sendPing(olm); err != nil { + logger.Error("Failed to send periodic ping: %v", err) + } + } + } +} + +// ConfigurePeer sets up or updates a peer within the WireGuard device +func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string) error { + siteHost, err := resolveDomain(siteConfig.Endpoint) + if err != nil { + return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err) + } + + // Split off the CIDR of the server IP which is just a string and add /32 for the allowed IP + allowedIp := strings.Split(siteConfig.ServerIP, "/") + if len(allowedIp) > 1 { + allowedIp[1] = "32" + } else { + allowedIp = append(allowedIp, "32") + } + allowedIpStr := strings.Join(allowedIp, "/") + + // Collect all allowed IPs in a slice + var allowedIPs []string + allowedIPs = append(allowedIPs, allowedIpStr) + + // If we have anything in remoteSubnets, add those as well + if siteConfig.RemoteSubnets != "" { + // Split remote subnets by comma and add each one + remoteSubnets := strings.Split(siteConfig.RemoteSubnets, ",") + for _, subnet := range remoteSubnets { + subnet = strings.TrimSpace(subnet) + if subnet != "" { + allowedIPs = append(allowedIPs, subnet) + } + } + } + + // Construct WireGuard config for this peer + var configBuilder strings.Builder + configBuilder.WriteString(fmt.Sprintf("private_key=%s\n", fixKey(privateKey.String()))) + configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", fixKey(siteConfig.PublicKey))) + + // Add each allowed IP separately + for _, allowedIP := range allowedIPs { + configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP)) + } + + configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost)) + configBuilder.WriteString("persistent_keepalive_interval=1\n") + + config := configBuilder.String() + logger.Debug("Configuring peer with config: %s", config) + + err = dev.IpcSet(config) + if err != nil { + return fmt.Errorf("failed to configure WireGuard peer: %v", err) + } + + // Set up peer monitoring + if peerMonitor != nil { + monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0] + monitorPeer := fmt.Sprintf("%s:%d", monitorAddress, siteConfig.ServerPort+1) // +1 for the monitor port + logger.Debug("Setting up peer monitor for site %d at %s", siteConfig.SiteId, monitorPeer) + + primaryRelay, err := resolveDomain(endpoint) // Using global endpoint variable + if err != nil { + logger.Warn("Failed to resolve primary relay endpoint: %v", err) + } + + wgConfig := &peermonitor.WireGuardConfig{ + SiteID: siteConfig.SiteId, + PublicKey: fixKey(siteConfig.PublicKey), + ServerIP: strings.Split(siteConfig.ServerIP, "/")[0], + Endpoint: siteConfig.Endpoint, + PrimaryRelay: primaryRelay, + } + + err = peerMonitor.AddPeer(siteConfig.SiteId, monitorPeer, wgConfig) + if err != nil { + logger.Warn("Failed to setup monitoring for site %d: %v", siteConfig.SiteId, err) + } else { + logger.Info("Started monitoring for site %d at %s", siteConfig.SiteId, monitorPeer) + } + } + + return nil +} + +// RemovePeer removes a peer from the WireGuard device +func RemovePeer(dev *device.Device, siteId int, publicKey string) error { + // Construct WireGuard config to remove the peer + var configBuilder strings.Builder + configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", fixKey(publicKey))) + configBuilder.WriteString("remove=true\n") + + config := configBuilder.String() + logger.Debug("Removing peer with config: %s", config) + + err := dev.IpcSet(config) + if err != nil { + return fmt.Errorf("failed to remove WireGuard peer: %v", err) + } + + // Stop monitoring this peer + if peerMonitor != nil { + peerMonitor.RemovePeer(siteId) + logger.Info("Stopped monitoring for site %d", siteId) + } + + return nil +} + +// ConfigureInterface configures a network interface with an IP address and brings it up +func ConfigureInterface(interfaceName string, wgData WgData) error { + var ipAddr string = wgData.TunnelIP + + // Parse the IP address and network + ip, ipNet, err := net.ParseCIDR(ipAddr) + if err != nil { + return fmt.Errorf("invalid IP address: %v", err) + } + + switch runtime.GOOS { + case "linux": + return configureLinux(interfaceName, ip, ipNet) + case "darwin": + return configureDarwin(interfaceName, ip, ipNet) + case "windows": + return configureWindows(interfaceName, ip, ipNet) + default: + return fmt.Errorf("unsupported operating system: %s", runtime.GOOS) + } +} + +func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error { + logger.Info("Configuring Windows interface: %s", interfaceName) + + // Calculate mask string (e.g., 255.255.255.0) + maskBits, _ := ipNet.Mask.Size() + mask := net.CIDRMask(maskBits, 32) + maskIP := net.IP(mask) + + // Set the IP address using netsh + cmd := exec.Command("netsh", "interface", "ipv4", "set", "address", + fmt.Sprintf("name=%s", interfaceName), + "source=static", + fmt.Sprintf("addr=%s", ip.String()), + fmt.Sprintf("mask=%s", maskIP.String())) + + logger.Info("Running command: %v", cmd) + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("netsh command failed: %v, output: %s", err, out) + } + + // Bring up the interface if needed (in Windows, setting the IP usually brings it up) + // But we'll explicitly enable it to be sure + cmd = exec.Command("netsh", "interface", "set", "interface", + interfaceName, + "admin=enable") + + logger.Info("Running command: %v", cmd) + out, err = cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("netsh enable interface command failed: %v, output: %s", err, out) + } + + // delay 2 seconds + time.Sleep(8 * time.Second) + + // Wait for the interface to be up and have the correct IP + err = waitForInterfaceUp(interfaceName, ip, 30*time.Second) + if err != nil { + return fmt.Errorf("interface did not come up within timeout: %v", err) + } + + return nil +} + +// waitForInterfaceUp polls the network interface until it's up or times out +func waitForInterfaceUp(interfaceName string, expectedIP net.IP, timeout time.Duration) error { + logger.Info("Waiting for interface %s to be up with IP %s", interfaceName, expectedIP) + deadline := time.Now().Add(timeout) + pollInterval := 500 * time.Millisecond + + for time.Now().Before(deadline) { + // Check if interface exists and is up + iface, err := net.InterfaceByName(interfaceName) + if err == nil { + // Check if interface is up + if iface.Flags&net.FlagUp != 0 { + // Check if it has the expected IP + addrs, err := iface.Addrs() + if err == nil { + for _, addr := range addrs { + ipNet, ok := addr.(*net.IPNet) + if ok && ipNet.IP.Equal(expectedIP) { + logger.Info("Interface %s is up with correct IP", interfaceName) + return nil // Interface is up with correct IP + } + } + logger.Info("Interface %s is up but doesn't have expected IP yet", interfaceName) + } + } else { + logger.Info("Interface %s exists but is not up yet", interfaceName) + } + } else { + logger.Info("Interface %s not found yet: %v", interfaceName, err) + } + + // Wait before next check + time.Sleep(pollInterval) + } + + return fmt.Errorf("timed out waiting for interface %s to be up with IP %s", interfaceName, expectedIP) +} + +func WindowsAddRoute(destination string, gateway string, interfaceName string) error { + if runtime.GOOS != "windows" { + return nil + } + + var cmd *exec.Cmd + + // Parse destination to get the IP and subnet + ip, ipNet, err := net.ParseCIDR(destination) + if err != nil { + return fmt.Errorf("invalid destination address: %v", err) + } + + // Calculate the subnet mask + maskBits, _ := ipNet.Mask.Size() + mask := net.CIDRMask(maskBits, 32) + maskIP := net.IP(mask) + + if gateway != "" { + // Route with specific gateway + cmd = exec.Command("route", "add", + ip.String(), + "mask", maskIP.String(), + gateway, + "metric", "1") + } else if interfaceName != "" { + // First, get the interface index + indexCmd := exec.Command("netsh", "interface", "ipv4", "show", "interfaces") + output, err := indexCmd.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to get interface index: %v, output: %s", err, output) + } + + // Parse the output to find the interface index + lines := strings.Split(string(output), "\n") + var ifIndex string + for _, line := range lines { + if strings.Contains(line, interfaceName) { + fields := strings.Fields(line) + if len(fields) > 0 { + ifIndex = fields[0] + break + } + } + } + + if ifIndex == "" { + return fmt.Errorf("could not find index for interface %s", interfaceName) + } + + // Convert to integer to validate + idx, err := strconv.Atoi(ifIndex) + if err != nil { + return fmt.Errorf("invalid interface index: %v", err) + } + + // Route via interface using the index + cmd = exec.Command("route", "add", + ip.String(), + "mask", maskIP.String(), + "0.0.0.0", + "if", strconv.Itoa(idx)) + } else { + return fmt.Errorf("either gateway or interface must be specified") + } + + logger.Info("Running command: %v", cmd) + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("route command failed: %v, output: %s", err, out) + } + + return nil +} + +func WindowsRemoveRoute(destination string) error { + // Parse destination to get the IP + ip, ipNet, err := net.ParseCIDR(destination) + if err != nil { + return fmt.Errorf("invalid destination address: %v", err) + } + + // Calculate the subnet mask + maskBits, _ := ipNet.Mask.Size() + mask := net.CIDRMask(maskBits, 32) + maskIP := net.IP(mask) + + cmd := exec.Command("route", "delete", + ip.String(), + "mask", maskIP.String()) + + logger.Info("Running command: %v", cmd) + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("route delete command failed: %v, output: %s", err, out) + } + + return nil +} + +func findUnusedUTUN() (string, error) { + ifaces, err := net.Interfaces() + if err != nil { + return "", fmt.Errorf("failed to list interfaces: %v", err) + } + used := make(map[int]bool) + re := regexp.MustCompile(`^utun(\d+)$`) + for _, iface := range ifaces { + if matches := re.FindStringSubmatch(iface.Name); len(matches) == 2 { + if num, err := strconv.Atoi(matches[1]); err == nil { + used[num] = true + } + } + } + // Try utun0 up to utun255. + for i := 0; i < 256; i++ { + if !used[i] { + return fmt.Sprintf("utun%d", i), nil + } + } + return "", fmt.Errorf("no unused utun interface found") +} + +func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error { + logger.Info("Configuring darwin interface: %s", interfaceName) + + prefix, _ := ipNet.Mask.Size() + ipStr := fmt.Sprintf("%s/%d", ip.String(), prefix) + + cmd := exec.Command("ifconfig", interfaceName, "inet", ipStr, ip.String(), "alias") + logger.Info("Running command: %v", cmd) + + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("ifconfig command failed: %v, output: %s", err, out) + } + + // Bring up the interface + cmd = exec.Command("ifconfig", interfaceName, "up") + logger.Info("Running command: %v", cmd) + + out, err = cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("ifconfig up command failed: %v, output: %s", err, out) + } + + return nil +} + +func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error { + // Get the interface + link, err := netlink.LinkByName(interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface %s: %v", interfaceName, err) + } + + // Create the IP address attributes + addr := &netlink.Addr{ + IPNet: &net.IPNet{ + IP: ip, + Mask: ipNet.Mask, + }, + } + + // Add the IP address to the interface + if err := netlink.AddrAdd(link, addr); err != nil { + return fmt.Errorf("failed to add IP address: %v", err) + } + + // Bring up the interface + if err := netlink.LinkSetUp(link); err != nil { + return fmt.Errorf("failed to bring up interface: %v", err) + } + + return nil +} + +func DarwinAddRoute(destination string, gateway string, interfaceName string) error { + if runtime.GOOS != "darwin" { + return nil + } + + var cmd *exec.Cmd + + if gateway != "" { + // Route with specific gateway + cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-gateway", gateway) + } else if interfaceName != "" { + // Route via interface + cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-interface", interfaceName) + } else { + return fmt.Errorf("either gateway or interface must be specified") + } + + logger.Info("Running command: %v", cmd) + + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("route command failed: %v, output: %s", err, out) + } + + return nil +} + +func DarwinRemoveRoute(destination string) error { + if runtime.GOOS != "darwin" { + return nil + } + + cmd := exec.Command("route", "-q", "-n", "delete", "-inet", destination) + logger.Info("Running command: %v", cmd) + + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("route delete command failed: %v, output: %s", err, out) + } + + return nil +} + +func LinuxAddRoute(destination string, gateway string, interfaceName string) error { + if runtime.GOOS != "linux" { + return nil + } + + var cmd *exec.Cmd + + if gateway != "" { + // Route with specific gateway + cmd = exec.Command("ip", "route", "add", destination, "via", gateway) + } else if interfaceName != "" { + // Route via interface + cmd = exec.Command("ip", "route", "add", destination, "dev", interfaceName) + } else { + return fmt.Errorf("either gateway or interface must be specified") + } + + logger.Info("Running command: %v", cmd) + + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("ip route command failed: %v, output: %s", err, out) + } + + return nil +} + +func LinuxRemoveRoute(destination string) error { + if runtime.GOOS != "linux" { + return nil + } + + cmd := exec.Command("ip", "route", "del", destination) + logger.Info("Running command: %v", cmd) + + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("ip route delete command failed: %v, output: %s", err, out) + } + + return nil +} + +// addRouteForServerIP adds an OS-specific route for the server IP +func addRouteForServerIP(serverIP, interfaceName string) error { + if runtime.GOOS == "darwin" { + return DarwinAddRoute(serverIP, "", interfaceName) + } + // else if runtime.GOOS == "windows" { + // return WindowsAddRoute(serverIP, "", interfaceName) + // } else if runtime.GOOS == "linux" { + // return LinuxAddRoute(serverIP, "", interfaceName) + // } + return nil +} + +// removeRouteForServerIP removes an OS-specific route for the server IP +func removeRouteForServerIP(serverIP string) error { + if runtime.GOOS == "darwin" { + return DarwinRemoveRoute(serverIP) + } + // else if runtime.GOOS == "windows" { + // return WindowsRemoveRoute(serverIP) + // } else if runtime.GOOS == "linux" { + // return LinuxRemoveRoute(serverIP) + // } + return nil +} + +// addRoutesForRemoteSubnets adds routes for each comma-separated CIDR in RemoteSubnets +func addRoutesForRemoteSubnets(remoteSubnets, interfaceName string) error { + if remoteSubnets == "" { + return nil + } + + // Split remote subnets by comma and add routes for each one + subnets := strings.Split(remoteSubnets, ",") + for _, subnet := range subnets { + subnet = strings.TrimSpace(subnet) + if subnet == "" { + continue + } + + // Add route based on operating system + if runtime.GOOS == "darwin" { + if err := DarwinAddRoute(subnet, "", interfaceName); err != nil { + logger.Error("Failed to add Darwin route for subnet %s: %v", subnet, err) + return err + } + } else if runtime.GOOS == "windows" { + if err := WindowsAddRoute(subnet, "", interfaceName); err != nil { + logger.Error("Failed to add Windows route for subnet %s: %v", subnet, err) + return err + } + } else if runtime.GOOS == "linux" { + if err := LinuxAddRoute(subnet, "", interfaceName); err != nil { + logger.Error("Failed to add Linux route for subnet %s: %v", subnet, err) + return err + } + } + + logger.Info("Added route for remote subnet: %s", subnet) + } + return nil +} + +// removeRoutesForRemoteSubnets removes routes for each comma-separated CIDR in RemoteSubnets +func removeRoutesForRemoteSubnets(remoteSubnets string) error { + if remoteSubnets == "" { + return nil + } + + // Split remote subnets by comma and remove routes for each one + subnets := strings.Split(remoteSubnets, ",") + for _, subnet := range subnets { + subnet = strings.TrimSpace(subnet) + if subnet == "" { + continue + } + + // Remove route based on operating system + if runtime.GOOS == "darwin" { + if err := DarwinRemoveRoute(subnet); err != nil { + logger.Error("Failed to remove Darwin route for subnet %s: %v", subnet, err) + return err + } + } else if runtime.GOOS == "windows" { + if err := WindowsRemoveRoute(subnet); err != nil { + logger.Error("Failed to remove Windows route for subnet %s: %v", subnet, err) + return err + } + } else if runtime.GOOS == "linux" { + if err := LinuxRemoveRoute(subnet); err != nil { + logger.Error("Failed to remove Linux route for subnet %s: %v", subnet, err) + return err + } + } + + logger.Info("Removed route for remote subnet: %s", subnet) + } + return nil +} diff --git a/docker-compose.yml b/docker-compose.yml index 86f4ca1..b63cf27 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,10 +1,10 @@ services: - newt: - image: fosrl/newt:latest - container_name: newt + olm: + image: fosrl/olm:latest + container_name: olm restart: unless-stopped environment: - PANGOLIN_ENDPOINT=https://example.com - - NEWT_ID=2ix2t8xk22ubpfy - - NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2 + - OLM_ID=2ix2t8xk22ubpfy + - OLM_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2 - LOG_LEVEL=DEBUG \ No newline at end of file diff --git a/entrypoint.sh b/entrypoint.sh index 79ae7a0..5ca3dda 100644 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -4,7 +4,7 @@ set -e # first arg is `-f` or `--some-option` if [ "${1#-}" != "$1" ]; then - set -- newt "$@" + set -- olm "$@" fi exec "$@" \ No newline at end of file diff --git a/go.mod b/go.mod index 2cc0c19..8827763 100644 --- a/go.mod +++ b/go.mod @@ -1,20 +1,55 @@ -module github.com/fosrl/newt +module github.com/fosrl/olm go 1.23.1 toolchain go1.23.2 -require golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 +require ( + github.com/fosrl/newt v0.0.0-20250724194014-008be54c554a + github.com/vishvananda/netlink v1.3.1 + golang.org/x/crypto v0.40.0 + golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 + golang.org/x/net v0.42.0 + golang.org/x/sys v0.34.0 + golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb + golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 +) require ( - github.com/google/btree v1.1.2 // indirect + github.com/Microsoft/go-winio v0.6.2 // indirect + github.com/containerd/errdefs v1.0.0 // indirect + github.com/containerd/errdefs/pkg v0.3.0 // indirect + github.com/distribution/reference v0.6.0 // indirect + github.com/docker/docker v28.3.2+incompatible // indirect + github.com/docker/go-connections v0.5.0 // indirect + github.com/docker/go-units v0.5.0 // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/gogo/protobuf v1.3.2 // indirect + github.com/google/btree v1.1.3 // indirect + github.com/google/go-cmp v0.7.0 // indirect + github.com/google/gopacket v1.1.19 // indirect github.com/gorilla/websocket v1.5.3 // indirect - golang.org/x/crypto v0.28.0 // indirect - golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 // indirect - golang.org/x/net v0.30.0 // indirect - golang.org/x/sys v0.26.0 // indirect - golang.org/x/time v0.7.0 // indirect + github.com/josharian/native v1.1.0 // indirect + github.com/mdlayher/genetlink v1.3.2 // indirect + github.com/mdlayher/netlink v1.7.2 // indirect + github.com/mdlayher/socket v0.5.1 // indirect + github.com/moby/docker-image-spec v1.3.1 // indirect + github.com/opencontainers/go-digest v1.0.0 // indirect + github.com/opencontainers/image-spec v1.1.1 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/vishvananda/netns v0.0.5 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 // indirect + go.opentelemetry.io/otel v1.37.0 // indirect + go.opentelemetry.io/otel/metric v1.37.0 // indirect + go.opentelemetry.io/otel/trace v1.37.0 // indirect + golang.org/x/mod v0.26.0 // indirect + golang.org/x/sync v0.16.0 // indirect + golang.org/x/time v0.12.0 // indirect + golang.org/x/tools v0.35.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect - golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 // indirect - gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect + gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 // indirect + software.sslmate.com/src/go-pkcs12 v0.5.0 // indirect ) diff --git a/go.sum b/go.sum index d95ab3a..18c5cff 100644 --- a/go.sum +++ b/go.sum @@ -1,22 +1,127 @@ -github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= -github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= +github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= +github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M= +github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE= +github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk= +github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= +github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= +github.com/docker/docker v28.3.2+incompatible h1:wn66NJ6pWB1vBZIilP8G3qQPqHy5XymfYn5vsqeA5oA= +github.com/docker/docker v28.3.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c= +github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= +github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= +github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/fosrl/newt v0.0.0-20250717220102-cd86e6b6de83 h1:jI6tP2sJNNb70Y+Ixq+oI06fDPnGUbarz/r67g7KvB8= +github.com/fosrl/newt v0.0.0-20250717220102-cd86e6b6de83/go.mod h1:oqHsCx1rsEc8hAVGVXemfeolIwlr19biJSQiLYi7Mvo= +github.com/fosrl/newt v0.0.0-20250718235538-510e78437ca4 h1:bK/MQyTOLGthrXZ7ExvOCdW0EH0o9b5vwk/+UKnNdg0= +github.com/fosrl/newt v0.0.0-20250718235538-510e78437ca4/go.mod h1:oqHsCx1rsEc8hAVGVXemfeolIwlr19biJSQiLYi7Mvo= +github.com/fosrl/newt v0.0.0-20250724190153-64c22a94a47a h1:Jgd60yfFJxb5z6L3LcoraaosHjiRgKLnMz6T3mv3D4Q= +github.com/fosrl/newt v0.0.0-20250724190153-64c22a94a47a/go.mod h1:oqHsCx1rsEc8hAVGVXemfeolIwlr19biJSQiLYi7Mvo= +github.com/fosrl/newt v0.0.0-20250724194014-008be54c554a h1:17r/Uhef6aIxpO0xYGI3771LJx7cTyc1WziDOgghc54= +github.com/fosrl/newt v0.0.0-20250724194014-008be54c554a/go.mod h1:oqHsCx1rsEc8hAVGVXemfeolIwlr19biJSQiLYi7Mvo= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= +github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= +github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= +github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= -golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= -golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 h1:yqrTHse8TCMW1M1ZCP+VAR/l0kKxwaAIqN/il7x4voA= -golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU= -golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= -golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= -golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= +github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= +github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= +github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= +github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= +github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= +github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= +github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= +github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= +github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= +github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= +github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= +github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= +github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0= +github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= +github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= +github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 h1:Hf9xI/XLML9ElpiHVDNwvqI0hIFlzV8dgIr35kV1kRU= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0/go.mod h1:NfchwuyNoMcZ5MLHwPrODwUF1HWCXWrL31s8gSAdIKY= +go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= +go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= +go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= +go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= +go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= +go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= +golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= +golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4= +golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc= +golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= +golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= +golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= +golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= -golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4= -golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= -golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE= -golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= -gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= -gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= +golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A= +golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw= +golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU= +golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ= +gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 h1:H+qymc2ndLKNFR5TcaPmsHGiJnhJMqeofBYSRq4oG3c= +gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56/go.mod h1:i8iCZyAdwRnLZYaIi2NUL1gfNtAveqxkKAe0JfAv9Bs= +software.sslmate.com/src/go-pkcs12 v0.5.0 h1:EC6R394xgENTpZ4RltKydeDUjtlM5drOYIG9c6TVj2M= +software.sslmate.com/src/go-pkcs12 v0.5.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI= diff --git a/httpserver/httpserver.go b/httpserver/httpserver.go new file mode 100644 index 0000000..a3c3d3b --- /dev/null +++ b/httpserver/httpserver.go @@ -0,0 +1,177 @@ +package httpserver + +import ( + "encoding/json" + "fmt" + "net/http" + "sync" + "time" + + "github.com/fosrl/newt/logger" +) + +// ConnectionRequest defines the structure for an incoming connection request +type ConnectionRequest struct { + ID string `json:"id"` + Secret string `json:"secret"` + Endpoint string `json:"endpoint"` +} + +// PeerStatus represents the status of a peer connection +type PeerStatus struct { + SiteID int `json:"siteId"` + Connected bool `json:"connected"` + RTT time.Duration `json:"rtt"` + LastSeen time.Time `json:"lastSeen"` +} + +// StatusResponse is returned by the status endpoint +type StatusResponse struct { + Status string `json:"status"` + Connected bool `json:"connected"` + TunnelIP string `json:"tunnelIP,omitempty"` + PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"` +} + +// HTTPServer represents the HTTP server and its state +type HTTPServer struct { + addr string + server *http.Server + connectionChan chan ConnectionRequest + statusMu sync.RWMutex + peerStatuses map[int]*PeerStatus + connectedAt time.Time + isConnected bool +} + +// NewHTTPServer creates a new HTTP server +func NewHTTPServer(addr string) *HTTPServer { + s := &HTTPServer{ + addr: addr, + connectionChan: make(chan ConnectionRequest, 1), + peerStatuses: make(map[int]*PeerStatus), + } + + return s +} + +// Start starts the HTTP server +func (s *HTTPServer) Start() error { + mux := http.NewServeMux() + mux.HandleFunc("/connect", s.handleConnect) + mux.HandleFunc("/status", s.handleStatus) + + s.server = &http.Server{ + Addr: s.addr, + Handler: mux, + } + + logger.Info("Starting HTTP server on %s", s.addr) + go func() { + if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.Error("HTTP server error: %v", err) + } + }() + + return nil +} + +// Stop stops the HTTP server +func (s *HTTPServer) Stop() error { + logger.Info("Stopping HTTP server") + return s.server.Close() +} + +// GetConnectionChannel returns the channel for receiving connection requests +func (s *HTTPServer) GetConnectionChannel() <-chan ConnectionRequest { + return s.connectionChan +} + +// UpdatePeerStatus updates the status of a peer +func (s *HTTPServer) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + + status, exists := s.peerStatuses[siteID] + if !exists { + status = &PeerStatus{ + SiteID: siteID, + } + s.peerStatuses[siteID] = status + } + + status.Connected = connected + status.RTT = rtt + status.LastSeen = time.Now() +} + +// SetConnectionStatus sets the overall connection status +func (s *HTTPServer) SetConnectionStatus(isConnected bool) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + + s.isConnected = isConnected + + if isConnected { + s.connectedAt = time.Now() + } else { + // Clear peer statuses when disconnected + s.peerStatuses = make(map[int]*PeerStatus) + } +} + +// handleConnect handles the /connect endpoint +func (s *HTTPServer) handleConnect(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req ConnectionRequest + decoder := json.NewDecoder(r.Body) + if err := decoder.Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest) + return + } + + // Validate required fields + if req.ID == "" || req.Secret == "" || req.Endpoint == "" { + http.Error(w, "Missing required fields: id, secret, and endpoint must be provided", http.StatusBadRequest) + return + } + + // Send the request to the main goroutine + s.connectionChan <- req + + // Return a success response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + json.NewEncoder(w).Encode(map[string]string{ + "status": "connection request accepted", + }) +} + +// handleStatus handles the /status endpoint +func (s *HTTPServer) handleStatus(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + s.statusMu.RLock() + defer s.statusMu.RUnlock() + + resp := StatusResponse{ + Connected: s.isConnected, + PeerStatuses: s.peerStatuses, + } + + if s.isConnected { + resp.Status = "connected" + } else { + resp.Status = "disconnected" + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) +} diff --git a/logger/level.go b/logger/level.go deleted file mode 100644 index 175995f..0000000 --- a/logger/level.go +++ /dev/null @@ -1,27 +0,0 @@ -package logger - -type LogLevel int - -const ( - DEBUG LogLevel = iota - INFO - WARN - ERROR - FATAL -) - -var levelStrings = map[LogLevel]string{ - DEBUG: "DEBUG", - INFO: "INFO", - WARN: "WARN", - ERROR: "ERROR", - FATAL: "FATAL", -} - -// String returns the string representation of the log level -func (l LogLevel) String() string { - if s, ok := levelStrings[l]; ok { - return s - } - return "UNKNOWN" -} diff --git a/logger/logger.go b/logger/logger.go deleted file mode 100644 index 9ef486d..0000000 --- a/logger/logger.go +++ /dev/null @@ -1,106 +0,0 @@ -package logger - -import ( - "fmt" - "log" - "os" - "sync" - "time" -) - -// Logger struct holds the logger instance -type Logger struct { - logger *log.Logger - level LogLevel -} - -var ( - defaultLogger *Logger - once sync.Once -) - -// NewLogger creates a new logger instance -func NewLogger() *Logger { - return &Logger{ - logger: log.New(os.Stdout, "", 0), - level: DEBUG, - } -} - -// Init initializes the default logger -func Init() *Logger { - once.Do(func() { - defaultLogger = NewLogger() - }) - return defaultLogger -} - -// GetLogger returns the default logger instance -func GetLogger() *Logger { - if defaultLogger == nil { - Init() - } - return defaultLogger -} - -// SetLevel sets the minimum logging level -func (l *Logger) SetLevel(level LogLevel) { - l.level = level -} - -// log handles the actual logging -func (l *Logger) log(level LogLevel, format string, args ...interface{}) { - if level < l.level { - return - } - timestamp := time.Now().Format("2006/01/02 15:04:05") - message := fmt.Sprintf(format, args...) - l.logger.Printf("%s: %s %s", level.String(), timestamp, message) -} - -// Debug logs debug level messages -func (l *Logger) Debug(format string, args ...interface{}) { - l.log(DEBUG, format, args...) -} - -// Info logs info level messages -func (l *Logger) Info(format string, args ...interface{}) { - l.log(INFO, format, args...) -} - -// Warn logs warning level messages -func (l *Logger) Warn(format string, args ...interface{}) { - l.log(WARN, format, args...) -} - -// Error logs error level messages -func (l *Logger) Error(format string, args ...interface{}) { - l.log(ERROR, format, args...) -} - -// Fatal logs fatal level messages and exits -func (l *Logger) Fatal(format string, args ...interface{}) { - l.log(FATAL, format, args...) - os.Exit(1) -} - -// Global helper functions -func Debug(format string, args ...interface{}) { - GetLogger().Debug(format, args...) -} - -func Info(format string, args ...interface{}) { - GetLogger().Info(format, args...) -} - -func Warn(format string, args ...interface{}) { - GetLogger().Warn(format, args...) -} - -func Error(format string, args ...interface{}) { - GetLogger().Error(format, args...) -} - -func Fatal(format string, args ...interface{}) { - GetLogger().Fatal(format, args...) -} diff --git a/main.go b/main.go index 786ecbd..6c17388 100644 --- a/main.go +++ b/main.go @@ -1,303 +1,381 @@ package main import ( - "bytes" - "encoding/base64" - "encoding/hex" + "context" "encoding/json" "flag" "fmt" - "math/rand" "net" - "net/netip" "os" "os/signal" + "runtime" "strconv" - "strings" "syscall" "time" "github.com/fosrl/newt/logger" - "github.com/fosrl/newt/proxy" "github.com/fosrl/newt/websocket" + "github.com/fosrl/olm/httpserver" + "github.com/fosrl/olm/peermonitor" + "github.com/fosrl/olm/wgtester" - "golang.org/x/net/icmp" - "golang.org/x/net/ipv4" - "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/tun/netstack" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -type WgData struct { - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` - ServerIP string `json:"serverIP"` - TunnelIP string `json:"tunnelIP"` - Targets TargetsByType `json:"targets"` -} - -type TargetsByType struct { - UDP []string `json:"udp"` - TCP []string `json:"tcp"` -} - -type TargetData struct { - Targets []string `json:"targets"` -} - -func fixKey(key string) string { - // Remove any whitespace - key = strings.TrimSpace(key) - - // Decode from base64 - decoded, err := base64.StdEncoding.DecodeString(key) - if err != nil { - logger.Fatal("Error decoding base64:", err) - } - - // Convert to hex - return hex.EncodeToString(decoded) -} - -func ping(tnet *netstack.Net, dst string) error { - logger.Info("Pinging %s", dst) - socket, err := tnet.Dial("ping4", dst) - if err != nil { - return fmt.Errorf("failed to create ICMP socket: %w", err) - } - defer socket.Close() - - requestPing := icmp.Echo{ - Seq: rand.Intn(1 << 16), - Data: []byte("gopher burrow"), - } - - icmpBytes, err := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil) - if err != nil { - return fmt.Errorf("failed to marshal ICMP message: %w", err) - } - - if err := socket.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil { - return fmt.Errorf("failed to set read deadline: %w", err) - } - - start := time.Now() - _, err = socket.Write(icmpBytes) - if err != nil { - return fmt.Errorf("failed to write ICMP packet: %w", err) - } - - n, err := socket.Read(icmpBytes[:]) - if err != nil { - return fmt.Errorf("failed to read ICMP packet: %w", err) - } - - replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n]) - if err != nil { - return fmt.Errorf("failed to parse ICMP packet: %w", err) - } - - replyPing, ok := replyPacket.Body.(*icmp.Echo) - if !ok { - return fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyPacket.Body) - } - - if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq { - return fmt.Errorf("invalid ping reply: got seq=%d data=%q, want seq=%d data=%q", - replyPing.Seq, replyPing.Data, requestPing.Seq, requestPing.Data) - } - - logger.Info("Ping latency: %v", time.Since(start)) - return nil -} - -func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{}) { - ticker := time.NewTicker(10 * time.Second) - defer ticker.Stop() - - go func() { - for { - select { - case <-ticker.C: - err := ping(tnet, serverIP) - if err != nil { - logger.Warn("Periodic ping failed: %v", err) - logger.Warn("HINT: Do you have UDP port 51280 (or the port in config.yml) open on your Pangolin server?") - } - case <-stopChan: - logger.Info("Stopping ping check") - return - } - } - }() -} - -func pingWithRetry(tnet *netstack.Net, dst string) error { - const ( - maxAttempts = 5 - retryDelay = 2 * time.Second - ) - - var lastErr error - for attempt := 1; attempt <= maxAttempts; attempt++ { - logger.Info("Ping attempt %d of %d", attempt, maxAttempts) - - if err := ping(tnet, dst); err != nil { - lastErr = err - logger.Warn("Ping attempt %d failed: %v", attempt, err) - - if attempt < maxAttempts { - time.Sleep(retryDelay) - continue - } - return fmt.Errorf("all ping attempts failed after %d tries, last error: %w", - maxAttempts, lastErr) - } - - // Successful ping - return nil - } - - // This shouldn't be reached due to the return in the loop, but added for completeness - return fmt.Errorf("unexpected error: all ping attempts failed") -} - -func parseLogLevel(level string) logger.LogLevel { - switch strings.ToUpper(level) { - case "DEBUG": - return logger.DEBUG - case "INFO": - return logger.INFO - case "WARN": - return logger.WARN - case "ERROR": - return logger.ERROR - case "FATAL": - return logger.FATAL - default: - return logger.INFO // default to INFO if invalid level provided - } -} - -func mapToWireGuardLogLevel(level logger.LogLevel) int { - switch level { - case logger.DEBUG: - return device.LogLevelVerbose - // case logger.INFO: - // return device.LogLevel - case logger.WARN: - return device.LogLevelError - case logger.ERROR, logger.FATAL: - return device.LogLevelSilent - default: - return device.LogLevelSilent - } -} - -func resolveDomain(domain string) (string, error) { - // Check if there's a port in the domain - host, port, err := net.SplitHostPort(domain) - if err != nil { - // No port found, use the domain as is - host = domain - port = "" - } - - // Remove any protocol prefix if present - if strings.HasPrefix(host, "http://") { - host = strings.TrimPrefix(host, "http://") - } else if strings.HasPrefix(host, "https://") { - host = strings.TrimPrefix(host, "https://") - } - - // Lookup IP addresses - ips, err := net.LookupIP(host) - if err != nil { - return "", fmt.Errorf("DNS lookup failed: %v", err) - } - - if len(ips) == 0 { - return "", fmt.Errorf("no IP addresses found for domain %s", host) - } - - // Get the first IPv4 address if available - var ipAddr string - for _, ip := range ips { - if ipv4 := ip.To4(); ipv4 != nil { - ipAddr = ipv4.String() - break - } - } - - // If no IPv4 found, use the first IP (might be IPv6) - if ipAddr == "" { - ipAddr = ips[0].String() - } - - // Add port back if it existed - if port != "" { - ipAddr = net.JoinHostPort(ipAddr, port) - } - - return ipAddr, nil -} - func main() { + // Check if we're running as a Windows service + if isWindowsService() { + runService("OlmWireguardService", false, os.Args[1:]) + fmt.Println("Running as Windows service") + return + } + + // Handle service management commands on Windows + if runtime.GOOS == "windows" && len(os.Args) > 1 { + switch os.Args[1] { + case "install": + err := installService() + if err != nil { + fmt.Printf("Failed to install service: %v\n", err) + os.Exit(1) + } + fmt.Println("Service installed successfully") + return + case "remove", "uninstall": + err := removeService() + if err != nil { + fmt.Printf("Failed to remove service: %v\n", err) + os.Exit(1) + } + fmt.Println("Service removed successfully") + return + case "start": + // Pass the remaining arguments (after "start") to the service + serviceArgs := os.Args[2:] + err := startService(serviceArgs) + if err != nil { + fmt.Printf("Failed to start service: %v\n", err) + os.Exit(1) + } + fmt.Println("Service started successfully") + return + case "stop": + err := stopService() + if err != nil { + fmt.Printf("Failed to stop service: %v\n", err) + os.Exit(1) + } + fmt.Println("Service stopped successfully") + return + case "status": + status, err := getServiceStatus() + if err != nil { + fmt.Printf("Failed to get service status: %v\n", err) + os.Exit(1) + } + fmt.Printf("Service status: %s\n", status) + return + case "debug": + // get the status and if it is Not Installed then install it first + status, err := getServiceStatus() + if err != nil { + fmt.Printf("Failed to get service status: %v\n", err) + os.Exit(1) + } + if status == "Not Installed" { + err := installService() + if err != nil { + fmt.Printf("Failed to install service: %v\n", err) + os.Exit(1) + } + fmt.Println("Service installed successfully, now running in debug mode") + } + + // Pass the remaining arguments (after "debug") to the service + serviceArgs := os.Args[2:] + err = debugService(serviceArgs) + if err != nil { + fmt.Printf("Failed to debug service: %v\n", err) + os.Exit(1) + } + return + case "logs": + err := watchLogFile(false) + if err != nil { + fmt.Printf("Failed to watch log file: %v\n", err) + os.Exit(1) + } + return + case "help", "--help", "-h": + fmt.Println("Olm WireGuard VPN Client") + fmt.Println("\nWindows Service Management:") + fmt.Println(" install Install the service") + fmt.Println(" remove Remove the service") + fmt.Println(" start Start the service") + fmt.Println(" stop Stop the service") + fmt.Println(" status Show service status") + fmt.Println(" debug Run service in debug mode") + fmt.Println("\nFor console mode, run without arguments or with standard flags.") + return + default: + // get the status and if it is Not Installed then install it first + status, err := getServiceStatus() + if err != nil { + fmt.Printf("Failed to get service status: %v\n", err) + os.Exit(1) + } + if status == "Not Installed" { + err := installService() + if err != nil { + fmt.Printf("Failed to install service: %v\n", err) + os.Exit(1) + } + fmt.Println("Service installed successfully, now running") + } + + // Pass the remaining arguments (after "debug") to the service + serviceArgs := os.Args[1:] + err = debugService(serviceArgs) + if err != nil { + fmt.Printf("Failed to debug service: %v\n", err) + os.Exit(1) + } + return + } + } + + // Run in console mode + runOlmMain(context.Background()) +} + +func runOlmMain(ctx context.Context) { + runOlmMainWithArgs(ctx, os.Args[1:]) +} + +func runOlmMainWithArgs(ctx context.Context, args []string) { + // Log that we've entered the main function + // fmt.Printf("runOlmMainWithArgs() called with args: %v\n", args) + + // Create a new FlagSet for parsing service arguments + serviceFlags := flag.NewFlagSet("service", flag.ContinueOnError) + var ( - endpoint string - id string - secret string - mtu string - mtuInt int - dns string - privateKey wgtypes.Key - err error - logLevel string + endpoint string + id string + secret string + mtu string + mtuInt int + dns string + privateKey wgtypes.Key + err error + logLevel string + interfaceName string + enableHTTP bool + httpAddr string + testMode bool // Add this var for the test flag + testTarget string // Add this var for test target + pingInterval time.Duration + pingTimeout time.Duration + doHolepunch bool + connected bool ) - // if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values + stopHolepunch = make(chan struct{}) + stopPing = make(chan struct{}) + + // if PANGOLIN_ENDPOINT, OLM_ID, and OLM_SECRET are set as environment variables, they will be used as default values endpoint = os.Getenv("PANGOLIN_ENDPOINT") - id = os.Getenv("NEWT_ID") - secret = os.Getenv("NEWT_SECRET") + id = os.Getenv("OLM_ID") + secret = os.Getenv("OLM_SECRET") mtu = os.Getenv("MTU") dns = os.Getenv("DNS") logLevel = os.Getenv("LOG_LEVEL") + interfaceName = os.Getenv("INTERFACE") + httpAddr = os.Getenv("HTTP_ADDR") + pingIntervalStr := os.Getenv("PING_INTERVAL") + pingTimeoutStr := os.Getenv("PING_TIMEOUT") + doHolepunch = os.Getenv("HOLEPUNCH") == "true" // Default to true, can be overridden by flag if endpoint == "" { - flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") + serviceFlags.StringVar(&endpoint, "endpoint", "", "Endpoint of your Pangolin server") } if id == "" { - flag.StringVar(&id, "id", "", "Newt ID") + serviceFlags.StringVar(&id, "id", "", "Olm ID") } if secret == "" { - flag.StringVar(&secret, "secret", "", "Newt secret") + serviceFlags.StringVar(&secret, "secret", "", "Olm secret") } if mtu == "" { - flag.StringVar(&mtu, "mtu", "1280", "MTU to use") + serviceFlags.StringVar(&mtu, "mtu", "1280", "MTU to use") } if dns == "" { - flag.StringVar(&dns, "dns", "8.8.8.8", "DNS server to use") + serviceFlags.StringVar(&dns, "dns", "8.8.8.8", "DNS server to use") } if logLevel == "" { - flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") + serviceFlags.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") + } + if interfaceName == "" { + serviceFlags.StringVar(&interfaceName, "interface", "olm", "Name of the WireGuard interface") + } + if httpAddr == "" { + serviceFlags.StringVar(&httpAddr, "http-addr", ":9452", "HTTP server address (e.g., ':9452')") + } + if pingIntervalStr == "" { + serviceFlags.StringVar(&pingIntervalStr, "ping-interval", "3s", "Interval for pinging the server (default 3s)") + } + if pingTimeoutStr == "" { + serviceFlags.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 3s)") + } + serviceFlags.BoolVar(&enableHTTP, "enable-http", false, "Enable HTT server for receiving connection requests") + serviceFlags.BoolVar(&doHolepunch, "holepunch", false, "Enable hole punching (default false)") + + // Parse the service arguments + if err := serviceFlags.Parse(args); err != nil { + fmt.Printf("Error parsing service arguments: %v\n", err) + return } - // do a --version check - version := flag.Bool("version", false, "Print the version") + // Debug: Print final values after flag parsing + // fmt.Printf("After flag parsing: endpoint='%s', id='%s', secret='%s'\n", endpoint, id, secret) - flag.Parse() - - if *version { - fmt.Println("Newt version replaceme") - os.Exit(0) + // Parse ping intervals + if pingIntervalStr != "" { + pingInterval, err = time.ParseDuration(pingIntervalStr) + if err != nil { + fmt.Printf("Invalid PING_INTERVAL value: %s, using default 3 seconds\n", pingIntervalStr) + pingInterval = 3 * time.Second + } + } else { + pingInterval = 3 * time.Second } - logger.Init() + if pingTimeoutStr != "" { + pingTimeout, err = time.ParseDuration(pingTimeoutStr) + if err != nil { + fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 5 seconds\n", pingTimeoutStr) + pingTimeout = 5 * time.Second + } + } else { + pingTimeout = 5 * time.Second + } + + // Setup Windows event logging if on Windows + if runtime.GOOS == "windows" { + setupWindowsEventLog() + } else { + // Initialize logger for non-Windows platforms + logger.Init() + } loggerLevel := parseLogLevel(logLevel) logger.GetLogger().SetLevel(parseLogLevel(logLevel)) + // Log startup information + logger.Debug("Olm service starting...") + logger.Debug("Parameters: endpoint='%s', id='%s', secret='%s'", endpoint, id, secret) + logger.Debug("HTTP enabled: %v, HTTP addr: %s", enableHTTP, httpAddr) + + // Handle test mode + if testMode { + if testTarget == "" { + logger.Fatal("Test mode requires -test-target to be set to a server:port") + } + + logger.Info("Running in test mode, connecting to %s", testTarget) + + // Create a new tester client + tester, err := wgtester.NewClient(testTarget) + if err != nil { + logger.Fatal("Failed to create tester client: %v", err) + } + defer tester.Close() + + // Test connection with a 2-second timeout + connected, rtt := tester.TestConnectionWithTimeout(2 * time.Second) + + if connected { + logger.Info("Connection test successful! RTT: %v", rtt) + fmt.Printf("Connection test successful! RTT: %v\n", rtt) + os.Exit(0) + } else { + logger.Error("Connection test failed - no response received") + fmt.Println("Connection test failed - no response received") + os.Exit(1) + } + } + + var httpServer *httpserver.HTTPServer + if enableHTTP { + httpServer = httpserver.NewHTTPServer(httpAddr) + if err := httpServer.Start(); err != nil { + logger.Fatal("Failed to start HTTP server: %v", err) + } + + // Use a goroutine to handle connection requests + go func() { + for req := range httpServer.GetConnectionChannel() { + logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint) + + // Set the connection parameters + id = req.ID + secret = req.Secret + endpoint = req.Endpoint + } + }() + } + + // // Check if required parameters are missing and provide helpful guidance + // missingParams := []string{} + // if id == "" { + // missingParams = append(missingParams, "id (use -id flag or OLM_ID env var)") + // } + // if secret == "" { + // missingParams = append(missingParams, "secret (use -secret flag or OLM_SECRET env var)") + // } + // if endpoint == "" { + // missingParams = append(missingParams, "endpoint (use -endpoint flag or PANGOLIN_ENDPOINT env var)") + // } + + // if len(missingParams) > 0 { + // logger.Error("Missing required parameters: %v", missingParams) + // logger.Error("Either provide them as command line flags or set as environment variables") + // fmt.Printf("ERROR: Missing required parameters: %v\n", missingParams) + // fmt.Printf("Please provide them as command line flags or set as environment variables\n") + // if !enableHTTP { + // logger.Error("HTTP server is disabled, cannot receive parameters via API") + // fmt.Printf("HTTP server is disabled, cannot receive parameters via API\n") + // return + // } + // } + + // // wait until we have a client id and secret and endpoint + // waitCount := 0 + // for id == "" || secret == "" || endpoint == "" { + // select { + // case <-ctx.Done(): + // logger.Info("Context cancelled while waiting for credentials") + // return + // default: + // missing := []string{} + // if id == "" { + // missing = append(missing, "id") + // } + // if secret == "" { + // missing = append(missing, "secret") + // } + // if endpoint == "" { + // missing = append(missing, "endpoint") + // } + // waitCount++ + // if waitCount%10 == 1 { // Log every 10 seconds instead of every second + // logger.Debug("Waiting for missing parameters: %v (waiting %d seconds)", missing, waitCount) + // } + // time.Sleep(1 * time.Second) + // } + // } + // parse the mtu string into an int mtuInt, err = strconv.Atoi(mtu) if err != nil { @@ -309,54 +387,78 @@ func main() { logger.Fatal("Failed to generate private key: %v", err) } - // Create a new client - client, err := websocket.NewClient( + // Create a new olm + olm, err := websocket.NewClient( + "olm", id, // CLI arg takes precedence secret, // CLI arg takes precedence endpoint, + pingInterval, + pingTimeout, ) if err != nil { - logger.Fatal("Failed to create client: %v", err) + logger.Fatal("Failed to create olm: %v", err) } + endpoint = olm.GetConfig().Endpoint // Update endpoint from config + id = olm.GetConfig().ID // Update ID from config // Create TUN device and network stack - var tun tun.Device - var tnet *netstack.Net var dev *device.Device - var pm *proxy.ProxyManager - var connected bool var wgData WgData + var holePunchData HolePunchData + var uapiListener net.Listener + var tdev tun.Device - client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) { - logger.Info("Received terminate message") - if pm != nil { - pm.Stop() + sourcePort, err := FindAvailableUDPPort(49152, 65535) + if err != nil { + fmt.Printf("Error finding available port: %v\n", err) + os.Exit(1) + } + + olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) { + logger.Debug("Received message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return } - if dev != nil { - dev.Close() + + if err := json.Unmarshal(jsonData, &holePunchData); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + return } - client.Close() + + gerbilServerPubKey = holePunchData.ServerPubKey + + go keepSendingUDPHolePunch(holePunchData.Endpoint, id, sourcePort) }) - pingStopChan := make(chan struct{}) - defer close(pingStopChan) - // Register handlers for different message types - client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) { - logger.Info("Received registration message") + olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) { + logger.Debug("Received message: %v", msg.Data) if connected { - logger.Info("Already connected! But I will send a ping anyway...") - // ping(tnet, wgData.ServerIP) - err = pingWithRetry(tnet, wgData.ServerIP) - if err != nil { - // Handle complete failure after all retries - logger.Warn("Failed to ping %s: %v", wgData.ServerIP, err) - logger.Warn("HINT: Do you have UDP port 51280 (or the port in config.yml) open on your Pangolin server?") - } + logger.Info("Already connected. Ignoring new connection request.") return } + if stopRegister != nil { + stopRegister() + stopRegister = nil + } + + close(stopHolepunch) + + // wait 10 milliseconds to ensure the previous connection is closed + time.Sleep(10 * time.Millisecond) + + // if there is an existing tunnel then close it + if dev != nil { + logger.Info("Got new message. Closing existing tunnel!") + dev.Close() + } + jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Info("Error marshaling data: %v", err) @@ -368,38 +470,82 @@ func main() { return } - logger.Info("Received: %+v", msg) - tun, tnet, err = netstack.CreateNetTUN( - []netip.Addr{netip.MustParseAddr(wgData.TunnelIP)}, - []netip.Addr{netip.MustParseAddr(dns)}, - mtuInt) + tdev, err = func() (tun.Device, error) { + tunFdStr := os.Getenv(ENV_WG_TUN_FD) + + // if on macOS, call findUnusedUTUN to get a new utun device + if runtime.GOOS == "darwin" { + interfaceName, err := findUnusedUTUN() + if err != nil { + return nil, err + } + return tun.CreateTUN(interfaceName, mtuInt) + } + + if tunFdStr == "" { + return tun.CreateTUN(interfaceName, mtuInt) + } + + return createTUNFromFD(tunFdStr, mtuInt) + }() + if err != nil { logger.Error("Failed to create TUN device: %v", err) + return } - // Create WireGuard device - dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger( + realInterfaceName, err2 := tdev.Name() + if err2 == nil { + interfaceName = realInterfaceName + } + + // open UAPI file (or use supplied fd) + fileUAPI, err := func() (*os.File, error) { + uapiFdStr := os.Getenv(ENV_WG_UAPI_FD) + if uapiFdStr == "" { + return uapiOpen(interfaceName) + } + + // use supplied fd + + fd, err := strconv.ParseUint(uapiFdStr, 10, 32) + if err != nil { + return nil, err + } + + return os.NewFile(uintptr(fd), ""), nil + }() + if err != nil { + logger.Error("UAPI listen error: %v", err) + os.Exit(1) + return + } + + dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger( mapToWireGuardLogLevel(loggerLevel), "wireguard: ", )) - endpoint, err := resolveDomain(wgData.Endpoint) + errs := make(chan error) + + uapiListener, err = uapiListen(interfaceName, fileUAPI) if err != nil { - logger.Error("Failed to resolve endpoint: %v", err) - return + logger.Error("Failed to listen on uapi socket: %v", err) + os.Exit(1) } - // Configure WireGuard - config := fmt.Sprintf(`private_key=%s -public_key=%s -allowed_ip=%s/32 -endpoint=%s -persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint) + go func() { + for { + conn, err := uapiListener.Accept() + if err != nil { + errs <- err + return + } + go dev.IpcHandle(conn) + } + }() - err = dev.IpcSet(config) - if err != nil { - logger.Error("Failed to configure WireGuard device: %v", err) - } + logger.Info("UAPI listener started") // Bring up the device err = dev.Up() @@ -407,206 +553,354 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey( logger.Error("Failed to bring up WireGuard device: %v", err) } - logger.Info("WireGuard device created. Lets ping the server now...") - // Ping to bring the tunnel up on the server side quickly - // ping(tnet, wgData.ServerIP) - err = pingWithRetry(tnet, wgData.ServerIP) + // configure the interface + err = ConfigureInterface(realInterfaceName, wgData) if err != nil { - // Handle complete failure after all retries - logger.Error("Failed to ping %s: %v", wgData.ServerIP, err) + logger.Error("Failed to configure interface: %v", err) } - if !connected { - logger.Info("Starting ping check") - startPingCheck(tnet, wgData.ServerIP, pingStopChan) + peerMonitor = peermonitor.NewPeerMonitor( + func(siteID int, connected bool, rtt time.Duration) { + if httpServer != nil { + httpServer.UpdatePeerStatus(siteID, connected, rtt) + } + if connected { + logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt) + } else { + logger.Warn("Peer %d is disconnected", siteID) + } + }, + fixKey(privateKey.String()), + olm, + dev, + doHolepunch, + ) + + // loop over the sites and call ConfigurePeer for each one + for _, site := range wgData.Sites { + if httpServer != nil { + httpServer.UpdatePeerStatus(site.SiteId, false, 0) + } + err = ConfigurePeer(dev, site, privateKey, endpoint) + if err != nil { + logger.Error("Failed to configure peer: %v", err) + return + } + + err = addRouteForServerIP(site.ServerIP, interfaceName) + if err != nil { + logger.Error("Failed to add route for peer: %v", err) + return + } + + // Add routes for remote subnets + if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil { + logger.Error("Failed to add routes for remote subnets: %v", err) + return + } + + logger.Info("Configured peer %s", site.PublicKey) } - // Create proxy manager - pm = proxy.NewProxyManager(tnet) + peerMonitor.Start() connected = true - // add the targets if there are any - if len(wgData.Targets.TCP) > 0 { - updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP}) - } + logger.Info("WireGuard device created.") + }) - if len(wgData.Targets.UDP) > 0 { - updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP}) - } + olm.RegisterHandler("olm/wg/peer/update", func(msg websocket.WSMessage) { + logger.Debug("Received update-peer message: %v", msg.Data) - err = pm.Start() + jsonData, err := json.Marshal(msg.Data) if err != nil { - logger.Error("Failed to start proxy manager: %v", err) + logger.Error("Error marshaling data: %v", err) + return + } + + var updateData UpdatePeerData + if err := json.Unmarshal(jsonData, &updateData); err != nil { + logger.Error("Error unmarshaling update data: %v", err) + return + } + + // Convert to SiteConfig + siteConfig := SiteConfig{ + SiteId: updateData.SiteId, + Endpoint: updateData.Endpoint, + PublicKey: updateData.PublicKey, + ServerIP: updateData.ServerIP, + ServerPort: updateData.ServerPort, + RemoteSubnets: updateData.RemoteSubnets, + } + + // Update the peer in WireGuard + if dev != nil { + // Find the existing peer to get old RemoteSubnets + var oldRemoteSubnets string + for _, site := range wgData.Sites { + if site.SiteId == updateData.SiteId { + oldRemoteSubnets = site.RemoteSubnets + break + } + } + + if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { + logger.Error("Failed to update peer: %v", err) + // Send error response if needed + return + } + + // Remove old remote subnet routes if they changed + if oldRemoteSubnets != siteConfig.RemoteSubnets { + if err := removeRoutesForRemoteSubnets(oldRemoteSubnets); err != nil { + logger.Error("Failed to remove old remote subnet routes: %v", err) + // Continue anyway to add new routes + } + + // Add new remote subnet routes + if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil { + logger.Error("Failed to add new remote subnet routes: %v", err) + return + } + } + + // Update successful + logger.Info("Successfully updated peer for site %d", updateData.SiteId) + // If this is part of a WgData structure, update it + for i, site := range wgData.Sites { + if site.SiteId == updateData.SiteId { + wgData.Sites[i] = siteConfig + break + } + } + } else { + logger.Error("WireGuard device not initialized") } }) - client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) { - logger.Info("Received: %+v", msg) + // Handler for adding a new peer + olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) { + logger.Debug("Received add-peer message: %v", msg.Data) - // if there is no wgData or pm, we can't add targets - if wgData.TunnelIP == "" || pm == nil { - logger.Info("No tunnel IP or proxy manager available") - return - } - - targetData, err := parseTargetData(msg.Data) + jsonData, err := json.Marshal(msg.Data) if err != nil { - logger.Info("Error parsing target data: %v", err) + logger.Error("Error marshaling data: %v", err) return } - if len(targetData.Targets) > 0 { - updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData) + var addData AddPeerData + if err := json.Unmarshal(jsonData, &addData); err != nil { + logger.Error("Error unmarshaling add data: %v", err) + return + } + + // Convert to SiteConfig + siteConfig := SiteConfig{ + SiteId: addData.SiteId, + Endpoint: addData.Endpoint, + PublicKey: addData.PublicKey, + ServerIP: addData.ServerIP, + ServerPort: addData.ServerPort, + RemoteSubnets: addData.RemoteSubnets, + } + + // Add the peer to WireGuard + if dev != nil { + if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { + logger.Error("Failed to add peer: %v", err) + return + } + + // Add route for the new peer + err = addRouteForServerIP(siteConfig.ServerIP, interfaceName) + if err != nil { + logger.Error("Failed to add route for new peer: %v", err) + return + } + + // Add routes for remote subnets + if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil { + logger.Error("Failed to add routes for remote subnets: %v", err) + return + } + + // Add successful + logger.Info("Successfully added peer for site %d", addData.SiteId) + + // Update WgData with the new peer + wgData.Sites = append(wgData.Sites, siteConfig) + } else { + logger.Error("WireGuard device not initialized") } }) - client.RegisterHandler("newt/udp/add", func(msg websocket.WSMessage) { - logger.Info("Received: %+v", msg) + // Handler for removing a peer + olm.RegisterHandler("olm/wg/peer/remove", func(msg websocket.WSMessage) { + logger.Debug("Received remove-peer message: %v", msg.Data) - // if there is no wgData or pm, we can't add targets - if wgData.TunnelIP == "" || pm == nil { - logger.Info("No tunnel IP or proxy manager available") - return - } - - targetData, err := parseTargetData(msg.Data) + jsonData, err := json.Marshal(msg.Data) if err != nil { - logger.Info("Error parsing target data: %v", err) + logger.Error("Error marshaling data: %v", err) return } - if len(targetData.Targets) > 0 { - updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData) + var removeData RemovePeerData + if err := json.Unmarshal(jsonData, &removeData); err != nil { + logger.Error("Error unmarshaling remove data: %v", err) + return + } + + // Find the peer to remove + var peerToRemove *SiteConfig + var newSites []SiteConfig + + for _, site := range wgData.Sites { + if site.SiteId == removeData.SiteId { + peerToRemove = &site + } else { + newSites = append(newSites, site) + } + } + + if peerToRemove == nil { + logger.Error("Peer with site ID %d not found", removeData.SiteId) + return + } + + // Remove the peer from WireGuard + if dev != nil { + if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil { + logger.Error("Failed to remove peer: %v", err) + // Send error response if needed + return + } + + // Remove route for the peer + err = removeRouteForServerIP(peerToRemove.ServerIP) + if err != nil { + logger.Error("Failed to remove route for peer: %v", err) + return + } + + // Remove routes for remote subnets + if err := removeRoutesForRemoteSubnets(peerToRemove.RemoteSubnets); err != nil { + logger.Error("Failed to remove routes for remote subnets: %v", err) + return + } + + // Remove successful + logger.Info("Successfully removed peer for site %d", removeData.SiteId) + + // Update WgData to remove the peer + wgData.Sites = newSites + } else { + logger.Error("WireGuard device not initialized") } }) - client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) { - logger.Info("Received: %+v", msg) + olm.RegisterHandler("olm/wg/peer/relay", func(msg websocket.WSMessage) { + logger.Debug("Received relay-peer message: %v", msg.Data) - // if there is no wgData or pm, we can't add targets - if wgData.TunnelIP == "" || pm == nil { - logger.Info("No tunnel IP or proxy manager available") - return - } - - targetData, err := parseTargetData(msg.Data) + jsonData, err := json.Marshal(msg.Data) if err != nil { - logger.Info("Error parsing target data: %v", err) + logger.Error("Error marshaling data: %v", err) return } - if len(targetData.Targets) > 0 { - updateTargets(pm, "remove", wgData.TunnelIP, "udp", targetData) + var removeData RelayPeerData + if err := json.Unmarshal(jsonData, &removeData); err != nil { + logger.Error("Error unmarshaling remove data: %v", err) + return } + + primaryRelay, err := resolveDomain(removeData.Endpoint) + if err != nil { + logger.Warn("Failed to resolve primary relay endpoint: %v", err) + } + + peerMonitor.HandleFailover(removeData.SiteId, primaryRelay) }) - client.RegisterHandler("newt/tcp/remove", func(msg websocket.WSMessage) { - logger.Info("Received: %+v", msg) - - // if there is no wgData or pm, we can't add targets - if wgData.TunnelIP == "" || pm == nil { - logger.Info("No tunnel IP or proxy manager available") - return - } - - targetData, err := parseTargetData(msg.Data) - if err != nil { - logger.Info("Error parsing target data: %v", err) - return - } - - if len(targetData.Targets) > 0 { - updateTargets(pm, "remove", wgData.TunnelIP, "tcp", targetData) - } + olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) { + logger.Info("Received terminate message") + olm.Close() }) - client.OnConnect(func() error { + olm.OnConnect(func() error { + logger.Info("Websocket Connected") + + if httpServer != nil { + httpServer.SetConnectionStatus(true) + } + + if connected { + logger.Debug("Already connected, skipping registration") + return nil + } + publicKey := privateKey.PublicKey() - logger.Debug("Public key: %s", publicKey) - err := client.SendMessage("newt/wg/register", map[string]interface{}{ - "publicKey": fmt.Sprintf("%s", publicKey), - }) - if err != nil { - logger.Error("Failed to send registration message: %v", err) - return err - } + logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch) + + stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ + "publicKey": publicKey.String(), + "relay": !doHolepunch, + }, 1*time.Second) + + go keepSendingPing(olm) logger.Info("Sent registration message") return nil }) + olm.OnTokenUpdate(func(token string) { + olmToken = token + }) + // Connect to the WebSocket server - if err := client.Connect(); err != nil { + if err := olm.Connect(); err != nil { logger.Fatal("Failed to connect to server: %v", err) } - defer client.Close() + defer olm.Close() - // Wait for interrupt signal + // Wait for interrupt signal or context cancellation sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - <-sigCh - // Cleanup - dev.Close() -} - -func parseTargetData(data interface{}) (TargetData, error) { - var targetData TargetData - jsonData, err := json.Marshal(data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return targetData, err + select { + case <-sigCh: + logger.Info("Received interrupt signal") + case <-ctx.Done(): + logger.Info("Context cancelled") } - if err := json.Unmarshal(jsonData, &targetData); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - return targetData, err - } - return targetData, nil -} - -func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error { - for _, t := range targetData.Targets { - // Split the first number off of the target with : separator and use as the port - parts := strings.Split(t, ":") - if len(parts) != 3 { - logger.Info("Invalid target format: %s", t) - continue - } - - // Get the port as an int - port := 0 - _, err := fmt.Sscanf(parts[0], "%d", &port) - if err != nil { - logger.Info("Invalid port: %s", parts[0]) - continue - } - - if action == "add" { - target := parts[1] + ":" + parts[2] - // Only remove the specific target if it exists - err := pm.RemoveTarget(proto, tunnelIP, port) - if err != nil { - // Ignore "target not found" errors as this is expected for new targets - if !strings.Contains(err.Error(), "target not found") { - logger.Error("Failed to remove existing target: %v", err) - } - } - - // Add the new target - pm.AddTarget(proto, tunnelIP, port, target) - - } else if action == "remove" { - logger.Info("Removing target with port %d", port) - err := pm.RemoveTarget(proto, tunnelIP, port) - if err != nil { - logger.Error("Failed to remove target: %v", err) - return err - } - } + select { + case <-stopHolepunch: + // Channel already closed, do nothing + default: + close(stopHolepunch) } - return nil + if stopRegister != nil { + stopRegister() + stopRegister = nil + } + + select { + case <-stopPing: + // Channel already closed + default: + close(stopPing) + } + + if uapiListener != nil { + uapiListener.Close() + } + if dev != nil { + dev.Close() + } + + logger.Info("runOlmMain() exiting") + fmt.Printf("runOlmMain() exiting\n") } diff --git a/peermonitor/peermonitor.go b/peermonitor/peermonitor.go new file mode 100644 index 0000000..df90de2 --- /dev/null +++ b/peermonitor/peermonitor.go @@ -0,0 +1,324 @@ +package peermonitor + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/websocket" + "github.com/fosrl/olm/wgtester" + "golang.zx2c4.com/wireguard/device" +) + +// PeerMonitorCallback is the function type for connection status change callbacks +type PeerMonitorCallback func(siteID int, connected bool, rtt time.Duration) + +// WireGuardConfig holds the WireGuard configuration for a peer +type WireGuardConfig struct { + SiteID int + PublicKey string + ServerIP string + Endpoint string + PrimaryRelay string // The primary relay endpoint +} + +// PeerMonitor handles monitoring the connection status to multiple WireGuard peers +type PeerMonitor struct { + monitors map[int]*wgtester.Client + configs map[int]*WireGuardConfig + callback PeerMonitorCallback + mutex sync.Mutex + running bool + interval time.Duration + timeout time.Duration + maxAttempts int + privateKey string + wsClient *websocket.Client + device *device.Device + handleRelaySwitch bool // Whether to handle relay switching +} + +// NewPeerMonitor creates a new peer monitor with the given callback +func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool) *PeerMonitor { + return &PeerMonitor{ + monitors: make(map[int]*wgtester.Client), + configs: make(map[int]*WireGuardConfig), + callback: callback, + interval: 1 * time.Second, // Default check interval + timeout: 2500 * time.Millisecond, + maxAttempts: 8, + privateKey: privateKey, + wsClient: wsClient, + device: device, + handleRelaySwitch: handleRelaySwitch, + } +} + +// SetInterval changes how frequently peers are checked +func (pm *PeerMonitor) SetInterval(interval time.Duration) { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + pm.interval = interval + + // Update interval for all existing monitors + for _, client := range pm.monitors { + client.SetPacketInterval(interval) + } +} + +// SetTimeout changes the timeout for waiting for responses +func (pm *PeerMonitor) SetTimeout(timeout time.Duration) { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + pm.timeout = timeout + + // Update timeout for all existing monitors + for _, client := range pm.monitors { + client.SetTimeout(timeout) + } +} + +// SetMaxAttempts changes the maximum number of attempts for TestConnection +func (pm *PeerMonitor) SetMaxAttempts(attempts int) { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + pm.maxAttempts = attempts + + // Update max attempts for all existing monitors + for _, client := range pm.monitors { + client.SetMaxAttempts(attempts) + } +} + +// AddPeer adds a new peer to monitor +func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, wgConfig *WireGuardConfig) error { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + // Check if we're already monitoring this peer + if _, exists := pm.monitors[siteID]; exists { + // Update the endpoint instead of creating a new monitor + pm.removePeerUnlocked(siteID) + } + + client, err := wgtester.NewClient(endpoint) + if err != nil { + return err + } + + // Configure the client with our settings + client.SetPacketInterval(pm.interval) + client.SetTimeout(pm.timeout) + client.SetMaxAttempts(pm.maxAttempts) + + // Store the client and config + pm.monitors[siteID] = client + pm.configs[siteID] = wgConfig + + // If monitor is already running, start monitoring this peer + if pm.running { + siteIDCopy := siteID // Create a copy for the closure + err = client.StartMonitor(func(status wgtester.ConnectionStatus) { + pm.handleConnectionStatusChange(siteIDCopy, status) + }) + } + + return err +} + +// removePeerUnlocked stops monitoring a peer and removes it from the monitor +// This function assumes the mutex is already held by the caller +func (pm *PeerMonitor) removePeerUnlocked(siteID int) { + client, exists := pm.monitors[siteID] + if !exists { + return + } + + client.StopMonitor() + client.Close() + delete(pm.monitors, siteID) + delete(pm.configs, siteID) +} + +// RemovePeer stops monitoring a peer and removes it from the monitor +func (pm *PeerMonitor) RemovePeer(siteID int) { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + pm.removePeerUnlocked(siteID) +} + +// Start begins monitoring all peers +func (pm *PeerMonitor) Start() { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + if pm.running { + return // Already running + } + + pm.running = true + + // Start monitoring all peers + for siteID, client := range pm.monitors { + siteIDCopy := siteID // Create a copy for the closure + err := client.StartMonitor(func(status wgtester.ConnectionStatus) { + pm.handleConnectionStatusChange(siteIDCopy, status) + }) + if err != nil { + logger.Error("Failed to start monitoring peer %d: %v\n", siteID, err) + continue + } + logger.Info("Started monitoring peer %d\n", siteID) + } +} + +// handleConnectionStatusChange is called when a peer's connection status changes +func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status wgtester.ConnectionStatus) { + // Call the user-provided callback first + if pm.callback != nil { + pm.callback(siteID, status.Connected, status.RTT) + } + + // If disconnected, handle failover + if !status.Connected { + // Send relay message to the server + if pm.wsClient != nil { + pm.sendRelay(siteID) + } + } +} + +// handleFailover handles failover to the relay server when a peer is disconnected +func (pm *PeerMonitor) HandleFailover(siteID int, relayEndpoint string) { + pm.mutex.Lock() + config, exists := pm.configs[siteID] + pm.mutex.Unlock() + + if !exists { + return + } + + // Configure WireGuard to use the relay + wgConfig := fmt.Sprintf(`private_key=%s +public_key=%s +allowed_ip=%s/32 +endpoint=%s:21820 +persistent_keepalive_interval=1`, pm.privateKey, config.PublicKey, config.ServerIP, relayEndpoint) + + err := pm.device.IpcSet(wgConfig) + if err != nil { + logger.Error("Failed to configure WireGuard device: %v\n", err) + return + } + + logger.Info("Adjusted peer %d to point to relay!\n", siteID) +} + +// sendRelay sends a relay message to the server +func (pm *PeerMonitor) sendRelay(siteID int) error { + if !pm.handleRelaySwitch { + return nil + } + + if pm.wsClient == nil { + return fmt.Errorf("websocket client is nil") + } + + err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{ + "siteId": siteID, + }) + if err != nil { + logger.Error("Failed to send registration message: %v", err) + return err + } + logger.Info("Sent relay message") + return nil +} + +// Stop stops monitoring all peers +func (pm *PeerMonitor) Stop() { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + if !pm.running { + return + } + + pm.running = false + + // Stop all monitors + for _, client := range pm.monitors { + client.StopMonitor() + } +} + +// Close stops monitoring and cleans up resources +func (pm *PeerMonitor) Close() { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + // Stop and close all clients + for siteID, client := range pm.monitors { + client.StopMonitor() + client.Close() + delete(pm.monitors, siteID) + } + + pm.running = false +} + +// TestPeer tests connectivity to a specific peer +func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) { + pm.mutex.Lock() + client, exists := pm.monitors[siteID] + pm.mutex.Unlock() + + if !exists { + return false, 0, fmt.Errorf("peer with siteID %d not found", siteID) + } + + ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) + defer cancel() + + connected, rtt := client.TestConnection(ctx) + return connected, rtt, nil +} + +// TestAllPeers tests connectivity to all peers +func (pm *PeerMonitor) TestAllPeers() map[int]struct { + Connected bool + RTT time.Duration +} { + pm.mutex.Lock() + peers := make(map[int]*wgtester.Client, len(pm.monitors)) + for siteID, client := range pm.monitors { + peers[siteID] = client + } + pm.mutex.Unlock() + + results := make(map[int]struct { + Connected bool + RTT time.Duration + }) + for siteID, client := range peers { + ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) + connected, rtt := client.TestConnection(ctx) + cancel() + + results[siteID] = struct { + Connected bool + RTT time.Duration + }{ + Connected: connected, + RTT: rtt, + } + } + + return results +} diff --git a/proxy/manager.go b/proxy/manager.go deleted file mode 100644 index b6c521b..0000000 --- a/proxy/manager.go +++ /dev/null @@ -1,352 +0,0 @@ -package proxy - -import ( - "fmt" - "io" - "net" - "strings" - "sync" - "time" - - "github.com/fosrl/newt/logger" - "golang.zx2c4.com/wireguard/tun/netstack" - "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" -) - -// Target represents a proxy target with its address and port -type Target struct { - Address string - Port int -} - -// ProxyManager handles the creation and management of proxy connections -type ProxyManager struct { - tnet *netstack.Net - tcpTargets map[string]map[int]string // map[listenIP]map[port]targetAddress - udpTargets map[string]map[int]string - listeners []*gonet.TCPListener - udpConns []*gonet.UDPConn - running bool - mutex sync.RWMutex -} - -// NewProxyManager creates a new proxy manager instance -func NewProxyManager(tnet *netstack.Net) *ProxyManager { - return &ProxyManager{ - tnet: tnet, - tcpTargets: make(map[string]map[int]string), - udpTargets: make(map[string]map[int]string), - listeners: make([]*gonet.TCPListener, 0), - udpConns: make([]*gonet.UDPConn, 0), - } -} - -// AddTarget adds a new target for proxying -func (pm *ProxyManager) AddTarget(proto, listenIP string, port int, targetAddr string) error { - pm.mutex.Lock() - defer pm.mutex.Unlock() - - switch proto { - case "tcp": - if pm.tcpTargets[listenIP] == nil { - pm.tcpTargets[listenIP] = make(map[int]string) - } - pm.tcpTargets[listenIP][port] = targetAddr - case "udp": - if pm.udpTargets[listenIP] == nil { - pm.udpTargets[listenIP] = make(map[int]string) - } - pm.udpTargets[listenIP][port] = targetAddr - default: - return fmt.Errorf("unsupported protocol: %s", proto) - } - - if pm.running { - return pm.startTarget(proto, listenIP, port, targetAddr) - } else { - logger.Info("Not adding target because not running") - } - return nil -} - -func (pm *ProxyManager) RemoveTarget(proto, listenIP string, port int) error { - pm.mutex.Lock() - defer pm.mutex.Unlock() - - switch proto { - case "tcp": - if targets, ok := pm.tcpTargets[listenIP]; ok { - delete(targets, port) - // Remove and close the corresponding TCP listener - for i, listener := range pm.listeners { - if addr, ok := listener.Addr().(*net.TCPAddr); ok && addr.Port == port { - listener.Close() - time.Sleep(50 * time.Millisecond) - // Remove from slice - pm.listeners = append(pm.listeners[:i], pm.listeners[i+1:]...) - break - } - } - } else { - return fmt.Errorf("target not found: %s:%d", listenIP, port) - } - case "udp": - if targets, ok := pm.udpTargets[listenIP]; ok { - delete(targets, port) - // Remove and close the corresponding UDP connection - for i, conn := range pm.udpConns { - if addr, ok := conn.LocalAddr().(*net.UDPAddr); ok && addr.Port == port { - conn.Close() - time.Sleep(50 * time.Millisecond) - // Remove from slice - pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...) - break - } - } - } else { - return fmt.Errorf("target not found: %s:%d", listenIP, port) - } - default: - return fmt.Errorf("unsupported protocol: %s", proto) - } - return nil -} - -// Start begins listening for all configured proxy targets -func (pm *ProxyManager) Start() error { - pm.mutex.Lock() - defer pm.mutex.Unlock() - - if pm.running { - return nil - } - - // Start TCP targets - for listenIP, targets := range pm.tcpTargets { - for port, targetAddr := range targets { - if err := pm.startTarget("tcp", listenIP, port, targetAddr); err != nil { - return fmt.Errorf("failed to start TCP target: %v", err) - } - } - } - - // Start UDP targets - for listenIP, targets := range pm.udpTargets { - for port, targetAddr := range targets { - if err := pm.startTarget("udp", listenIP, port, targetAddr); err != nil { - return fmt.Errorf("failed to start UDP target: %v", err) - } - } - } - - pm.running = true - return nil -} - -func (pm *ProxyManager) Stop() error { - pm.mutex.Lock() - defer pm.mutex.Unlock() - - if !pm.running { - return nil - } - - // Set running to false first to signal handlers to stop - pm.running = false - - // Close TCP listeners - for i := len(pm.listeners) - 1; i >= 0; i-- { - listener := pm.listeners[i] - if err := listener.Close(); err != nil { - logger.Error("Error closing TCP listener: %v", err) - } - // Remove from slice - pm.listeners = append(pm.listeners[:i], pm.listeners[i+1:]...) - } - - // Close UDP connections - for i := len(pm.udpConns) - 1; i >= 0; i-- { - conn := pm.udpConns[i] - if err := conn.Close(); err != nil { - logger.Error("Error closing UDP connection: %v", err) - } - // Remove from slice - pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...) - } - - // Clear the target maps - for k := range pm.tcpTargets { - delete(pm.tcpTargets, k) - } - for k := range pm.udpTargets { - delete(pm.udpTargets, k) - } - - // Give active connections a chance to close gracefully - time.Sleep(100 * time.Millisecond) - - return nil -} - -func (pm *ProxyManager) startTarget(proto, listenIP string, port int, targetAddr string) error { - switch proto { - case "tcp": - listener, err := pm.tnet.ListenTCP(&net.TCPAddr{Port: port}) - if err != nil { - return fmt.Errorf("failed to create TCP listener: %v", err) - } - - pm.listeners = append(pm.listeners, listener) - go pm.handleTCPProxy(listener, targetAddr) - - case "udp": - addr := &net.UDPAddr{Port: port} - conn, err := pm.tnet.ListenUDP(addr) - if err != nil { - return fmt.Errorf("failed to create UDP listener: %v", err) - } - - pm.udpConns = append(pm.udpConns, conn) - go pm.handleUDPProxy(conn, targetAddr) - - default: - return fmt.Errorf("unsupported protocol: %s", proto) - } - - logger.Info("Started %s proxy from %s:%d to %s", proto, listenIP, port, targetAddr) - - return nil -} - -func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string) { - for { - conn, err := listener.Accept() - if err != nil { - // Check if we're shutting down or the listener was closed - if !pm.running { - return - } - - // Check for specific network errors that indicate the listener is closed - if ne, ok := err.(net.Error); ok && !ne.Temporary() { - logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr()) - return - } - - logger.Error("Error accepting TCP connection: %v", err) - // Don't hammer the CPU if we hit a temporary error - time.Sleep(100 * time.Millisecond) - continue - } - - go func() { - target, err := net.Dial("tcp", targetAddr) - if err != nil { - logger.Error("Error connecting to target: %v", err) - conn.Close() - return - } - - // Create a WaitGroup to ensure both copy operations complete - var wg sync.WaitGroup - wg.Add(2) - - go func() { - defer wg.Done() - io.Copy(target, conn) - target.Close() - }() - - go func() { - defer wg.Done() - io.Copy(conn, target) - conn.Close() - }() - - // Wait for both copies to complete - wg.Wait() - }() - } -} - -func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) { - buffer := make([]byte, 65507) // Max UDP packet size - clientConns := make(map[string]*net.UDPConn) - var clientsMutex sync.RWMutex - - for { - n, remoteAddr, err := conn.ReadFrom(buffer) - if err != nil { - if !pm.running { - return - } - - // Check for connection closed conditions - if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") { - logger.Info("UDP connection closed, stopping proxy handler") - - // Clean up existing client connections - clientsMutex.Lock() - for _, targetConn := range clientConns { - targetConn.Close() - } - clientConns = nil - clientsMutex.Unlock() - - return - } - - logger.Error("Error reading UDP packet: %v", err) - continue - } - - clientKey := remoteAddr.String() - clientsMutex.RLock() - targetConn, exists := clientConns[clientKey] - clientsMutex.RUnlock() - - if !exists { - targetUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr) - if err != nil { - logger.Error("Error resolving target address: %v", err) - continue - } - - targetConn, err = net.DialUDP("udp", nil, targetUDPAddr) - if err != nil { - logger.Error("Error connecting to target: %v", err) - continue - } - - clientsMutex.Lock() - clientConns[clientKey] = targetConn - clientsMutex.Unlock() - - go func() { - buffer := make([]byte, 65507) - for { - n, _, err := targetConn.ReadFromUDP(buffer) - if err != nil { - logger.Error("Error reading from target: %v", err) - return - } - - _, err = conn.WriteTo(buffer[:n], remoteAddr) - if err != nil { - logger.Error("Error writing to client: %v", err) - return - } - } - }() - } - - _, err = targetConn.Write(buffer[:n]) - if err != nil { - logger.Error("Error writing to target: %v", err) - targetConn.Close() - clientsMutex.Lock() - delete(clientConns, clientKey) - clientsMutex.Unlock() - } - } -} diff --git a/public/screenshots/preview.png b/public/screenshots/preview.png deleted file mode 100644 index c6a8cd8..0000000 Binary files a/public/screenshots/preview.png and /dev/null differ diff --git a/service_unix.go b/service_unix.go new file mode 100644 index 0000000..c9f5fbf --- /dev/null +++ b/service_unix.go @@ -0,0 +1,50 @@ +//go:build !windows + +package main + +import ( + "fmt" +) + +// Service management functions are not available on non-Windows platforms +func installService() error { + return fmt.Errorf("service management is only available on Windows") +} + +func removeService() error { + return fmt.Errorf("service management is only available on Windows") +} + +func startService(args []string) error { + _ = args // unused on Unix platforms + return fmt.Errorf("service management is only available on Windows") +} + +func stopService() error { + return fmt.Errorf("service management is only available on Windows") +} + +func getServiceStatus() (string, error) { + return "", fmt.Errorf("service management is only available on Windows") +} + +func debugService(args []string) error { + _ = args // unused on Unix platforms + return fmt.Errorf("debug service is only available on Windows") +} + +func isWindowsService() bool { + return false +} + +func runService(name string, isDebug bool, args []string) { + // No-op on non-Windows platforms +} + +func setupWindowsEventLog() { + // No-op on non-Windows platforms +} + +func watchLogFile(end bool) error { + return fmt.Errorf("watching log file is only available on Windows") +} diff --git a/service_windows.go b/service_windows.go new file mode 100644 index 0000000..f4dd7ff --- /dev/null +++ b/service_windows.go @@ -0,0 +1,537 @@ +//go:build windows + +package main + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log" + "os" + "os/signal" + "path/filepath" + "syscall" + "time" + + "github.com/fosrl/newt/logger" + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/debug" + "golang.org/x/sys/windows/svc/eventlog" + "golang.org/x/sys/windows/svc/mgr" +) + +const ( + serviceName = "OlmWireguardService" + serviceDisplayName = "Olm WireGuard VPN Service" + serviceDescription = "Olm WireGuard VPN client service for secure network connectivity" +) + +// Global variable to store service arguments +var serviceArgs []string + +// getServiceArgsPath returns the path where service arguments are stored +func getServiceArgsPath() string { + logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm") + return filepath.Join(logDir, "service_args.json") +} + +// saveServiceArgs saves the service arguments to a file +func saveServiceArgs(args []string) error { + logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm") + err := os.MkdirAll(logDir, 0755) + if err != nil { + return fmt.Errorf("failed to create config directory: %v", err) + } + + argsPath := getServiceArgsPath() + data, err := json.Marshal(args) + if err != nil { + return fmt.Errorf("failed to marshal service args: %v", err) + } + + err = os.WriteFile(argsPath, data, 0644) + if err != nil { + return fmt.Errorf("failed to write service args: %v", err) + } + + return nil +} + +// loadServiceArgs loads the service arguments from a file +func loadServiceArgs() ([]string, error) { + argsPath := getServiceArgsPath() + data, err := os.ReadFile(argsPath) + if err != nil { + if os.IsNotExist(err) { + return []string{}, nil // Return empty args if file doesn't exist + } + return nil, fmt.Errorf("failed to read service args: %v", err) + } + + // delete the file after reading + err = os.Remove(argsPath) + if err != nil { + return nil, fmt.Errorf("failed to delete service args file: %v", err) + } + + var args []string + err = json.Unmarshal(data, &args) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal service args: %v", err) + } + + return args, nil +} + +type olmService struct { + elog debug.Log + ctx context.Context + stop context.CancelFunc + args []string +} + +func (s *olmService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (bool, uint32) { + const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown + changes <- svc.Status{State: svc.StartPending} + + s.elog.Info(1, "Service Execute called, starting main logic") + + // Load saved service arguments + savedArgs, err := loadServiceArgs() + if err != nil { + s.elog.Error(1, fmt.Sprintf("Failed to load service args: %v", err)) + // Continue with empty args if loading fails + savedArgs = []string{} + } + s.args = savedArgs + + // Start the main olm functionality + olmDone := make(chan struct{}) + go func() { + s.runOlm() + close(olmDone) + }() + + changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted} + s.elog.Info(1, "Service status set to Running") + + for { + select { + case c := <-r: + switch c.Cmd { + case svc.Interrogate: + changes <- c.CurrentStatus + case svc.Stop, svc.Shutdown: + s.elog.Info(1, "Service stopping") + changes <- svc.Status{State: svc.StopPending} + if s.stop != nil { + s.stop() + } + // Wait for main logic to finish or timeout + select { + case <-olmDone: + s.elog.Info(1, "Main logic finished gracefully") + case <-time.After(10 * time.Second): + s.elog.Info(1, "Timeout waiting for main logic to finish") + } + return false, 0 + default: + s.elog.Error(1, fmt.Sprintf("Unexpected control request #%d", c)) + } + case <-olmDone: + s.elog.Info(1, "Main olm logic completed, stopping service") + changes <- svc.Status{State: svc.StopPending} + return false, 0 + } + } +} + +func (s *olmService) runOlm() { + // Create a context that can be cancelled when the service stops + s.ctx, s.stop = context.WithCancel(context.Background()) + + // Setup logging for service mode + s.elog.Info(1, "Starting Olm main logic") + + // Run the main olm logic and wait for it to complete + done := make(chan struct{}) + go func() { + defer func() { + if r := recover(); r != nil { + s.elog.Error(1, fmt.Sprintf("Olm panic: %v", r)) + } + close(done) + }() + + // Call the main olm function with stored arguments + runOlmMainWithArgs(s.ctx, s.args) + }() + + // Wait for either context cancellation or main logic completion + select { + case <-s.ctx.Done(): + s.elog.Info(1, "Olm service context cancelled") + case <-done: + s.elog.Info(1, "Olm main logic completed") + } +} + +func runService(name string, isDebug bool, args []string) { + var err error + var elog debug.Log + + if isDebug { + elog = debug.New(name) + fmt.Printf("Starting %s service in debug mode\n", name) + } else { + elog, err = eventlog.Open(name) + if err != nil { + fmt.Printf("Failed to open event log: %v\n", err) + return + } + } + defer elog.Close() + + elog.Info(1, fmt.Sprintf("Starting %s service", name)) + run := svc.Run + if isDebug { + run = debug.Run + } + + service := &olmService{elog: elog, args: args} + err = run(name, service) + if err != nil { + elog.Error(1, fmt.Sprintf("%s service failed: %v", name, err)) + if isDebug { + fmt.Printf("Service failed: %v\n", err) + } + return + } + elog.Info(1, fmt.Sprintf("%s service stopped", name)) + if isDebug { + fmt.Printf("%s service stopped\n", name) + } +} + +func installService() error { + exepath, err := os.Executable() + if err != nil { + return fmt.Errorf("failed to get executable path: %v", err) + } + + m, err := mgr.Connect() + if err != nil { + return fmt.Errorf("failed to connect to service manager: %v", err) + } + defer m.Disconnect() + + s, err := m.OpenService(serviceName) + if err == nil { + s.Close() + return fmt.Errorf("service %s already exists", serviceName) + } + + config := mgr.Config{ + ServiceType: 0x10, // SERVICE_WIN32_OWN_PROCESS + StartType: mgr.StartManual, + ErrorControl: mgr.ErrorNormal, + DisplayName: serviceDisplayName, + Description: serviceDescription, + BinaryPathName: exepath, + } + + s, err = m.CreateService(serviceName, exepath, config) + if err != nil { + return fmt.Errorf("failed to create service: %v", err) + } + defer s.Close() + + err = eventlog.InstallAsEventCreate(serviceName, eventlog.Error|eventlog.Warning|eventlog.Info) + if err != nil { + s.Delete() + return fmt.Errorf("failed to install event log: %v", err) + } + + return nil +} + +func removeService() error { + m, err := mgr.Connect() + if err != nil { + return fmt.Errorf("failed to connect to service manager: %v", err) + } + defer m.Disconnect() + + s, err := m.OpenService(serviceName) + if err != nil { + return fmt.Errorf("service %s is not installed", serviceName) + } + defer s.Close() + + // Stop the service if it's running + status, err := s.Query() + if err != nil { + return fmt.Errorf("failed to query service status: %v", err) + } + + if status.State != svc.Stopped { + _, err = s.Control(svc.Stop) + if err != nil { + return fmt.Errorf("failed to stop service: %v", err) + } + + // Wait for service to stop + timeout := time.Now().Add(30 * time.Second) + for status.State != svc.Stopped { + if timeout.Before(time.Now()) { + return fmt.Errorf("timeout waiting for service to stop") + } + time.Sleep(300 * time.Millisecond) + status, err = s.Query() + if err != nil { + return fmt.Errorf("failed to query service status: %v", err) + } + } + } + + err = s.Delete() + if err != nil { + return fmt.Errorf("failed to delete service: %v", err) + } + + err = eventlog.Remove(serviceName) + if err != nil { + return fmt.Errorf("failed to remove event log: %v", err) + } + + return nil +} + +func startService(args []string) error { + // Save the service arguments before starting + if len(args) > 0 { + err := saveServiceArgs(args) + if err != nil { + return fmt.Errorf("failed to save service args: %v", err) + } + } + + m, err := mgr.Connect() + if err != nil { + return fmt.Errorf("failed to connect to service manager: %v", err) + } + defer m.Disconnect() + + s, err := m.OpenService(serviceName) + if err != nil { + return fmt.Errorf("service %s is not installed", serviceName) + } + defer s.Close() + + err = s.Start() + if err != nil { + return fmt.Errorf("failed to start service: %v", err) + } + + return nil +} + +func stopService() error { + m, err := mgr.Connect() + if err != nil { + return fmt.Errorf("failed to connect to service manager: %v", err) + } + defer m.Disconnect() + + s, err := m.OpenService(serviceName) + if err != nil { + return fmt.Errorf("service %s is not installed", serviceName) + } + defer s.Close() + + status, err := s.Control(svc.Stop) + if err != nil { + return fmt.Errorf("failed to stop service: %v", err) + } + + timeout := time.Now().Add(30 * time.Second) + for status.State != svc.Stopped { + if timeout.Before(time.Now()) { + return fmt.Errorf("timeout waiting for service to stop") + } + time.Sleep(300 * time.Millisecond) + status, err = s.Query() + if err != nil { + return fmt.Errorf("failed to query service status: %v", err) + } + } + + return nil +} + +func debugService(args []string) error { + // Save the service arguments before starting + if len(args) > 0 { + err := saveServiceArgs(args) + if err != nil { + return fmt.Errorf("failed to save service args: %v", err) + } + } + + // fmt.Printf("Starting service in debug mode...\n") + + // Start the service + err := startService([]string{}) // Pass empty args since we already saved them + if err != nil { + return fmt.Errorf("failed to start service: %v", err) + } + + // fmt.Printf("Service started. Watching logs (Press Ctrl+C to stop watching)...\n") + // fmt.Printf("================================================================================\n") + + // Watch the log file + return watchLogFile(true) +} + +func watchLogFile(end bool) error { + logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm", "logs") + logPath := filepath.Join(logDir, "olm.log") + + // Ensure the log directory exists + err := os.MkdirAll(logDir, 0755) + if err != nil { + return fmt.Errorf("failed to create log directory: %v", err) + } + + // Wait for the log file to be created if it doesn't exist + var file *os.File + for i := 0; i < 30; i++ { // Wait up to 15 seconds + file, err = os.Open(logPath) + if err == nil { + break + } + if i == 0 { + fmt.Printf("Waiting for log file to be created...\n") + } + time.Sleep(500 * time.Millisecond) + } + + if err != nil { + return fmt.Errorf("failed to open log file after waiting: %v", err) + } + defer file.Close() + + // Seek to the end of the file to only show new logs + _, err = file.Seek(0, 2) + if err != nil { + return fmt.Errorf("failed to seek to end of file: %v", err) + } + + // Set up signal handling for graceful exit + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + + // Create a ticker to check for new content + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + buffer := make([]byte, 4096) + + for { + select { + case <-sigCh: + fmt.Printf("\n\nStopping log watch...\n") + // stop the service if needed + if end { + if err := stopService(); err != nil { + fmt.Printf("Failed to stop service: %v\n", err) + } + } + fmt.Printf("Log watch stopped.\n") + return nil + case <-ticker.C: + // Read new content + n, err := file.Read(buffer) + if err != nil && err != io.EOF { + // Try to reopen the file in case it was recreated + file.Close() + file, err = os.Open(logPath) + if err != nil { + return fmt.Errorf("error reopening log file: %v", err) + } + continue + } + + if n > 0 { + // Print the new content + fmt.Print(string(buffer[:n])) + } + } + } +} + +func getServiceStatus() (string, error) { + m, err := mgr.Connect() + if err != nil { + return "", fmt.Errorf("failed to connect to service manager: %v", err) + } + defer m.Disconnect() + + s, err := m.OpenService(serviceName) + if err != nil { + return "Not Installed", nil + } + defer s.Close() + + status, err := s.Query() + if err != nil { + return "", fmt.Errorf("failed to query service status: %v", err) + } + + switch status.State { + case svc.Stopped: + return "Stopped", nil + case svc.StartPending: + return "Starting", nil + case svc.StopPending: + return "Stopping", nil + case svc.Running: + return "Running", nil + case svc.ContinuePending: + return "Continue Pending", nil + case svc.PausePending: + return "Pause Pending", nil + case svc.Paused: + return "Paused", nil + default: + return "Unknown", nil + } +} + +func isWindowsService() bool { + isWindowsService, err := svc.IsWindowsService() + return err == nil && isWindowsService +} + +func setupWindowsEventLog() { + // Create log directory if it doesn't exist + logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm", "logs") + err := os.MkdirAll(logDir, 0755) + if err != nil { + fmt.Printf("Failed to create log directory: %v\n", err) + return + } + + logFile := filepath.Join(logDir, "olm.log") + file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + if err != nil { + fmt.Printf("Failed to open log file: %v\n", err) + return + } + + // Set the custom logger output + logger.GetLogger().SetOutput(file) + + log.Printf("Olm service logging initialized - log file: %s", logFile) +} diff --git a/unix.go b/unix.go new file mode 100644 index 0000000..3a9c09e --- /dev/null +++ b/unix.go @@ -0,0 +1,35 @@ +//go:build !windows + +package main + +import ( + "net" + "os" + "strconv" + + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/ipc" + "golang.zx2c4.com/wireguard/tun" +) + +func createTUNFromFD(tunFdStr string, mtuInt int) (tun.Device, error) { + fd, err := strconv.ParseUint(tunFdStr, 10, 32) + if err != nil { + return nil, err + } + + err = unix.SetNonblock(int(fd), true) + if err != nil { + return nil, err + } + + file := os.NewFile(uintptr(fd), "") + return tun.CreateTUNFromFile(file, mtuInt) +} +func uapiOpen(interfaceName string) (*os.File, error) { + return ipc.UAPIOpen(interfaceName) +} + +func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) { + return ipc.UAPIListen(interfaceName, fileUAPI) +} diff --git a/websocket/client.go b/websocket/client.go deleted file mode 100644 index 8a7d3f9..0000000 --- a/websocket/client.go +++ /dev/null @@ -1,351 +0,0 @@ -package websocket - -import ( - "bytes" - "encoding/json" - "fmt" - "net/http" - "net/url" - "strings" - "sync" - "time" - - "github.com/fosrl/newt/logger" - - "github.com/gorilla/websocket" -) - -type Client struct { - conn *websocket.Conn - config *Config - baseURL string - handlers map[string]MessageHandler - done chan struct{} - handlersMux sync.RWMutex - - reconnectInterval time.Duration - isConnected bool - reconnectMux sync.RWMutex - - onConnect func() error -} - -type ClientOption func(*Client) - -type MessageHandler func(message WSMessage) - -// WithBaseURL sets the base URL for the client -func WithBaseURL(url string) ClientOption { - return func(c *Client) { - c.baseURL = url - } -} - -func (c *Client) OnConnect(callback func() error) { - c.onConnect = callback -} - -// NewClient creates a new Newt client -func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*Client, error) { - config := &Config{ - NewtID: newtID, - Secret: secret, - Endpoint: endpoint, - } - - client := &Client{ - config: config, - baseURL: endpoint, // default value - handlers: make(map[string]MessageHandler), - done: make(chan struct{}), - reconnectInterval: 10 * time.Second, - isConnected: false, - } - - // Apply options before loading config - for _, opt := range opts { - opt(client) - } - - // Load existing config if available - if err := client.loadConfig(); err != nil { - return nil, fmt.Errorf("failed to load config: %w", err) - } - - return client, nil -} - -// Connect establishes the WebSocket connection -func (c *Client) Connect() error { - go c.connectWithRetry() - return nil -} - -// Close closes the WebSocket connection -func (c *Client) Close() error { - close(c.done) - if c.conn != nil { - return c.conn.Close() - } - - // stop the ping monitor - c.setConnected(false) - - return nil -} - -// SendMessage sends a message through the WebSocket connection -func (c *Client) SendMessage(messageType string, data interface{}) error { - if c.conn == nil { - return fmt.Errorf("not connected") - } - - msg := WSMessage{ - Type: messageType, - Data: data, - } - - return c.conn.WriteJSON(msg) -} - -// RegisterHandler registers a handler for a specific message type -func (c *Client) RegisterHandler(messageType string, handler MessageHandler) { - c.handlersMux.Lock() - defer c.handlersMux.Unlock() - c.handlers[messageType] = handler -} - -// readPump pumps messages from the WebSocket connection -func (c *Client) readPump() { - defer c.conn.Close() - - for { - select { - case <-c.done: - return - default: - var msg WSMessage - err := c.conn.ReadJSON(&msg) - if err != nil { - return - } - - c.handlersMux.RLock() - if handler, ok := c.handlers[msg.Type]; ok { - handler(msg) - } - c.handlersMux.RUnlock() - } - } -} - -func (c *Client) getToken() (string, error) { - // Parse the base URL to ensure we have the correct hostname - baseURL, err := url.Parse(c.baseURL) - if err != nil { - return "", fmt.Errorf("failed to parse base URL: %w", err) - } - - // Ensure we have the base URL without trailing slashes - baseEndpoint := strings.TrimRight(baseURL.String(), "/") - - // If we already have a token, try to use it - if c.config.Token != "" { - tokenCheckData := map[string]interface{}{ - "newtId": c.config.NewtID, - "secret": c.config.Secret, - "token": c.config.Token, - } - jsonData, err := json.Marshal(tokenCheckData) - if err != nil { - return "", fmt.Errorf("failed to marshal token check data: %w", err) - } - - // Create a new request - req, err := http.NewRequest( - "POST", - baseEndpoint+"/api/v1/auth/newt/get-token", - bytes.NewBuffer(jsonData), - ) - if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) - } - - // Set headers - req.Header.Set("Content-Type", "application/json") - req.Header.Set("X-CSRF-Token", "x-csrf-protection") - - // Make the request - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return "", fmt.Errorf("failed to check token validity: %w", err) - } - defer resp.Body.Close() - - var tokenResp TokenResponse - if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { - return "", fmt.Errorf("failed to decode token check response: %w", err) - } - - // If token is still valid, return it - if tokenResp.Success && tokenResp.Message == "Token session already valid" { - return c.config.Token, nil - } - } - - // Get a new token - tokenData := map[string]interface{}{ - "newtId": c.config.NewtID, - "secret": c.config.Secret, - } - jsonData, err := json.Marshal(tokenData) - if err != nil { - return "", fmt.Errorf("failed to marshal token request data: %w", err) - } - - // Create a new request - req, err := http.NewRequest( - "POST", - baseEndpoint+"/api/v1/auth/newt/get-token", - bytes.NewBuffer(jsonData), - ) - if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) - } - - // Set headers - req.Header.Set("Content-Type", "application/json") - req.Header.Set("X-CSRF-Token", "x-csrf-protection") - - // Make the request - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return "", fmt.Errorf("failed to request new token: %w", err) - } - defer resp.Body.Close() - - var tokenResp TokenResponse - if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { - return "", fmt.Errorf("failed to decode token response: %w", err) - } - - if !tokenResp.Success { - return "", fmt.Errorf("failed to get token: %s", tokenResp.Message) - } - - if tokenResp.Data.Token == "" { - return "", fmt.Errorf("received empty token from server") - } - - return tokenResp.Data.Token, nil -} - -func (c *Client) connectWithRetry() { - for { - select { - case <-c.done: - return - default: - err := c.establishConnection() - if err != nil { - logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval) - time.Sleep(c.reconnectInterval) - continue - } - return - } - } -} - -func (c *Client) establishConnection() error { - // Get token for authentication - token, err := c.getToken() - if err != nil { - return fmt.Errorf("failed to get token: %w", err) - } - - // Parse the base URL to determine protocol and hostname - baseURL, err := url.Parse(c.baseURL) - if err != nil { - return fmt.Errorf("failed to parse base URL: %w", err) - } - - // Determine WebSocket protocol based on HTTP protocol - wsProtocol := "wss" - if baseURL.Scheme == "http" { - wsProtocol = "ws" - } - - // Create WebSocket URL - wsURL := fmt.Sprintf("%s://%s/api/v1/ws", wsProtocol, baseURL.Host) - u, err := url.Parse(wsURL) - if err != nil { - return fmt.Errorf("failed to parse WebSocket URL: %w", err) - } - - // Add token to query parameters - q := u.Query() - q.Set("token", token) - u.RawQuery = q.Encode() - - // Connect to WebSocket - conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) - if err != nil { - return fmt.Errorf("failed to connect to WebSocket: %w", err) - } - - c.conn = conn - c.setConnected(true) - - // Start the ping monitor - go c.pingMonitor() - // Start the read pump - go c.readPump() - - if c.onConnect != nil { - err := c.saveConfig() - if err != nil { - logger.Error("Failed to save config: %v", err) - } - if err := c.onConnect(); err != nil { - logger.Error("OnConnect callback failed: %v", err) - } - } - - return nil -} - -func (c *Client) pingMonitor() { - ticker := time.NewTicker(30 * time.Second) - defer ticker.Stop() - - for { - select { - case <-c.done: - return - case <-ticker.C: - if err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil { - logger.Error("Ping failed: %v", err) - c.reconnect() - return - } - } - } -} - -func (c *Client) reconnect() { - c.setConnected(false) - if c.conn != nil { - c.conn.Close() - } - - go c.connectWithRetry() -} - -func (c *Client) setConnected(status bool) { - c.reconnectMux.Lock() - defer c.reconnectMux.Unlock() - c.isConnected = status -} diff --git a/websocket/config.go b/websocket/config.go deleted file mode 100644 index 794ff1e..0000000 --- a/websocket/config.go +++ /dev/null @@ -1,72 +0,0 @@ -package websocket - -import ( - "encoding/json" - "log" - "os" - "path/filepath" - "runtime" -) - -func getConfigPath() string { - var configDir string - switch runtime.GOOS { - case "darwin": - configDir = filepath.Join(os.Getenv("HOME"), "Library", "Application Support", "newt-client") - case "windows": - configDir = filepath.Join(os.Getenv("APPDATA"), "newt-client") - default: // linux and others - configDir = filepath.Join(os.Getenv("HOME"), ".config", "newt-client") - } - - if err := os.MkdirAll(configDir, 0755); err != nil { - log.Printf("Failed to create config directory: %v", err) - } - - return filepath.Join(configDir, "config.json") -} - -func (c *Client) loadConfig() error { - if c.config.NewtID != "" && c.config.Secret != "" && c.config.Endpoint != "" { - return nil - } - - configPath := getConfigPath() - data, err := os.ReadFile(configPath) - if err != nil { - if os.IsNotExist(err) { - return nil - } - return err - } - - var config Config - if err := json.Unmarshal(data, &config); err != nil { - return err - } - - if c.config.NewtID == "" { - c.config.NewtID = config.NewtID - } - if c.config.Token == "" { - c.config.Token = config.Token - } - if c.config.Secret == "" { - c.config.Secret = config.Secret - } - if c.config.Endpoint == "" { - c.config.Endpoint = config.Endpoint - c.baseURL = config.Endpoint - } - - return nil -} - -func (c *Client) saveConfig() error { - configPath := getConfigPath() - data, err := json.MarshalIndent(c.config, "", " ") - if err != nil { - return err - } - return os.WriteFile(configPath, data, 0644) -} diff --git a/websocket/types.go b/websocket/types.go deleted file mode 100644 index 084465a..0000000 --- a/websocket/types.go +++ /dev/null @@ -1,21 +0,0 @@ -package websocket - -type Config struct { - NewtID string `json:"newtId"` - Secret string `json:"secret"` - Token string `json:"token"` - Endpoint string `json:"endpoint"` -} - -type TokenResponse struct { - Data struct { - Token string `json:"token"` - } `json:"data"` - Success bool `json:"success"` - Message string `json:"message"` -} - -type WSMessage struct { - Type string `json:"type"` - Data interface{} `json:"data"` -} diff --git a/wgtester/wgtester.go b/wgtester/wgtester.go new file mode 100644 index 0000000..28ffdba --- /dev/null +++ b/wgtester/wgtester.go @@ -0,0 +1,260 @@ +package wgtester + +import ( + "context" + "encoding/binary" + "net" + "sync" + "time" + + "github.com/fosrl/newt/logger" +) + +const ( + // Magic bytes to identify our packets + magicHeader uint32 = 0xDEADBEEF + // Request packet type + packetTypeRequest uint8 = 1 + // Response packet type + packetTypeResponse uint8 = 2 + // Packet format: + // - 4 bytes: magic header (0xDEADBEEF) + // - 1 byte: packet type (1 = request, 2 = response) + // - 8 bytes: timestamp (for round-trip timing) + packetSize = 13 +) + +// Client handles checking connectivity to a server +type Client struct { + conn *net.UDPConn + serverAddr string + monitorRunning bool + monitorLock sync.Mutex + connLock sync.Mutex // Protects connection operations + shutdownCh chan struct{} + packetInterval time.Duration + timeout time.Duration + maxAttempts int +} + +// ConnectionStatus represents the current connection state +type ConnectionStatus struct { + Connected bool + RTT time.Duration +} + +// NewClient creates a new connection test client +func NewClient(serverAddr string) (*Client, error) { + return &Client{ + serverAddr: serverAddr, + shutdownCh: make(chan struct{}), + packetInterval: 2 * time.Second, + timeout: 500 * time.Millisecond, // Timeout for individual packets + maxAttempts: 3, // Default max attempts + }, nil +} + +// SetPacketInterval changes how frequently packets are sent in monitor mode +func (c *Client) SetPacketInterval(interval time.Duration) { + c.packetInterval = interval +} + +// SetTimeout changes the timeout for waiting for responses +func (c *Client) SetTimeout(timeout time.Duration) { + c.timeout = timeout +} + +// SetMaxAttempts changes the maximum number of attempts for TestConnection +func (c *Client) SetMaxAttempts(attempts int) { + c.maxAttempts = attempts +} + +// Close cleans up client resources +func (c *Client) Close() { + c.StopMonitor() + + c.connLock.Lock() + defer c.connLock.Unlock() + + if c.conn != nil { + c.conn.Close() + c.conn = nil + } +} + +// ensureConnection makes sure we have an active UDP connection +func (c *Client) ensureConnection() error { + c.connLock.Lock() + defer c.connLock.Unlock() + + if c.conn != nil { + return nil + } + + serverAddr, err := net.ResolveUDPAddr("udp", c.serverAddr) + if err != nil { + return err + } + + c.conn, err = net.DialUDP("udp", nil, serverAddr) + if err != nil { + return err + } + + return nil +} + +// TestConnection checks if the connection to the server is working +// Returns true if connected, false otherwise +func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { + if err := c.ensureConnection(); err != nil { + logger.Warn("Failed to ensure connection: %v", err) + return false, 0 + } + + // Prepare packet buffer + packet := make([]byte, packetSize) + binary.BigEndian.PutUint32(packet[0:4], magicHeader) + packet[4] = packetTypeRequest + + // Send multiple attempts as specified + for attempt := 0; attempt < c.maxAttempts; attempt++ { + select { + case <-ctx.Done(): + return false, 0 + default: + // Add current timestamp to packet + timestamp := time.Now().UnixNano() + binary.BigEndian.PutUint64(packet[5:13], uint64(timestamp)) + + // Lock the connection for the entire send/receive operation + c.connLock.Lock() + + // Check if connection is still valid after acquiring lock + if c.conn == nil { + c.connLock.Unlock() + return false, 0 + } + + logger.Debug("Attempting to send monitor packet to %s", c.serverAddr) + _, err := c.conn.Write(packet) + if err != nil { + c.connLock.Unlock() + logger.Info("Error sending packet: %v", err) + continue + } + logger.Debug("Successfully sent monitor packet") + + // Set read deadline + c.conn.SetReadDeadline(time.Now().Add(c.timeout)) + + // Wait for response + responseBuffer := make([]byte, packetSize) + n, err := c.conn.Read(responseBuffer) + c.connLock.Unlock() + + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + // Timeout, try next attempt + time.Sleep(100 * time.Millisecond) // Brief pause between attempts + continue + } + logger.Error("Error reading response: %v", err) + continue + } + + if n != packetSize { + continue // Malformed packet + } + + // Verify response + magic := binary.BigEndian.Uint32(responseBuffer[0:4]) + packetType := responseBuffer[4] + if magic != magicHeader || packetType != packetTypeResponse { + continue // Not our response + } + + // Extract the original timestamp and calculate RTT + sentTimestamp := int64(binary.BigEndian.Uint64(responseBuffer[5:13])) + rtt := time.Duration(time.Now().UnixNano() - sentTimestamp) + + return true, rtt + } + } + + return false, 0 +} + +// TestConnectionWithTimeout tries to test connection with a timeout +// Returns true if connected, false otherwise +func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return c.TestConnection(ctx) +} + +// MonitorCallback is the function type for connection status change callbacks +type MonitorCallback func(status ConnectionStatus) + +// StartMonitor begins monitoring the connection and calls the callback +// when the connection status changes +func (c *Client) StartMonitor(callback MonitorCallback) error { + c.monitorLock.Lock() + defer c.monitorLock.Unlock() + + if c.monitorRunning { + logger.Info("Monitor already running") + return nil // Already running + } + + if err := c.ensureConnection(); err != nil { + return err + } + + c.monitorRunning = true + c.shutdownCh = make(chan struct{}) + + go func() { + var lastConnected bool + firstRun := true + + ticker := time.NewTicker(c.packetInterval) + defer ticker.Stop() + + for { + select { + case <-c.shutdownCh: + return + case <-ticker.C: + ctx, cancel := context.WithTimeout(context.Background(), c.timeout) + connected, rtt := c.TestConnection(ctx) + cancel() + + // Callback if status changed or it's the first check + if connected != lastConnected || firstRun { + callback(ConnectionStatus{ + Connected: connected, + RTT: rtt, + }) + lastConnected = connected + firstRun = false + } + } + } + }() + + return nil +} + +// StopMonitor stops the connection monitoring +func (c *Client) StopMonitor() { + c.monitorLock.Lock() + defer c.monitorLock.Unlock() + + if !c.monitorRunning { + return + } + + close(c.shutdownCh) + c.monitorRunning = false +} diff --git a/windows.go b/windows.go new file mode 100644 index 0000000..032096b --- /dev/null +++ b/windows.go @@ -0,0 +1,25 @@ +//go:build windows + +package main + +import ( + "errors" + "net" + "os" + + "golang.zx2c4.com/wireguard/ipc" + "golang.zx2c4.com/wireguard/tun" +) + +func createTUNFromFD(tunFdStr string, mtuInt int) (tun.Device, error) { + return nil, errors.New("CreateTUNFromFile not supported on Windows") +} + +func uapiOpen(interfaceName string) (*os.File, error) { + return nil, nil +} + +func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) { + // On Windows, UAPIListen only takes one parameter + return ipc.UAPIListen(interfaceName) +}