mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
31
common.go
31
common.go
@@ -65,7 +65,7 @@ type EncryptedHolePunchMessage struct {
|
||||
var (
|
||||
peerMonitor *peermonitor.PeerMonitor
|
||||
stopHolepunch chan struct{}
|
||||
stopRegister chan struct{}
|
||||
stopRegister func()
|
||||
stopPing chan struct{}
|
||||
olmToken string
|
||||
gerbilServerPubKey string
|
||||
@@ -378,35 +378,6 @@ func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16) {
|
||||
}
|
||||
}
|
||||
|
||||
func sendRegistration(olm *websocket.Client, publicKey string) error {
|
||||
err := olm.SendMessage("olm/wg/register", map[string]interface{}{
|
||||
"publicKey": publicKey,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send registration message: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Info("Sent registration message")
|
||||
return nil
|
||||
}
|
||||
|
||||
func keepSendingRegistration(olm *websocket.Client, publicKey string) {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stopRegister:
|
||||
logger.Info("Stopping registration messages")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := sendRegistration(olm, publicKey); err != nil {
|
||||
logger.Error("Failed to send periodic registration: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
||||
if maxPort < minPort {
|
||||
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
|
||||
|
||||
30
main.go
30
main.go
@@ -157,7 +157,7 @@ func runOlmMain(ctx context.Context) {
|
||||
|
||||
func runOlmMainWithArgs(ctx context.Context, args []string) {
|
||||
// Log that we've entered the main function
|
||||
fmt.Printf("runOlmMainWithArgs() called with args: %v\n", args)
|
||||
// fmt.Printf("runOlmMainWithArgs() called with args: %v\n", args)
|
||||
|
||||
// Create a new FlagSet for parsing service arguments
|
||||
serviceFlags := flag.NewFlagSet("service", flag.ContinueOnError)
|
||||
@@ -179,10 +179,10 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
||||
testTarget string // Add this var for test target
|
||||
pingInterval time.Duration
|
||||
pingTimeout time.Duration
|
||||
doHolepunch bool
|
||||
)
|
||||
|
||||
stopHolepunch = make(chan struct{})
|
||||
stopRegister = make(chan struct{})
|
||||
stopPing = make(chan struct{})
|
||||
|
||||
// if PANGOLIN_ENDPOINT, OLM_ID, and OLM_SECRET are set as environment variables, they will be used as default values
|
||||
@@ -196,6 +196,7 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
||||
httpAddr = os.Getenv("HTTP_ADDR")
|
||||
pingIntervalStr := os.Getenv("PING_INTERVAL")
|
||||
pingTimeoutStr := os.Getenv("PING_TIMEOUT")
|
||||
doHolepunch = os.Getenv("HOLEPUNCH") == "true" // Default to true, can be overridden by flag
|
||||
|
||||
if endpoint == "" {
|
||||
serviceFlags.StringVar(&endpoint, "endpoint", "", "Endpoint of your Pangolin server")
|
||||
@@ -227,6 +228,8 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
||||
if pingTimeoutStr == "" {
|
||||
serviceFlags.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 3s)")
|
||||
}
|
||||
serviceFlags.BoolVar(&enableHTTP, "enable-http", false, "Enable HTT server for receiving connection requests")
|
||||
serviceFlags.BoolVar(&doHolepunch, "holepunch", false, "Enable hole punching (default false)")
|
||||
|
||||
// Parse the service arguments
|
||||
if err := serviceFlags.Parse(args); err != nil {
|
||||
@@ -442,7 +445,10 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
||||
|
||||
connectTimes++
|
||||
|
||||
close(stopRegister)
|
||||
if stopRegister != nil {
|
||||
stopRegister()
|
||||
stopRegister = nil
|
||||
}
|
||||
|
||||
// if there is an existing tunnel then close it
|
||||
if dev != nil {
|
||||
@@ -566,6 +572,7 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
||||
fixKey(privateKey.String()),
|
||||
olm,
|
||||
dev,
|
||||
doHolepunch,
|
||||
)
|
||||
|
||||
// loop over the sites and call ConfigurePeer for each one
|
||||
@@ -791,9 +798,14 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
||||
|
||||
olm.OnConnect(func() error {
|
||||
publicKey := privateKey.PublicKey()
|
||||
logger.Debug("Public key: %s", publicKey)
|
||||
|
||||
go keepSendingRegistration(olm, publicKey.String())
|
||||
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch)
|
||||
|
||||
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"relay": !doHolepunch,
|
||||
}, 1*time.Second)
|
||||
|
||||
go keepSendingPing(olm)
|
||||
|
||||
if httpServer != nil {
|
||||
@@ -832,11 +844,9 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
||||
close(stopHolepunch)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-stopRegister:
|
||||
// Channel already closed
|
||||
default:
|
||||
close(stopRegister)
|
||||
if stopRegister != nil {
|
||||
stopRegister()
|
||||
stopRegister = nil
|
||||
}
|
||||
|
||||
select {
|
||||
|
||||
@@ -26,31 +26,33 @@ type WireGuardConfig struct {
|
||||
|
||||
// PeerMonitor handles monitoring the connection status to multiple WireGuard peers
|
||||
type PeerMonitor struct {
|
||||
monitors map[int]*wgtester.Client
|
||||
configs map[int]*WireGuardConfig
|
||||
callback PeerMonitorCallback
|
||||
mutex sync.Mutex
|
||||
running bool
|
||||
interval time.Duration
|
||||
timeout time.Duration
|
||||
maxAttempts int
|
||||
privateKey string
|
||||
wsClient *websocket.Client
|
||||
device *device.Device
|
||||
monitors map[int]*wgtester.Client
|
||||
configs map[int]*WireGuardConfig
|
||||
callback PeerMonitorCallback
|
||||
mutex sync.Mutex
|
||||
running bool
|
||||
interval time.Duration
|
||||
timeout time.Duration
|
||||
maxAttempts int
|
||||
privateKey string
|
||||
wsClient *websocket.Client
|
||||
device *device.Device
|
||||
handleRelaySwitch bool // Whether to handle relay switching
|
||||
}
|
||||
|
||||
// NewPeerMonitor creates a new peer monitor with the given callback
|
||||
func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device) *PeerMonitor {
|
||||
func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool) *PeerMonitor {
|
||||
return &PeerMonitor{
|
||||
monitors: make(map[int]*wgtester.Client),
|
||||
configs: make(map[int]*WireGuardConfig),
|
||||
callback: callback,
|
||||
interval: 1 * time.Second, // Default check interval
|
||||
timeout: 2500 * time.Millisecond,
|
||||
maxAttempts: 8,
|
||||
privateKey: privateKey,
|
||||
wsClient: wsClient,
|
||||
device: device,
|
||||
monitors: make(map[int]*wgtester.Client),
|
||||
configs: make(map[int]*WireGuardConfig),
|
||||
callback: callback,
|
||||
interval: 1 * time.Second, // Default check interval
|
||||
timeout: 2500 * time.Millisecond,
|
||||
maxAttempts: 8,
|
||||
privateKey: privateKey,
|
||||
wsClient: wsClient,
|
||||
device: device,
|
||||
handleRelaySwitch: handleRelaySwitch,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -214,6 +216,10 @@ persistent_keepalive_interval=1`, pm.privateKey, config.PublicKey, config.Server
|
||||
|
||||
// sendRelay sends a relay message to the server
|
||||
func (pm *PeerMonitor) sendRelay(siteID int) error {
|
||||
if !pm.handleRelaySwitch {
|
||||
return nil
|
||||
}
|
||||
|
||||
if pm.wsClient == nil {
|
||||
return fmt.Errorf("websocket client is nil")
|
||||
}
|
||||
|
||||
@@ -379,7 +379,7 @@ func debugService(args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("Starting service in debug mode...\n")
|
||||
// fmt.Printf("Starting service in debug mode...\n")
|
||||
|
||||
// Start the service
|
||||
err := startService([]string{}) // Pass empty args since we already saved them
|
||||
@@ -387,8 +387,8 @@ func debugService(args []string) error {
|
||||
return fmt.Errorf("failed to start service: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Service started. Watching logs (Press Ctrl+C to stop watching)...\n")
|
||||
fmt.Printf("================================================================================\n")
|
||||
// fmt.Printf("Service started. Watching logs (Press Ctrl+C to stop watching)...\n")
|
||||
// fmt.Printf("================================================================================\n")
|
||||
|
||||
// Watch the log file
|
||||
return watchLogFile(true)
|
||||
|
||||
Reference in New Issue
Block a user