Holepunch but relay by default

This commit is contained in:
Owen
2025-07-24 14:44:12 -07:00
parent 2a832420df
commit 5302f9da34
4 changed files with 51 additions and 64 deletions

View File

@@ -65,7 +65,7 @@ type EncryptedHolePunchMessage struct {
var ( var (
peerMonitor *peermonitor.PeerMonitor peerMonitor *peermonitor.PeerMonitor
stopHolepunch chan struct{} stopHolepunch chan struct{}
stopRegister chan struct{} stopRegister func()
stopPing chan struct{} stopPing chan struct{}
olmToken string olmToken string
gerbilServerPubKey string gerbilServerPubKey string
@@ -378,35 +378,6 @@ func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16) {
} }
} }
func sendRegistration(olm *websocket.Client, publicKey string) error {
err := olm.SendMessage("olm/wg/register", map[string]interface{}{
"publicKey": publicKey,
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Sent registration message")
return nil
}
func keepSendingRegistration(olm *websocket.Client, publicKey string) {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-stopRegister:
logger.Info("Stopping registration messages")
return
case <-ticker.C:
if err := sendRegistration(olm, publicKey); err != nil {
logger.Error("Failed to send periodic registration: %v", err)
}
}
}
}
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
if maxPort < minPort { if maxPort < minPort {
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort) return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)

30
main.go
View File

