Rudamentary check for p2p connectivity

This commit is contained in:
Owen
2025-02-23 20:17:39 -05:00
parent 3819823d95
commit b6db70e285

197
main.go
View File

@@ -392,20 +392,99 @@ func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
return 0, fmt.Errorf("no available UDP ports found in range %d-%d", minPort, maxPort)
}
func monitorConnection(dev *device.Device, onTimeout func()) {
const (
checkInterval = 100 * time.Millisecond // Check every 0.1 seconds
timeout = 500 * time.Millisecond // Total timeout of 1.5 seconds
)
go func() {
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
timeoutTimer := time.NewTimer(timeout)
defer timeoutTimer.Stop()
// var lastSent uint64
for {
select {
case <-ticker.C:
// Get the current device statistics
deviceInfo, err := dev.IpcGet()
if err != nil {
logger.Error("Failed to get device statistics: %v", err)
continue
}
// Parse the statistics from the IPC output
stats := parseStatistics(deviceInfo)
logger.Info("Received: %d, Sent: %d", stats.received, stats.sent)
// Check if we've received any new bytes
if stats.received > 0 {
// Connection is successful, we received data
logger.Info("Connection established - received bytes detected")
return
}
// Update the last known values
// lastSent = stats.sent
case <-timeoutTimer.C:
// We've hit our timeout without seeing any received bytes
logger.Warn("Connection timeout - no data received within %v", timeout)
onTimeout()
return
}
}
}()
}
// statistics holds the parsed byte counts from the device
type statistics struct {
received uint64
sent uint64
}
// parseStatistics extracts the received and sent byte counts from the device info string
func parseStatistics(info string) statistics {
var stats statistics
// Split the device info into lines
lines := strings.Split(info, "\n")
// Look for the transfer_receive and transfer_send lines
for _, line := range lines {
if strings.HasPrefix(line, "rx_bytes=") {
valueStr := strings.TrimPrefix(line, "rx_bytes=")
if value, err := strconv.ParseUint(valueStr, 10, 64); err == nil {
stats.received = value
}
} else if strings.HasPrefix(line, "tx_bytes=") {
valueStr := strings.TrimPrefix(line, "tx_bytes=")
if value, err := strconv.ParseUint(valueStr, 10, 64); err == nil {
stats.sent = value
}
}
}
return stats
}
func main() {
var (
endpoint string
id string
secret string
mtu string
mtuInt int
dns string
privateKey wgtypes.Key
err error
logLevel string
interfaceName string
generateAndSaveKeyTo string
reachableAt string
endpoint string
id string
secret string
mtu string
mtuInt int
dns string
privateKey wgtypes.Key
err error
logLevel string
interfaceName string
)
stopHolepunch = make(chan struct{})
@@ -419,8 +498,6 @@ func main() {
dns = os.Getenv("DNS")
logLevel = os.Getenv("LOG_LEVEL")
interfaceName = os.Getenv("INTERFACE")
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
reachableAt = os.Getenv("REACHABLE_AT")
if endpoint == "" {
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your Pangolin server")
@@ -441,13 +518,7 @@ func main() {
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
}
if interfaceName == "" {
flag.StringVar(&interfaceName, "interface", "wg2", "Name of the WireGuard interface")
}
if generateAndSaveKeyTo == "" {
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
}
if reachableAt == "" {
flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about")
flag.StringVar(&interfaceName, "interface", "olm", "Name of the WireGuard interface")
}
// do a --version check
@@ -495,18 +566,56 @@ func main() {
var dev *device.Device
var wgData WgData
var uapi *os.File
var tdev tun.Device
olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) {
logger.Info("Received terminate message")
olm.Close()
})
olm.RegisterHandler("olm/wg/update", func(msg websocket.WSMessage) {
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &wgData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
endpoint, err := resolveDomain(wgData.Endpoint)
if err != nil {
logger.Error("Failed to resolve endpoint: %v", err)
return
}
// Configure WireGuard
config := fmt.Sprintf(`private_key=%s
public_key=%s
allowed_ip=%s/32
endpoint=%s
persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint)
err = dev.IpcSet(config)
if err != nil {
logger.Error("Failed to configure WireGuard device: %v", err)
}
})
// Register handlers for different message types
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
logger.Info("Received message: %v", msg.Data)
close(stopRegister)
// if there is an existing tunnel then close it
if dev != nil {
logger.Info("Got new message. Closing existing tunnel!")
dev.Close()
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
@@ -519,7 +628,7 @@ func main() {
}
// NEED TO DETERMINE AVAILABLE TUN DEVICE HERE
tdev, err := func() (tun.Device, error) {
tdev, err = func() (tun.Device, error) {
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
// if on macOS, call findUnusedUTUN to get a new utun device
@@ -610,24 +719,18 @@ func main() {
logger.Info("UAPI listener started")
// endpoint, err := resolveDomain(wgData.Endpoint)
// if err != nil {
// logger.Error("Failed to resolve endpoint: %v", err)
// return
// }
host, err := resolveDomain(wgData.Endpoint)
if err != nil {
logger.Error("Failed to resolve endpoint: %v", err)
return
}
// Configure WireGuard
// config := fmt.Sprintf(`private_key=%s
// public_key=%s
// allowed_ip=%s/32
// endpoint=%s
// persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint)
config := fmt.Sprintf(`private_key=%s
public_key=%s
allowed_ip=%s/32
endpoint=18.212.58.121:21820
persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP)
endpoint=%s
persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, host)
err = dev.IpcSet(config)
if err != nil {
@@ -647,6 +750,30 @@ persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.Pub
}
close(stopHolepunch)
// Monitor the connection for activity
monitorConnection(dev, func() {
host, err := resolveDomain(endpoint)
if err != nil {
logger.Error("Failed to resolve endpoint: %v", err)
return
}
// Configure WireGuard
config := fmt.Sprintf(`private_key=%s
public_key=%s
allowed_ip=%s/32
endpoint=%s:21820
persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, host)
err = dev.IpcSet(config)
if err != nil {
logger.Error("Failed to configure WireGuard device: %v", err)
}
logger.Info("Adjusted to point to relay!")
})
logger.Info("WireGuard device created.")
})