mirror of
https://github.com/fosrl/olm.git
synced 2026-03-01 16:26:43 +00:00
Rudamentary check for p2p connectivity
This commit is contained in:
197
main.go
197
main.go
@@ -392,20 +392,99 @@ func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
|||||||
return 0, fmt.Errorf("no available UDP ports found in range %d-%d", minPort, maxPort)
|
return 0, fmt.Errorf("no available UDP ports found in range %d-%d", minPort, maxPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func monitorConnection(dev *device.Device, onTimeout func()) {
|
||||||
|
const (
|
||||||
|
checkInterval = 100 * time.Millisecond // Check every 0.1 seconds
|
||||||
|
timeout = 500 * time.Millisecond // Total timeout of 1.5 seconds
|
||||||
|
)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(checkInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
timeoutTimer := time.NewTimer(timeout)
|
||||||
|
defer timeoutTimer.Stop()
|
||||||
|
|
||||||
|
// var lastSent uint64
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
// Get the current device statistics
|
||||||
|
deviceInfo, err := dev.IpcGet()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to get device statistics: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the statistics from the IPC output
|
||||||
|
stats := parseStatistics(deviceInfo)
|
||||||
|
|
||||||
|
logger.Info("Received: %d, Sent: %d", stats.received, stats.sent)
|
||||||
|
|
||||||
|
// Check if we've received any new bytes
|
||||||
|
if stats.received > 0 {
|
||||||
|
// Connection is successful, we received data
|
||||||
|
logger.Info("Connection established - received bytes detected")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the last known values
|
||||||
|
// lastSent = stats.sent
|
||||||
|
|
||||||
|
case <-timeoutTimer.C:
|
||||||
|
// We've hit our timeout without seeing any received bytes
|
||||||
|
logger.Warn("Connection timeout - no data received within %v", timeout)
|
||||||
|
onTimeout()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// statistics holds the parsed byte counts from the device
|
||||||
|
type statistics struct {
|
||||||
|
received uint64
|
||||||
|
sent uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseStatistics extracts the received and sent byte counts from the device info string
|
||||||
|
func parseStatistics(info string) statistics {
|
||||||
|
var stats statistics
|
||||||
|
|
||||||
|
// Split the device info into lines
|
||||||
|
lines := strings.Split(info, "\n")
|
||||||
|
|
||||||
|
// Look for the transfer_receive and transfer_send lines
|
||||||
|
for _, line := range lines {
|
||||||
|
if strings.HasPrefix(line, "rx_bytes=") {
|
||||||
|
valueStr := strings.TrimPrefix(line, "rx_bytes=")
|
||||||
|
if value, err := strconv.ParseUint(valueStr, 10, 64); err == nil {
|
||||||
|
stats.received = value
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(line, "tx_bytes=") {
|
||||||
|
valueStr := strings.TrimPrefix(line, "tx_bytes=")
|
||||||
|
if value, err := strconv.ParseUint(valueStr, 10, 64); err == nil {
|
||||||
|
stats.sent = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var (
|
var (
|
||||||
endpoint string
|
endpoint string
|
||||||
id string
|
id string
|
||||||
secret string
|
secret string
|
||||||
mtu string
|
mtu string
|
||||||
mtuInt int
|
mtuInt int
|
||||||
dns string
|
dns string
|
||||||
privateKey wgtypes.Key
|
privateKey wgtypes.Key
|
||||||
err error
|
err error
|
||||||
logLevel string
|
logLevel string
|
||||||
interfaceName string
|
interfaceName string
|
||||||
generateAndSaveKeyTo string
|
|
||||||
reachableAt string
|
|
||||||
)
|
)
|
||||||
|
|
||||||
stopHolepunch = make(chan struct{})
|
stopHolepunch = make(chan struct{})
|
||||||
@@ -419,8 +498,6 @@ func main() {
|
|||||||
dns = os.Getenv("DNS")
|
dns = os.Getenv("DNS")
|
||||||
logLevel = os.Getenv("LOG_LEVEL")
|
logLevel = os.Getenv("LOG_LEVEL")
|
||||||
interfaceName = os.Getenv("INTERFACE")
|
interfaceName = os.Getenv("INTERFACE")
|
||||||
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
|
|
||||||
reachableAt = os.Getenv("REACHABLE_AT")
|
|
||||||
|
|
||||||
if endpoint == "" {
|
if endpoint == "" {
|
||||||
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your Pangolin server")
|
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your Pangolin server")
|
||||||
@@ -441,13 +518,7 @@ func main() {
|
|||||||
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
||||||
}
|
}
|
||||||
if interfaceName == "" {
|
if interfaceName == "" {
|
||||||
flag.StringVar(&interfaceName, "interface", "wg2", "Name of the WireGuard interface")
|
flag.StringVar(&interfaceName, "interface", "olm", "Name of the WireGuard interface")
|
||||||
}
|
|
||||||
if generateAndSaveKeyTo == "" {
|
|
||||||
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
|
|
||||||
}
|
|
||||||
if reachableAt == "" {
|
|
||||||
flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// do a --version check
|
// do a --version check
|
||||||
@@ -495,18 +566,56 @@ func main() {
|
|||||||
var dev *device.Device
|
var dev *device.Device
|
||||||
var wgData WgData
|
var wgData WgData
|
||||||
var uapi *os.File
|
var uapi *os.File
|
||||||
|
var tdev tun.Device
|
||||||
|
|
||||||
olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) {
|
olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) {
|
||||||
logger.Info("Received terminate message")
|
logger.Info("Received terminate message")
|
||||||
olm.Close()
|
olm.Close()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
olm.RegisterHandler("olm/wg/update", func(msg websocket.WSMessage) {
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Info("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(jsonData, &wgData); err != nil {
|
||||||
|
logger.Info("Error unmarshaling target data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint, err := resolveDomain(wgData.Endpoint)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to resolve endpoint: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure WireGuard
|
||||||
|
config := fmt.Sprintf(`private_key=%s
|
||||||
|
public_key=%s
|
||||||
|
allowed_ip=%s/32
|
||||||
|
endpoint=%s
|
||||||
|
persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint)
|
||||||
|
|
||||||
|
err = dev.IpcSet(config)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to configure WireGuard device: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// Register handlers for different message types
|
// Register handlers for different message types
|
||||||
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
|
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
|
||||||
logger.Info("Received message: %v", msg.Data)
|
logger.Info("Received message: %v", msg.Data)
|
||||||
|
|
||||||
close(stopRegister)
|
close(stopRegister)
|
||||||
|
|
||||||
|
// if there is an existing tunnel then close it
|
||||||
|
if dev != nil {
|
||||||
|
logger.Info("Got new message. Closing existing tunnel!")
|
||||||
|
dev.Close()
|
||||||
|
}
|
||||||
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Info("Error marshaling data: %v", err)
|
logger.Info("Error marshaling data: %v", err)
|
||||||
@@ -519,7 +628,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NEED TO DETERMINE AVAILABLE TUN DEVICE HERE
|
// NEED TO DETERMINE AVAILABLE TUN DEVICE HERE
|
||||||
tdev, err := func() (tun.Device, error) {
|
tdev, err = func() (tun.Device, error) {
|
||||||
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
|
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
|
||||||
|
|
||||||
// if on macOS, call findUnusedUTUN to get a new utun device
|
// if on macOS, call findUnusedUTUN to get a new utun device
|
||||||
@@ -610,24 +719,18 @@ func main() {
|
|||||||
|
|
||||||
logger.Info("UAPI listener started")
|
logger.Info("UAPI listener started")
|
||||||
|
|
||||||
// endpoint, err := resolveDomain(wgData.Endpoint)
|
host, err := resolveDomain(wgData.Endpoint)
|
||||||
// if err != nil {
|
if err != nil {
|
||||||
// logger.Error("Failed to resolve endpoint: %v", err)
|
logger.Error("Failed to resolve endpoint: %v", err)
|
||||||
// return
|
return
|
||||||
// }
|
}
|
||||||
|
|
||||||
// Configure WireGuard
|
// Configure WireGuard
|
||||||
// config := fmt.Sprintf(`private_key=%s
|
|
||||||
// public_key=%s
|
|
||||||
// allowed_ip=%s/32
|
|
||||||
// endpoint=%s
|
|
||||||
// persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint)
|
|
||||||
|
|
||||||
config := fmt.Sprintf(`private_key=%s
|
config := fmt.Sprintf(`private_key=%s
|
||||||
public_key=%s
|
public_key=%s
|
||||||
allowed_ip=%s/32
|
allowed_ip=%s/32
|
||||||
endpoint=18.212.58.121:21820
|
endpoint=%s
|
||||||
persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP)
|
persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, host)
|
||||||
|
|
||||||
err = dev.IpcSet(config)
|
err = dev.IpcSet(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -647,6 +750,30 @@ persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
|||||||
}
|
}
|
||||||
|
|
||||||
close(stopHolepunch)
|
close(stopHolepunch)
|
||||||
|
|
||||||
|
// Monitor the connection for activity
|
||||||
|
monitorConnection(dev, func() {
|
||||||
|
host, err := resolveDomain(endpoint)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to resolve endpoint: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure WireGuard
|
||||||
|
config := fmt.Sprintf(`private_key=%s
|
||||||
|
public_key=%s
|
||||||
|
allowed_ip=%s/32
|
||||||
|
endpoint=%s:21820
|
||||||
|
persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, host)
|
||||||
|
|
||||||
|
err = dev.IpcSet(config)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to configure WireGuard device: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Adjusted to point to relay!")
|
||||||
|
})
|
||||||
|
|
||||||
logger.Info("WireGuard device created.")
|
logger.Info("WireGuard device created.")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user