Files
newt/main.go
2025-06-18 22:54:13 -04:00

774 lines
21 KiB
Go

package main
import (
"encoding/json"
"flag"
"fmt"
"math"
"net/http"
"net/netip"
"os"
"os/signal"
"runtime"
"strconv"
"strings"
"syscall"
"time"
"github.com/fosrl/newt/docker"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/proxy"
"github.com/fosrl/newt/websocket"
"github.com/fosrl/newt/wg"
"github.com/fosrl/newt/wgtester"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type WgData struct {
Endpoint string `json:"endpoint"`
PublicKey string `json:"publicKey"`
ServerIP string `json:"serverIP"`
TunnelIP string `json:"tunnelIP"`
Targets TargetsByType `json:"targets"`
}
type TargetsByType struct {
UDP []string `json:"udp"`
TCP []string `json:"tcp"`
}
type TargetData struct {
Targets []string `json:"targets"`
}
type ExitNodeData struct {
ExitNodes []ExitNode `json:"exitNodes"`
}
// ExitNode represents an exit node with an ID, endpoint, and weight.
type ExitNode struct {
ID int `json:"exitNodeId"`
Name string `json:"exitNodeName"`
Endpoint string `json:"endpoint"`
Weight float64 `json:"weight"`
WasPreviouslyConnected bool `json:"wasPreviouslyConnected"`
}
var (
endpoint string
id string
secret string
mtu string
mtuInt int
dns string
privateKey wgtypes.Key
err error
logLevel string
interfaceName string
generateAndSaveKeyTo string
rm bool
acceptClients bool
updownScript string
tlsPrivateKey string
dockerSocket string
publicKey wgtypes.Key
)
func main() {
// if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values
endpoint = os.Getenv("PANGOLIN_ENDPOINT")
id = os.Getenv("NEWT_ID")
secret = os.Getenv("NEWT_SECRET")
mtu = os.Getenv("MTU")
dns = os.Getenv("DNS")
logLevel = os.Getenv("LOG_LEVEL")
updownScript = os.Getenv("UPDOWN_SCRIPT")
interfaceName = os.Getenv("INTERFACE")
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
rm = os.Getenv("RM") == "true"
acceptClients = os.Getenv("ACCEPT_CLIENTS") == "true"
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
dockerSocket = os.Getenv("DOCKER_SOCKET")
if endpoint == "" {
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
}
if id == "" {
flag.StringVar(&id, "id", "", "Newt ID")
}
if secret == "" {
flag.StringVar(&secret, "secret", "", "Newt secret")
}
if mtu == "" {
flag.StringVar(&mtu, "mtu", "1280", "MTU to use")
}
if dns == "" {
flag.StringVar(&dns, "dns", "8.8.8.8", "DNS server to use")
}
if logLevel == "" {
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
}
if updownScript == "" {
flag.StringVar(&updownScript, "updown", "", "Path to updown script to be called when targets are added or removed")
}
if interfaceName == "" {
flag.StringVar(&interfaceName, "interface", "wg1", "Name of the WireGuard interface")
}
if generateAndSaveKeyTo == "" {
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "/tmp/newtkey", "Path to save generated private key")
}
flag.BoolVar(&rm, "rm", false, "Remove the WireGuard interface")
flag.BoolVar(&acceptClients, "accept-clients", false, "Accept clients on the WireGuard interface")
if tlsPrivateKey == "" {
flag.StringVar(&tlsPrivateKey, "tls-client-cert", "", "Path to client certificate used for mTLS")
}
if dockerSocket == "" {
flag.StringVar(&dockerSocket, "docker-socket", "", "Path to Docker socket (typically /var/run/docker.sock)")
}
// do a --version check
version := flag.Bool("version", false, "Print the version")
flag.Parse()
newtVersion := "Newt version replaceme"
if *version {
fmt.Println(newtVersion)
os.Exit(0)
} else {
logger.Info(newtVersion)
}
logger.Init()
loggerLevel := parseLogLevel(logLevel)
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
// parse the mtu string into an int
mtuInt, err = strconv.Atoi(mtu)
if err != nil {
logger.Fatal("Failed to parse MTU: %v", err)
}
privateKey, err = wgtypes.GeneratePrivateKey()
if err != nil {
logger.Fatal("Failed to generate private key: %v", err)
}
var opt websocket.ClientOption
if tlsPrivateKey != "" {
opt = websocket.WithTLSConfig(tlsPrivateKey)
}
// Create a new client
client, err := websocket.NewClient(
id, // CLI arg takes precedence
secret, // CLI arg takes precedence
endpoint,
opt,
)
if err != nil {
logger.Fatal("Failed to create client: %v", err)
}
var wgService *wg.WireGuardService
// Create TUN device and network stack
var tun tun.Device
var tnet *netstack.Net
var dev *device.Device
var pm *proxy.ProxyManager
var connected bool
var wgData WgData
var wgTesterServer *wgtester.Server
if acceptClients {
// make sure we are running on linux
if runtime.GOOS != "linux" {
logger.Fatal("Tunnel management is only supported on Linux right now!")
os.Exit(1)
}
var host = endpoint
if strings.HasPrefix(host, "http://") {
host = strings.TrimPrefix(host, "http://")
} else if strings.HasPrefix(host, "https://") {
host = strings.TrimPrefix(host, "https://")
}
host = strings.TrimSuffix(host, "/")
// Create WireGuard service
wgService, err = wg.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client)
if err != nil {
logger.Fatal("Failed to create WireGuard service: %v", err)
}
defer wgService.Close(rm)
wgTesterServer = wgtester.NewServer("0.0.0.0", wgService.Port, id) // TODO: maybe make this the same ip of the wg server?
err := wgTesterServer.Start()
if err != nil {
logger.Error("Failed to start WireGuard tester server: %v", err)
} else {
// Make sure to stop the server on exit
defer wgTesterServer.Stop()
}
}
pingStopChan := make(chan struct{})
defer close(pingStopChan)
// Register handlers for different message types
client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) {
logger.Info("Received registration message")
if connected {
// Stop proxy manager if running
if pm != nil {
pm.Stop()
pm = nil
}
// Close WireGuard device if running
if dev != nil {
dev.Close()
dev = nil
}
// Close TUN/netstack if running
if tnet != nil {
tnet = nil
}
if tun != nil {
tun.Close()
tun = nil
}
// Stop the ping check
close(pingStopChan)
// Mark as disconnected
connected = false
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &wgData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
if wgService != nil {
wgService.SetServerPubKey(wgData.PublicKey)
}
logger.Info("Received: %+v", msg)
tun, tnet, err = netstack.CreateNetTUN(
[]netip.Addr{netip.MustParseAddr(wgData.TunnelIP)},
[]netip.Addr{netip.MustParseAddr(dns)},
mtuInt)
if err != nil {
logger.Error("Failed to create TUN device: %v", err)
}
// Create WireGuard device
dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(
mapToWireGuardLogLevel(loggerLevel),
"wireguard: ",
))
endpoint, err := resolveDomain(wgData.Endpoint)
if err != nil {
logger.Error("Failed to resolve endpoint: %v", err)
return
}
// Configure WireGuard
config := fmt.Sprintf(`private_key=%s
public_key=%s
allowed_ip=%s/32
endpoint=%s
persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint)
err = dev.IpcSet(config)
if err != nil {
logger.Error("Failed to configure WireGuard device: %v", err)
}
// Bring up the device
err = dev.Up()
if err != nil {
logger.Error("Failed to bring up WireGuard device: %v", err)
}
logger.Info("WireGuard device created. Lets ping the server now...")
// Even if pingWithRetry returns an error, it will continue trying in the background
_ = pingWithRetry(tnet, wgData.ServerIP)
// Always mark as connected and start the proxy manager regardless of initial ping result
// as the pings will continue in the background
if !connected {
logger.Info("Starting ping check")
startPingCheck(tnet, wgData.ServerIP, client, pingStopChan)
}
// Create proxy manager
pm = proxy.NewProxyManager(tnet)
connected = true
// add the targets if there are any
if len(wgData.Targets.TCP) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP})
}
if len(wgData.Targets.UDP) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP})
}
// first make sure the wpgService has a port
if wgService != nil {
// add a udp proxy for localost and the wgService port
// TODO: make sure this port is not used in a target
pm.AddTarget("udp", wgData.TunnelIP, int(wgService.Port), fmt.Sprintf("127.0.0.1:%d", wgService.Port))
}
err = pm.Start()
if err != nil {
logger.Error("Failed to start proxy manager: %v", err)
}
})
client.RegisterHandler("newt/wg/terminate", func(msg websocket.WSMessage) {
logger.Info("Received disconnect message")
// Stop proxy manager if running
if pm != nil {
pm.Stop()
pm = nil
}
// Close WireGuard device if running
if dev != nil {
dev.Close()
dev = nil
}
// Close TUN/netstack if running
if tnet != nil {
tnet = nil
}
if tun != nil {
tun.Close()
tun = nil
}
// Stop the ping check
close(pingStopChan)
// Mark as disconnected
connected = false
logger.Info("Tunnel destroyed, ready for reconnection")
})
client.RegisterHandler("newt/ping/exitNodes", func(msg websocket.WSMessage) {
logger.Info("Received ping message")
// Parse the incoming list of exit nodes
// Exit nodes is a json
var exitNodeData ExitNodeData
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &exitNodeData); err != nil {
logger.Info("Error unmarshaling exit node data: %v", err)
return
}
exitNodes := exitNodeData.ExitNodes
if len(exitNodes) == 0 {
logger.Info("No exit nodes provided")
return
}
type nodeResult struct {
Node ExitNode
Latency time.Duration
Err error
}
results := make([]nodeResult, len(exitNodes))
for i, node := range exitNodes {
start := time.Now()
client := &http.Client{
Timeout: 5 * time.Second,
}
url := node.Endpoint
if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") {
url = "http://" + url
}
if !strings.HasSuffix(url, "/ping") {
url = strings.TrimRight(url, "/") + "/ping"
}
resp, err := client.Get(url)
latency := time.Since(start)
if err != nil {
logger.Warn("Failed to ping exit node %s (%s): %v", node.ID, url, err)
results[i] = nodeResult{Node: node, Latency: latency, Err: err}
continue
}
resp.Body.Close()
results[i] = nodeResult{Node: node, Latency: latency, Err: nil}
// logger.Info("Exit node %s latency: %v", node.Name, latency)
}
// we will need to tweak these
const (
latencyPenaltyExponent = 1.5 // make latency matter more
lastNodeScoreBoost = 1.10 // 10% preference for the last used node
scoreTolerancePercent = 5.0 // allow last node if within 5% of best score
)
var bestNode *ExitNode
var bestScore float64 = -1e12
var bestLatency time.Duration = 1e12
type ExitNodeScore struct {
Node ExitNode
Score float64
Latency time.Duration
}
var candidateNodes []ExitNodeScore
for _, res := range results {
if res.Err != nil || res.Node.Weight <= 0 {
continue
}
latencyMs := float64(res.Latency.Milliseconds())
score := res.Node.Weight / math.Pow(latencyMs, latencyPenaltyExponent)
// slight boost if this is the last used node
if res.Node.WasPreviouslyConnected == true {
score *= lastNodeScoreBoost
}
logger.Info("Exit node %s with score: %.2f (latency: %dms, weight: %.2f)", res.Node.Name, score, res.Latency.Milliseconds(), res.Node.Weight)
candidateNodes = append(candidateNodes, ExitNodeScore{Node: res.Node, Score: score, Latency: res.Latency})
if score > bestScore {
bestScore = score
bestLatency = res.Latency
bestNode = &res.Node
} else if score == bestScore && res.Latency < bestLatency {
bestLatency = res.Latency
bestNode = &res.Node
}
}
// check if last used node is close enough in score
for _, cand := range candidateNodes {
if cand.Node.WasPreviouslyConnected {
if bestScore-cand.Score <= bestScore*(scoreTolerancePercent/100.0) {
logger.Info("Sticking with last used exit node: %s (%s), score close enough to best", cand.Node.Name, cand.Node.Endpoint)
bestNode = &cand.Node
}
break
}
}
if bestNode == nil {
logger.Warn("No suitable exit node found")
return
}
logger.Info("Selected exit node: %s (%s)", bestNode.Name, bestNode.Endpoint)
err = client.SendMessage("newt/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"exitNodeId": bestNode.ID,
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return
}
})
client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData)
}
})
client.RegisterHandler("newt/udp/add", func(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData)
}
})
client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "remove", wgData.TunnelIP, "udp", targetData)
}
})
client.RegisterHandler("newt/tcp/remove", func(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "remove", wgData.TunnelIP, "tcp", targetData)
}
})
// Register handler for Docker socket check
client.RegisterHandler("newt/socket/check", func(msg websocket.WSMessage) {
logger.Info("Received Docker socket check request")
if dockerSocket == "" {
logger.Info("Docker socket path is not set")
err := client.SendMessage("newt/socket/status", map[string]interface{}{
"available": false,
"socketPath": dockerSocket,
})
if err != nil {
logger.Error("Failed to send Docker socket check response: %v", err)
}
return
}
// Check if Docker socket is available
isAvailable := docker.CheckSocket(dockerSocket)
// Send response back to server
err := client.SendMessage("newt/socket/status", map[string]interface{}{
"available": isAvailable,
"socketPath": dockerSocket,
})
if err != nil {
logger.Error("Failed to send Docker socket check response: %v", err)
} else {
logger.Info("Docker socket check response sent: available=%t", isAvailable)
}
})
// Register handler for Docker container listing
client.RegisterHandler("newt/socket/fetch", func(msg websocket.WSMessage) {
logger.Info("Received Docker container fetch request")
if dockerSocket == "" {
logger.Info("Docker socket path is not set")
return
}
// List Docker containers
containers, err := docker.ListContainers(dockerSocket)
if err != nil {
logger.Error("Failed to list Docker containers: %v", err)
return
}
// Send container list back to server
err = client.SendMessage("newt/socket/containers", map[string]interface{}{
"containers": containers,
})
if err != nil {
logger.Error("Failed to send Docker container list: %v", err)
} else {
logger.Info("Docker container list sent, count: %d", len(containers))
}
})
client.OnConnect(func() error {
publicKey = privateKey.PublicKey()
logger.Debug("Public key: %s", publicKey)
// request from the server the list of nodes to ping at newt/ping/request
err := client.SendMessage("newt/ping/request", map[string]interface{}{})
if err != nil {
logger.Error("Failed to send ping request: %v", err)
}
if wgService != nil {
wgService.LoadRemoteConfig()
}
logger.Info("Sent registration message")
return nil
})
client.OnTokenUpdate(func(token string) {
if wgService != nil {
wgService.SetToken(token)
}
})
// Connect to the WebSocket server
if err := client.Connect(); err != nil {
logger.Fatal("Failed to connect to server: %v", err)
}
defer client.Close()
// Wait for interrupt signal
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
dev.Close()
if wgService != nil {
wgService.Close(rm)
}
if wgTesterServer != nil {
wgTesterServer.Stop()
}
if pm != nil {
pm.Stop()
}
if client != nil {
client.Close()
}
logger.Info("Exiting...")
os.Exit(0)
}
func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client, stopChan chan struct{}) {
initialInterval := 10 * time.Second
maxInterval := 60 * time.Second
currentInterval := initialInterval
consecutiveFailures := 0
connectionLost := false
ticker := time.NewTicker(currentInterval)
defer ticker.Stop()
go func() {
for {
select {
case <-ticker.C:
_, err := ping(tnet, serverIP)
if err != nil {
consecutiveFailures++
// Check if this is the first failure (connection just lost)
if !connectionLost {
connectionLost = true
logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.")
logger.Warn("Please check your internet connection and ensure the Pangolin server is online.")
logger.Warn("Newt will continue reconnection attempts automatically when connectivity is restored.")
}
logger.Warn("Periodic ping failed (%d consecutive failures): %v",
consecutiveFailures, err)
logger.Warn("HINT: Do you have UDP port 51820 (or the port in config.yml) open on your Pangolin server?")
// Increase interval if we have consistent failures, with a maximum cap
if consecutiveFailures >= 5 && currentInterval < maxInterval {
// Increase by 50% each time, up to the maximum
currentInterval = time.Duration(float64(currentInterval) * 1.5)
if currentInterval > maxInterval {
currentInterval = maxInterval
}
ticker.Reset(currentInterval)
logger.Debug("Increased ping check interval to %v due to consecutive failures",
currentInterval)
// Restart the connection flow
err := client.SendMessage("newt/ping/request", map[string]interface{}{})
if err != nil {
logger.Error("Failed to send ping request: %v", err)
}
}
} else {
// Check if connection was previously lost and is now restored
if connectionLost {
connectionLost = false
logger.Info("Connection to server restored!")
}
// On success, if we've backed off, gradually return to normal interval
if currentInterval > initialInterval {
currentInterval = time.Duration(float64(currentInterval) * 0.8)
if currentInterval < initialInterval {
currentInterval = initialInterval
}
ticker.Reset(currentInterval)
logger.Info("Decreased ping check interval to %v after successful ping",
currentInterval)
}
consecutiveFailures = 0
}
case <-stopChan:
logger.Info("Stopping ping check")
return
}
}
}()
}