mirror of
https://github.com/fosrl/newt.git
synced 2026-02-08 05:56:40 +00:00
Update main.go
fixed some errors, This file should be ok.
This commit is contained in:
97
main.go
97
main.go
@@ -50,11 +50,16 @@ type TargetData struct {
|
||||
}
|
||||
|
||||
func fixKey(key string) string {
|
||||
// Remove any whitespace
|
||||
key = strings.TrimSpace(key)
|
||||
|
||||
// Decode from base64
|
||||
decoded, err := base64.StdEncoding.DecodeString(key)
|
||||
if err != nil {
|
||||
logger.Fatal("Error decoding base64: %v", err)
|
||||
}
|
||||
|
||||
// Convert to hex
|
||||
return hex.EncodeToString(decoded)
|
||||
}
|
||||
|
||||
@@ -110,7 +115,6 @@ func ping(tnet *netstack.Net, dst string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- CHANGED: added healthFile as parameter ---
|
||||
func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{}, healthFile string) {
|
||||
initialInterval := 10 * time.Second
|
||||
maxInterval := 60 * time.Second
|
||||
@@ -127,23 +131,28 @@ func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{},
|
||||
err := ping(tnet, serverIP)
|
||||
if err != nil {
|
||||
consecutiveFailures++
|
||||
logger.Warn("Periodic ping failed (%d consecutive failures): %v", consecutiveFailures, err)
|
||||
logger.Warn("Periodic ping failed (%d consecutive failures): %v",
|
||||
consecutiveFailures, err)
|
||||
logger.Warn("HINT: Do you have UDP port 51820 (or the port in config.yml) open on your Pangolin server?")
|
||||
// --- CHANGED: Only remove file if healthFile is set ---
|
||||
|
||||
// Only remove file if healthFile is set
|
||||
if consecutiveFailures >= 3 && healthFile != "" {
|
||||
_ = os.Remove(healthFile)
|
||||
}
|
||||
|
||||
// Increase interval if we have consistent failures, with a maximum cap
|
||||
if consecutiveFailures >= 3 && currentInterval < maxInterval {
|
||||
// Increase by 50% each time, up to the maximum
|
||||
currentInterval = time.Duration(float64(currentInterval) * 1.5)
|
||||
if currentInterval > maxInterval {
|
||||
currentInterval = maxInterval
|
||||
}
|
||||
ticker.Reset(currentInterval)
|
||||
logger.Info("Increased ping check interval to %v due to consecutive failures", currentInterval)
|
||||
logger.Info("Increased ping check interval to %v due to consecutive failures",
|
||||
currentInterval)
|
||||
}
|
||||
} else {
|
||||
// --- CHANGED: Only write file if healthFile is set ---
|
||||
// Only write file if healthFile is set
|
||||
if healthFile != "" {
|
||||
err := os.WriteFile(healthFile, []byte("ok"), 0644)
|
||||
if err != nil {
|
||||
@@ -157,7 +166,8 @@ func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{},
|
||||
currentInterval = initialInterval
|
||||
}
|
||||
ticker.Reset(currentInterval)
|
||||
logger.Info("Decreased ping check interval to %v after successful ping", currentInterval)
|
||||
logger.Info("Decreased ping check interval to %v after successful ping",
|
||||
currentInterval)
|
||||
}
|
||||
consecutiveFailures = 0
|
||||
}
|
||||
@@ -169,6 +179,7 @@ func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{},
|
||||
}()
|
||||
}
|
||||
|
||||
// Function to track connection status and trigger reconnection as needed
|
||||
func monitorConnectionStatus(tnet *netstack.Net, serverIP string, client *websocket.Client) {
|
||||
const checkInterval = 30 * time.Second
|
||||
connectionLost := false
|
||||
@@ -178,18 +189,27 @@ func monitorConnectionStatus(tnet *netstack.Net, serverIP string, client *websoc
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// Try a ping to see if connection is alive
|
||||
err := ping(tnet, serverIP)
|
||||
|
||||
if err != nil && !connectionLost {
|
||||
// We just lost connection
|
||||
connectionLost = true
|
||||
logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.")
|
||||
|
||||
// Notify the user they might need to check their network
|
||||
logger.Warn("Please check your internet connection and ensure the Pangolin server is online.")
|
||||
logger.Warn("Newt will continue reconnection attempts automatically when connectivity is restored.")
|
||||
} else if err == nil && connectionLost {
|
||||
// Connection has been restored
|
||||
connectionLost = false
|
||||
logger.Info("Connection to server restored!")
|
||||
|
||||
// Tell the server we're back
|
||||
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
||||
"publicKey": privateKey.PublicKey().String(),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logger.Error("Failed to send registration message after reconnection: %v", err)
|
||||
} else {
|
||||
@@ -204,25 +224,32 @@ func pingWithRetry(tnet *netstack.Net, dst string) error {
|
||||
const (
|
||||
initialMaxAttempts = 15
|
||||
initialRetryDelay = 2 * time.Second
|
||||
maxRetryDelay = 60 * time.Second
|
||||
maxRetryDelay = 60 * time.Second // Cap the maximum delay
|
||||
)
|
||||
|
||||
attempt := 1
|
||||
retryDelay := initialRetryDelay
|
||||
|
||||
// First try with the initial parameters
|
||||
logger.Info("Ping attempt %d", attempt)
|
||||
if err := ping(tnet, dst); err == nil {
|
||||
// Successful ping
|
||||
return nil
|
||||
} else {
|
||||
logger.Warn("Ping attempt %d failed: %v", attempt, err)
|
||||
}
|
||||
|
||||
// Start a goroutine that will attempt pings indefinitely with increasing delays
|
||||
go func() {
|
||||
attempt = 2
|
||||
attempt = 2 // Continue from attempt 2
|
||||
|
||||
for {
|
||||
logger.Info("Ping attempt %d", attempt)
|
||||
|
||||
if err := ping(tnet, dst); err != nil {
|
||||
logger.Warn("Ping attempt %d failed: %v", attempt, err)
|
||||
|
||||
// Increase delay after certain thresholds but cap it
|
||||
if attempt%5 == 0 && retryDelay < maxRetryDelay {
|
||||
retryDelay = time.Duration(float64(retryDelay) * 1.5)
|
||||
if retryDelay > maxRetryDelay {
|
||||
@@ -230,14 +257,18 @@ func pingWithRetry(tnet *netstack.Net, dst string) error {
|
||||
}
|
||||
logger.Info("Increasing ping retry delay to %v", retryDelay)
|
||||
}
|
||||
|
||||
time.Sleep(retryDelay)
|
||||
attempt++
|
||||
} else {
|
||||
// Successful ping
|
||||
logger.Info("Ping succeeded after %d attempts", attempt)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Return an error for the first batch of attempts (to maintain compatibility with existing code)
|
||||
return fmt.Errorf("initial ping attempts failed, continuing in background")
|
||||
}
|
||||
|
||||
@@ -254,7 +285,7 @@ func parseLogLevel(level string) logger.LogLevel {
|
||||
case "FATAL":
|
||||
return logger.FATAL
|
||||
default:
|
||||
return logger.INFO
|
||||
return logger.INFO // default to INFO if invalid level provided
|
||||
}
|
||||
}
|
||||
|
||||
@@ -262,6 +293,8 @@ func mapToWireGuardLogLevel(level logger.LogLevel) int {
|
||||
switch level {
|
||||
case logger.DEBUG:
|
||||
return device.LogLevelVerbose
|
||||
// case logger.INFO:
|
||||
// return device.LogLevel
|
||||
case logger.WARN:
|
||||
return device.LogLevelError
|
||||
case logger.ERROR, logger.FATAL:
|
||||
@@ -272,23 +305,32 @@ func mapToWireGuardLogLevel(level logger.LogLevel) int {
|
||||
}
|
||||
|
||||
func resolveDomain(domain string) (string, error) {
|
||||
// Check if there's a port in the domain
|
||||
host, port, err := net.SplitHostPort(domain)
|
||||
if err != nil {
|
||||
// No port found, use the domain as is
|
||||
host = domain
|
||||
port = ""
|
||||
}
|
||||
|
||||
// Remove any protocol prefix if present
|
||||
if strings.HasPrefix(host, "http://") {
|
||||
host = strings.TrimPrefix(host, "http://")
|
||||
} else if strings.HasPrefix(host, "https://") {
|
||||
host = strings.TrimPrefix(host, "https://")
|
||||
}
|
||||
|
||||
// Lookup IP addresses
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("DNS lookup failed: %v", err)
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return "", fmt.Errorf("no IP addresses found for domain %s", host)
|
||||
}
|
||||
|
||||
// Get the first IPv4 address if available
|
||||
var ipAddr string
|
||||
for _, ip := range ips {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
@@ -296,16 +338,20 @@ func resolveDomain(domain string) (string, error) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If no IPv4 found, use the first IP (might be IPv6)
|
||||
if ipAddr == "" {
|
||||
ipAddr = ips[0].String()
|
||||
}
|
||||
|
||||
// Add port back if it existed
|
||||
if port != "" {
|
||||
ipAddr = net.JoinHostPort(ipAddr, port)
|
||||
}
|
||||
|
||||
return ipAddr, nil
|
||||
}
|
||||
|
||||
// --- ADDED: healthFile variable ---
|
||||
var (
|
||||
endpoint string
|
||||
id string
|
||||
@@ -323,6 +369,7 @@ var (
|
||||
)
|
||||
|
||||
func main() {
|
||||
// if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values
|
||||
endpoint = os.Getenv("PANGOLIN_ENDPOINT")
|
||||
id = os.Getenv("NEWT_ID")
|
||||
secret = os.Getenv("NEWT_SECRET")
|
||||
@@ -361,12 +408,14 @@ func main() {
|
||||
if dockerSocket == "" {
|
||||
flag.StringVar(&dockerSocket, "docker-socket", "", "Path to Docker socket (typically /var/run/docker.sock)")
|
||||
}
|
||||
// --- ADDED: CLI flag for healthFile if not set by env ---
|
||||
// CLI flag for healthFile if not set by env
|
||||
if healthFile == "" {
|
||||
flag.StringVar(&healthFile, "health-file", "", "Path to health file (if unset, health file won’t be written)")
|
||||
}
|
||||
|
||||
// do a --version check
|
||||
version := flag.Bool("version", false, "Print the version")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
newtVersion := "Newt version replaceme"
|
||||
@@ -381,6 +430,7 @@ func main() {
|
||||
loggerLevel := parseLogLevel(logLevel)
|
||||
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
|
||||
|
||||
// parse the mtu string into an int
|
||||
mtuInt, err = strconv.Atoi(mtu)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to parse MTU: %v", err)
|
||||
@@ -394,13 +444,18 @@ func main() {
|
||||
if tlsPrivateKey != "" {
|
||||
opt = websocket.WithTLSConfig(tlsPrivateKey)
|
||||
}
|
||||
// Create a new client
|
||||
client, err := websocket.NewClient(
|
||||
id, secret, endpoint, opt,
|
||||
id, // CLI arg takes precedence
|
||||
secret, // CLI arg takes precedence
|
||||
endpoint,
|
||||
opt,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
// Create TUN device and network stack
|
||||
var tun tun.Device
|
||||
var tnet *netstack.Net
|
||||
var dev *device.Device
|
||||
@@ -422,12 +477,14 @@ func main() {
|
||||
pingStopChan := make(chan struct{})
|
||||
defer close(pingStopChan)
|
||||
|
||||
// Register handlers for different message types
|
||||
client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received registration message")
|
||||
|
||||
if connected {
|
||||
logger.Info("Already connected! But I will send a ping anyway...")
|
||||
_ = pingWithRetry(tnet, wgData.ServerIP)
|
||||
// Even if pingWithRetry returns an error, it will continue trying in the background
|
||||
_ = pingWithRetry(tnet, wgData.ServerIP) // Ignoring initial error as pings will continue
|
||||
return
|
||||
}
|
||||
|
||||
@@ -451,6 +508,7 @@ func main() {
|
||||
logger.Error("Failed to create TUN device: %v", err)
|
||||
}
|
||||
|
||||
// Create WireGuard device
|
||||
dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(
|
||||
mapToWireGuardLogLevel(loggerLevel),
|
||||
"wireguard: ",
|
||||
@@ -462,6 +520,7 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
// Configure WireGuard
|
||||
config := fmt.Sprintf(`private_key=%s
|
||||
public_key=%s
|
||||
allowed_ip=%s/32
|
||||
@@ -473,6 +532,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
logger.Error("Failed to configure WireGuard device: %v", err)
|
||||
}
|
||||
|
||||
// Bring up the device
|
||||
err = dev.Up()
|
||||
if err != nil {
|
||||
logger.Error("Failed to bring up WireGuard device: %v", err)
|
||||
@@ -480,21 +540,29 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
|
||||
logger.Info("WireGuard device created. Lets ping the server now...")
|
||||
|
||||
// Even if pingWithRetry returns an error, it will continue trying in the background
|
||||
_ = pingWithRetry(tnet, wgData.ServerIP)
|
||||
|
||||
// Always mark as connected and start the proxy manager regardless of initial ping result
|
||||
// as the pings will continue in the background
|
||||
if !connected {
|
||||
logger.Info("Starting ping check")
|
||||
// --- CHANGED: Pass healthFile to startPingCheck ---
|
||||
startPingCheck(tnet, wgData.ServerIP, pingStopChan, healthFile)
|
||||
|
||||
// Start connection monitoring in a separate goroutine
|
||||
go monitorConnectionStatus(tnet, wgData.ServerIP, client)
|
||||
}
|
||||
|
||||
// Create proxy manager
|
||||
pm = proxy.NewProxyManager(tnet)
|
||||
|
||||
connected = true
|
||||
|
||||
// add the targets if there are any
|
||||
if len(wgData.Targets.TCP) > 0 {
|
||||
updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP})
|
||||
}
|
||||
|
||||
if len(wgData.Targets.UDP) > 0 {
|
||||
updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP})
|
||||
}
|
||||
@@ -505,7 +573,6 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received: %+v", msg)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user