mirror of
https://github.com/fosrl/olm.git
synced 2026-03-01 08:16:56 +00:00
Holepunch but relay by default
This commit is contained in:
31
common.go
31
common.go
@@ -65,7 +65,7 @@ type EncryptedHolePunchMessage struct {
|
|||||||
var (
|
var (
|
||||||
peerMonitor *peermonitor.PeerMonitor
|
peerMonitor *peermonitor.PeerMonitor
|
||||||
stopHolepunch chan struct{}
|
stopHolepunch chan struct{}
|
||||||
stopRegister chan struct{}
|
stopRegister func()
|
||||||
stopPing chan struct{}
|
stopPing chan struct{}
|
||||||
olmToken string
|
olmToken string
|
||||||
gerbilServerPubKey string
|
gerbilServerPubKey string
|
||||||
@@ -378,35 +378,6 @@ func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func sendRegistration(olm *websocket.Client, publicKey string) error {
|
|
||||||
err := olm.SendMessage("olm/wg/register", map[string]interface{}{
|
|
||||||
"publicKey": publicKey,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to send registration message: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
logger.Info("Sent registration message")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func keepSendingRegistration(olm *websocket.Client, publicKey string) {
|
|
||||||
ticker := time.NewTicker(1 * time.Second)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-stopRegister:
|
|
||||||
logger.Info("Stopping registration messages")
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
if err := sendRegistration(olm, publicKey); err != nil {
|
|
||||||
logger.Error("Failed to send periodic registration: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
||||||
if maxPort < minPort {
|
if maxPort < minPort {
|
||||||
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
|
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
|
||||||
|
|||||||
30
main.go
30
main.go
@@ -157,7 +157,7 @@ func runOlmMain(ctx context.Context) {
|
|||||||
|
|
||||||
func runOlmMainWithArgs(ctx context.Context, args []string) {
|
func runOlmMainWithArgs(ctx context.Context, args []string) {
|
||||||
// Log that we've entered the main function
|
// Log that we've entered the main function
|
||||||
fmt.Printf("runOlmMainWithArgs() called with args: %v\n", args)
|
// fmt.Printf("runOlmMainWithArgs() called with args: %v\n", args)
|
||||||
|
|
||||||
// Create a new FlagSet for parsing service arguments
|
// Create a new FlagSet for parsing service arguments
|
||||||
serviceFlags := flag.NewFlagSet("service", flag.ContinueOnError)
|
serviceFlags := flag.NewFlagSet("service", flag.ContinueOnError)
|
||||||
@@ -179,10 +179,10 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
testTarget string // Add this var for test target
|
testTarget string // Add this var for test target
|
||||||
pingInterval time.Duration
|
pingInterval time.Duration
|
||||||
pingTimeout time.Duration
|
pingTimeout time.Duration
|
||||||
|
doHolepunch bool
|
||||||
)
|
)
|
||||||
|
|
||||||
stopHolepunch = make(chan struct{})
|
stopHolepunch = make(chan struct{})
|
||||||
stopRegister = make(chan struct{})
|
|
||||||
stopPing = make(chan struct{})
|
stopPing = make(chan struct{})
|
||||||
|
|
||||||
// if PANGOLIN_ENDPOINT, OLM_ID, and OLM_SECRET are set as environment variables, they will be used as default values
|
// if PANGOLIN_ENDPOINT, OLM_ID, and OLM_SECRET are set as environment variables, they will be used as default values
|
||||||
@@ -196,6 +196,7 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
httpAddr = os.Getenv("HTTP_ADDR")
|
httpAddr = os.Getenv("HTTP_ADDR")
|
||||||
pingIntervalStr := os.Getenv("PING_INTERVAL")
|
pingIntervalStr := os.Getenv("PING_INTERVAL")
|
||||||
pingTimeoutStr := os.Getenv("PING_TIMEOUT")
|
pingTimeoutStr := os.Getenv("PING_TIMEOUT")
|
||||||
|
doHolepunch = os.Getenv("HOLEPUNCH") == "true" // Default to true, can be overridden by flag
|
||||||
|
|
||||||
if endpoint == "" {
|
if endpoint == "" {
|
||||||
serviceFlags.StringVar(&endpoint, "endpoint", "", "Endpoint of your Pangolin server")
|
serviceFlags.StringVar(&endpoint, "endpoint", "", "Endpoint of your Pangolin server")
|
||||||
@@ -227,6 +228,8 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
if pingTimeoutStr == "" {
|
if pingTimeoutStr == "" {
|
||||||
serviceFlags.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 3s)")
|
serviceFlags.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 3s)")
|
||||||
}
|
}
|
||||||
|
serviceFlags.BoolVar(&enableHTTP, "enable-http", false, "Enable HTT server for receiving connection requests")
|
||||||
|
serviceFlags.BoolVar(&doHolepunch, "holepunch", false, "Enable hole punching (default false)")
|
||||||
|
|
||||||
// Parse the service arguments
|
// Parse the service arguments
|
||||||
if err := serviceFlags.Parse(args); err != nil {
|
if err := serviceFlags.Parse(args); err != nil {
|
||||||
@@ -442,7 +445,10 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
|
|
||||||
connectTimes++
|
connectTimes++
|
||||||
|
|
||||||
close(stopRegister)
|
if stopRegister != nil {
|
||||||
|
stopRegister()
|
||||||
|
stopRegister = nil
|
||||||
|
}
|
||||||
|
|
||||||
// if there is an existing tunnel then close it
|
// if there is an existing tunnel then close it
|
||||||
if dev != nil {
|
if dev != nil {
|
||||||
@@ -566,6 +572,7 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
fixKey(privateKey.String()),
|
fixKey(privateKey.String()),
|
||||||
olm,
|
olm,
|
||||||
dev,
|
dev,
|
||||||
|
doHolepunch,
|
||||||
)
|
)
|
||||||
|
|
||||||
// loop over the sites and call ConfigurePeer for each one
|
// loop over the sites and call ConfigurePeer for each one
|
||||||
@@ -791,9 +798,14 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
|
|
||||||
olm.OnConnect(func() error {
|
olm.OnConnect(func() error {
|
||||||
publicKey := privateKey.PublicKey()
|
publicKey := privateKey.PublicKey()
|
||||||
logger.Debug("Public key: %s", publicKey)
|
|
||||||
|
|
||||||
go keepSendingRegistration(olm, publicKey.String())
|
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch)
|
||||||
|
|
||||||
|
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||||
|
"publicKey": publicKey.String(),
|
||||||
|
"relay": !doHolepunch,
|
||||||
|
}, 1*time.Second)
|
||||||
|
|
||||||
go keepSendingPing(olm)
|
go keepSendingPing(olm)
|
||||||
|
|
||||||
if httpServer != nil {
|
if httpServer != nil {
|
||||||
@@ -832,11 +844,9 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
close(stopHolepunch)
|
close(stopHolepunch)
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
if stopRegister != nil {
|
||||||
case <-stopRegister:
|
stopRegister()
|
||||||
// Channel already closed
|
stopRegister = nil
|
||||||
default:
|
|
||||||
close(stopRegister)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
|||||||
@@ -26,31 +26,33 @@ type WireGuardConfig struct {
|
|||||||
|
|
||||||
// PeerMonitor handles monitoring the connection status to multiple WireGuard peers
|
// PeerMonitor handles monitoring the connection status to multiple WireGuard peers
|
||||||
type PeerMonitor struct {
|
type PeerMonitor struct {
|
||||||
monitors map[int]*wgtester.Client
|
monitors map[int]*wgtester.Client
|
||||||
configs map[int]*WireGuardConfig
|
configs map[int]*WireGuardConfig
|
||||||
callback PeerMonitorCallback
|
callback PeerMonitorCallback
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
running bool
|
running bool
|
||||||
interval time.Duration
|
interval time.Duration
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
maxAttempts int
|
maxAttempts int
|
||||||
privateKey string
|
privateKey string
|
||||||
wsClient *websocket.Client
|
wsClient *websocket.Client
|
||||||
device *device.Device
|
device *device.Device
|
||||||
|
handleRelaySwitch bool // Whether to handle relay switching
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPeerMonitor creates a new peer monitor with the given callback
|
// NewPeerMonitor creates a new peer monitor with the given callback
|
||||||
func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device) *PeerMonitor {
|
func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool) *PeerMonitor {
|
||||||
return &PeerMonitor{
|
return &PeerMonitor{
|
||||||
monitors: make(map[int]*wgtester.Client),
|
monitors: make(map[int]*wgtester.Client),
|
||||||
configs: make(map[int]*WireGuardConfig),
|
configs: make(map[int]*WireGuardConfig),
|
||||||
callback: callback,
|
callback: callback,
|
||||||
interval: 1 * time.Second, // Default check interval
|
interval: 1 * time.Second, // Default check interval
|
||||||
timeout: 2500 * time.Millisecond,
|
timeout: 2500 * time.Millisecond,
|
||||||
maxAttempts: 8,
|
maxAttempts: 8,
|
||||||
privateKey: privateKey,
|
privateKey: privateKey,
|
||||||
wsClient: wsClient,
|
wsClient: wsClient,
|
||||||
device: device,
|
device: device,
|
||||||
|
handleRelaySwitch: handleRelaySwitch,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -214,6 +216,10 @@ persistent_keepalive_interval=1`, pm.privateKey, config.PublicKey, config.Server
|
|||||||
|
|
||||||
// sendRelay sends a relay message to the server
|
// sendRelay sends a relay message to the server
|
||||||
func (pm *PeerMonitor) sendRelay(siteID int) error {
|
func (pm *PeerMonitor) sendRelay(siteID int) error {
|
||||||
|
if !pm.handleRelaySwitch {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if pm.wsClient == nil {
|
if pm.wsClient == nil {
|
||||||
return fmt.Errorf("websocket client is nil")
|
return fmt.Errorf("websocket client is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -379,7 +379,7 @@ func debugService(args []string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Starting service in debug mode...\n")
|
// fmt.Printf("Starting service in debug mode...\n")
|
||||||
|
|
||||||
// Start the service
|
// Start the service
|
||||||
err := startService([]string{}) // Pass empty args since we already saved them
|
err := startService([]string{}) // Pass empty args since we already saved them
|
||||||
@@ -387,8 +387,8 @@ func debugService(args []string) error {
|
|||||||
return fmt.Errorf("failed to start service: %v", err)
|
return fmt.Errorf("failed to start service: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Service started. Watching logs (Press Ctrl+C to stop watching)...\n")
|
// fmt.Printf("Service started. Watching logs (Press Ctrl+C to stop watching)...\n")
|
||||||
fmt.Printf("================================================================================\n")
|
// fmt.Printf("================================================================================\n")
|
||||||
|
|
||||||
// Watch the log file
|
// Watch the log file
|
||||||
return watchLogFile(true)
|
return watchLogFile(true)
|
||||||
|
|||||||
Reference in New Issue
Block a user