@@ -157,7 +157,7 @@ func runOlmMain(ctx context.Context) {
func runOlmMainWithArgs(ctx context.Context, args []string) { func runOlmMainWithArgs(ctx context.Context, args []string) {
// Log that we've entered the main function // Log that we've entered the main function
fmt.Printf("runOlmMainWithArgs() called with args: %v\n", args) // fmt.Printf("runOlmMainWithArgs() called with args: %v\n", args)
// Create a new FlagSet for parsing service arguments // Create a new FlagSet for parsing service arguments
serviceFlags := flag.NewFlagSet("service", flag.ContinueOnError) serviceFlags := flag.NewFlagSet("service", flag.ContinueOnError)
@@ -179,10 +179,10 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
testTarget string // Add this var for test target testTarget string // Add this var for test target
pingInterval time.Duration pingInterval time.Duration
pingTimeout time.Duration pingTimeout time.Duration
doHolepunch bool
) )
stopHolepunch = make(chan struct{}) stopHolepunch = make(chan struct{})
stopRegister = make(chan struct{})
stopPing = make(chan struct{}) stopPing = make(chan struct{})
// if PANGOLIN_ENDPOINT, OLM_ID, and OLM_SECRET are set as environment variables, they will be used as default values // if PANGOLIN_ENDPOINT, OLM_ID, and OLM_SECRET are set as environment variables, they will be used as default values
@@ -196,6 +196,7 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
httpAddr = os.Getenv("HTTP_ADDR") httpAddr = os.Getenv("HTTP_ADDR")
pingIntervalStr := os.Getenv("PING_INTERVAL") pingIntervalStr := os.Getenv("PING_INTERVAL")
pingTimeoutStr := os.Getenv("PING_TIMEOUT") pingTimeoutStr := os.Getenv("PING_TIMEOUT")
doHolepunch = os.Getenv("HOLEPUNCH") == "true" // Default to true, can be overridden by flag
if endpoint == "" { if endpoint == "" {
serviceFlags.StringVar(&endpoint, "endpoint", "", "Endpoint of your Pangolin server") serviceFlags.StringVar(&endpoint, "endpoint", "", "Endpoint of your Pangolin server")
@@ -227,6 +228,8 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
if pingTimeoutStr == "" { if pingTimeoutStr == "" {
serviceFlags.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 3s)") serviceFlags.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 3s)")
} }
serviceFlags.BoolVar(&enableHTTP, "enable-http", false, "Enable HTT server for receiving connection requests")
serviceFlags.BoolVar(&doHolepunch, "holepunch", false, "Enable hole punching (default false)")
// Parse the service arguments // Parse the service arguments
if err := serviceFlags.Parse(args); err != nil { if err := serviceFlags.Parse(args); err != nil {
@@ -442,7 +445,10 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
connectTimes++ connectTimes++
close(stopRegister) if stopRegister != nil {
stopRegister()
stopRegister = nil
}
// if there is an existing tunnel then close it // if there is an existing tunnel then close it
if dev != nil { if dev != nil {
@@ -566,6 +572,7 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
fixKey(privateKey.String()), fixKey(privateKey.String()),
olm, olm,
dev, dev,
doHolepunch,
) )
// loop over the sites and call ConfigurePeer for each one // loop over the sites and call ConfigurePeer for each one
@@ -791,9 +798,14 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
olm.OnConnect(func() error { olm.OnConnect(func() error {
publicKey := privateKey.PublicKey() publicKey := privateKey.PublicKey()
logger.Debug("Public key: %s", publicKey)
go keepSendingRegistration(olm, publicKey.String()) logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch)
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !doHolepunch,
}, 1*time.Second)
go keepSendingPing(olm) go keepSendingPing(olm)
if httpServer != nil { if httpServer != nil {
@@ -832,11 +844,9 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
close(stopHolepunch) close(stopHolepunch)
} }
select { if stopRegister != nil {
case <-stopRegister: stopRegister()
// Channel already closed stopRegister = nil
default:
close(stopRegister)
} }
select { select {

View File

@@ -26,31 +26,33 @@ type WireGuardConfig struct {
// PeerMonitor handles monitoring the connection status to multiple WireGuard peers // PeerMonitor handles monitoring the connection status to multiple WireGuard peers
type PeerMonitor struct { type PeerMonitor struct {
monitors map[int]*wgtester.Client monitors map[int]*wgtester.Client
configs map[int]*WireGuardConfig configs map[int]*WireGuardConfig
callback PeerMonitorCallback callback PeerMonitorCallback
mutex sync.Mutex mutex sync.Mutex
running bool running bool
interval time.Duration interval time.Duration
timeout time.Duration timeout time.Duration
maxAttempts int maxAttempts int
privateKey string privateKey string
wsClient *websocket.Client wsClient *websocket.Client
device *device.Device device *device.Device
handleRelaySwitch bool // Whether to handle relay switching
} }
// NewPeerMonitor creates a new peer monitor with the given callback // NewPeerMonitor creates a new peer monitor with the given callback
func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device) *PeerMonitor { func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool) *PeerMonitor {
return &PeerMonitor{ return &PeerMonitor{
monitors: make(map[int]*wgtester.Client), monitors: make(map[int]*wgtester.Client),
configs: make(map[int]*WireGuardConfig), configs: make(map[int]*WireGuardConfig),
callback: callback, callback: callback,
interval: 1 * time.Second, // Default check interval interval: 1 * time.Second, // Default check interval
timeout: 2500 * time.Millisecond, timeout: 2500 * time.Millisecond,
maxAttempts: 8, maxAttempts: 8,
privateKey: privateKey, privateKey: privateKey,
wsClient: wsClient, wsClient: wsClient,
device: device, device: device,
handleRelaySwitch: handleRelaySwitch,
} }
} }
@@ -214,6 +216,10 @@ persistent_keepalive_interval=1`, pm.privateKey, config.PublicKey, config.Server
// sendRelay sends a relay message to the server // sendRelay sends a relay message to the server
func (pm *PeerMonitor) sendRelay(siteID int) error { func (pm *PeerMonitor) sendRelay(siteID int) error {
if !pm.handleRelaySwitch {
return nil
}
if pm.wsClient == nil { if pm.wsClient == nil {
return fmt.Errorf("websocket client is nil") return fmt.Errorf("websocket client is nil")
} }

View File

@@ -379,7 +379,7 @@ func debugService(args []string) error {
} }
} }
fmt.Printf("Starting service in debug mode...\n") // fmt.Printf("Starting service in debug mode...\n")
// Start the service // Start the service
err := startService([]string{}) // Pass empty args since we already saved them err := startService([]string{}) // Pass empty args since we already saved them
@@ -387,8 +387,8 @@ func debugService(args []string) error {
return fmt.Errorf("failed to start service: %v", err) return fmt.Errorf("failed to start service: %v", err)
} }
fmt.Printf("Service started. Watching logs (Press Ctrl+C to stop watching)...\n") // fmt.Printf("Service started. Watching logs (Press Ctrl+C to stop watching)...\n")
fmt.Printf("================================================================================\n") // fmt.Printf("================================================================================\n")
// Watch the log file // Watch the log file
return watchLogFile(true) return watchLogFile(true)