mirror of
https://github.com/fosrl/olm.git
synced 2026-02-07 21:46:40 +00:00
Rudamentary check for p2p connectivity
This commit is contained in:
197
main.go
197
main.go
@@ -392,20 +392,99 @@ func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
||||
return 0, fmt.Errorf("no available UDP ports found in range %d-%d", minPort, maxPort)
|
||||
}
|
||||
|
||||
func monitorConnection(dev *device.Device, onTimeout func()) {
|
||||
const (
|
||||
checkInterval = 100 * time.Millisecond // Check every 0.1 seconds
|
||||
timeout = 500 * time.Millisecond // Total timeout of 1.5 seconds
|
||||
)
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeoutTimer := time.NewTimer(timeout)
|
||||
defer timeoutTimer.Stop()
|
||||
|
||||
// var lastSent uint64
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// Get the current device statistics
|
||||
deviceInfo, err := dev.IpcGet()
|
||||
if err != nil {
|
||||
logger.Error("Failed to get device statistics: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse the statistics from the IPC output
|
||||
stats := parseStatistics(deviceInfo)
|
||||
|
||||
logger.Info("Received: %d, Sent: %d", stats.received, stats.sent)
|
||||
|
||||
// Check if we've received any new bytes
|
||||
if stats.received > 0 {
|
||||
// Connection is successful, we received data
|
||||
logger.Info("Connection established - received bytes detected")
|
||||
return
|
||||
}
|
||||
|
||||
// Update the last known values
|
||||
// lastSent = stats.sent
|
||||
|
||||
case <-timeoutTimer.C:
|
||||
// We've hit our timeout without seeing any received bytes
|
||||
logger.Warn("Connection timeout - no data received within %v", timeout)
|
||||
onTimeout()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// statistics holds the parsed byte counts from the device
|
||||
type statistics struct {
|
||||
received uint64
|
||||
sent uint64
|
||||
}
|
||||
|
||||
// parseStatistics extracts the received and sent byte counts from the device info string
|
||||
func parseStatistics(info string) statistics {
|
||||
var stats statistics
|
||||
|
||||
// Split the device info into lines
|
||||
lines := strings.Split(info, "\n")
|
||||
|
||||
// Look for the transfer_receive and transfer_send lines
|
||||
for _, line := range lines {
|
||||
if strings.HasPrefix(line, "rx_bytes=") {
|
||||
valueStr := strings.TrimPrefix(line, "rx_bytes=")
|
||||
if value, err := strconv.ParseUint(valueStr, 10, 64); err == nil {
|
||||
stats.received = value
|
||||
}
|
||||
} else if strings.HasPrefix(line, "tx_bytes=") {
|
||||
valueStr := strings.TrimPrefix(line, "tx_bytes=")
|
||||
if value, err := strconv.ParseUint(valueStr, 10, 64); err == nil {
|
||||
stats.sent = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
func main() {
|
||||
var (
|
||||
endpoint string
|
||||
id string
|
||||
secret string
|
||||
mtu string
|
||||
mtuInt int
|
||||
dns string
|
||||
privateKey wgtypes.Key
|
||||
err error
|
||||
logLevel string
|
||||
interfaceName string
|
||||
generateAndSaveKeyTo string
|
||||
reachableAt string
|
||||
endpoint string
|
||||
id string
|
||||
secret string
|
||||
mtu string
|
||||
mtuInt int
|
||||
dns string
|
||||
privateKey wgtypes.Key
|
||||
err error
|
||||
logLevel string
|
||||
interfaceName string
|
||||
)
|
||||
|
||||
stopHolepunch = make(chan struct{})
|
||||
@@ -419,8 +498,6 @@ func main() {
|
||||
dns = os.Getenv("DNS")
|
||||
logLevel = os.Getenv("LOG_LEVEL")
|
||||
interfaceName = os.Getenv("INTERFACE")
|
||||
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
|
||||
reachableAt = os.Getenv("REACHABLE_AT")
|
||||
|
||||
if endpoint == "" {
|
||||
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your Pangolin server")
|
||||
@@ -441,13 +518,7 @@ func main() {
|
||||
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
||||
}
|
||||
if interfaceName == "" {
|
||||
flag.StringVar(&interfaceName, "interface", "wg2", "Name of the WireGuard interface")
|
||||
}
|
||||
if generateAndSaveKeyTo == "" {
|
||||
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
|
||||
}
|
||||
if reachableAt == "" {
|
||||
flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about")
|
||||
flag.StringVar(&interfaceName, "interface", "olm", "Name of the WireGuard interface")
|
||||
}
|
||||
|
||||
// do a --version check
|
||||
@@ -495,18 +566,56 @@ func main() {
|
||||
var dev *device.Device
|
||||
var wgData WgData
|
||||
var uapi *os.File
|
||||
var tdev tun.Device
|
||||
|
||||
olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received terminate message")
|
||||
olm.Close()
|
||||
})
|
||||
|
||||
olm.RegisterHandler("olm/wg/update", func(msg websocket.WSMessage) {
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &wgData); err != nil {
|
||||
logger.Info("Error unmarshaling target data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
endpoint, err := resolveDomain(wgData.Endpoint)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve endpoint: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Configure WireGuard
|
||||
config := fmt.Sprintf(`private_key=%s
|
||||
public_key=%s
|
||||
allowed_ip=%s/32
|
||||
endpoint=%s
|
||||
persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint)
|
||||
|
||||
err = dev.IpcSet(config)
|
||||
if err != nil {
|
||||
logger.Error("Failed to configure WireGuard device: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Register handlers for different message types
|
||||
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received message: %v", msg.Data)
|
||||
|
||||
close(stopRegister)
|
||||
|
||||
// if there is an existing tunnel then close it
|
||||
if dev != nil {
|
||||
logger.Info("Got new message. Closing existing tunnel!")
|
||||
dev.Close()
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
@@ -519,7 +628,7 @@ func main() {
|
||||
}
|
||||
|
||||
// NEED TO DETERMINE AVAILABLE TUN DEVICE HERE
|
||||
tdev, err := func() (tun.Device, error) {
|
||||
tdev, err = func() (tun.Device, error) {
|
||||
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
|
||||
|
||||
// if on macOS, call findUnusedUTUN to get a new utun device
|
||||
@@ -610,24 +719,18 @@ func main() {
|
||||
|
||||
logger.Info("UAPI listener started")
|
||||
|
||||
// endpoint, err := resolveDomain(wgData.Endpoint)
|
||||
// if err != nil {
|
||||
// logger.Error("Failed to resolve endpoint: %v", err)
|
||||
// return
|
||||
// }
|
||||
host, err := resolveDomain(wgData.Endpoint)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve endpoint: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Configure WireGuard
|
||||
// config := fmt.Sprintf(`private_key=%s
|
||||
// public_key=%s
|
||||
// allowed_ip=%s/32
|
||||
// endpoint=%s
|
||||
// persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint)
|
||||
|
||||
config := fmt.Sprintf(`private_key=%s
|
||||
public_key=%s
|
||||
allowed_ip=%s/32
|
||||
endpoint=18.212.58.121:21820
|
||||
persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP)
|
||||
endpoint=%s
|
||||
persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, host)
|
||||
|
||||
err = dev.IpcSet(config)
|
||||
if err != nil {
|
||||
@@ -647,6 +750,30 @@ persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
}
|
||||
|
||||
close(stopHolepunch)
|
||||
|
||||
// Monitor the connection for activity
|
||||
monitorConnection(dev, func() {
|
||||
host, err := resolveDomain(endpoint)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve endpoint: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Configure WireGuard
|
||||
config := fmt.Sprintf(`private_key=%s
|
||||
public_key=%s
|
||||
allowed_ip=%s/32
|
||||
endpoint=%s:21820
|
||||
persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, host)
|
||||
|
||||
err = dev.IpcSet(config)
|
||||
if err != nil {
|
||||
logger.Error("Failed to configure WireGuard device: %v", err)
|
||||
}
|
||||
|
||||
logger.Info("Adjusted to point to relay!")
|
||||
})
|
||||
|
||||
logger.Info("WireGuard device created.")
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user