Handle holepunches better

Former-commit-id: 136eee3302
This commit is contained in:
Owen
2025-12-01 13:54:01 -05:00
parent fb007e09a9
commit 7270b840cf
4 changed files with 167 additions and 125 deletions

View File

@@ -235,7 +235,7 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
OrgID: config.OrgID,
OverrideDNS: config.OverrideDNS,
EnableUAPI: true,
DisableRelay: true,
DisableRelay: false, // allow it to relay
}
go olm.StartTunnel(tunnelConfig)
} else {

View File

@@ -33,7 +33,6 @@ var (
connected bool
dev *device.Device
wgData WgData
holePunchData HolePunchData
uapiListener net.Listener
tdev tun.Device
middleDev *olmDevice.MiddleDevice
@@ -48,13 +47,22 @@ var (
globalConfig GlobalConfig
globalCtx context.Context
stopRegister func()
stopPeerSend func()
updateRegister func(newData interface{})
stopPing chan struct{}
peerManager *peers.PeerManager
)
// initSharedBindAndHolepunch creates the shared UDP socket and holepunch manager.
// initTunnelInfo creates the shared UDP socket and holepunch manager.
// This is used during initial tunnel setup and when switching organizations.
func initSharedBindAndHolepunch(clientID string) error {
func initTunnelInfo(clientID string) error {
var err error
privateKey, err = wgtypes.GeneratePrivateKey()
if err != nil {
logger.Error("Failed to generate private key: %v", err)
return err
}
sourcePort, err := util.FindAvailableUDPPort(49152, 65535)
if err != nil {
return fmt.Errorf("failed to find available UDP port: %w", err)
@@ -82,7 +90,7 @@ func initSharedBindAndHolepunch(clientID string) error {
logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount())
// Create the holepunch manager
holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm")
holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm", privateKey.PublicKey().String())
return nil
}
@@ -249,82 +257,12 @@ func StartTunnel(config TunnelConfig) {
// Store the client reference globally
olmClient = olm
privateKey, err = wgtypes.GeneratePrivateKey()
if err != nil {
logger.Error("Failed to generate private key: %v", err)
return
}
// Create shared UDP socket and holepunch manager
if err := initSharedBindAndHolepunch(id); err != nil {
if err := initTunnelInfo(id); err != nil {
logger.Error("%v", err)
return
}
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &holePunchData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
// Convert HolePunchData.ExitNodes to holepunch.ExitNode slice
exitNodes := make([]holepunch.ExitNode, len(holePunchData.ExitNodes))
for i, node := range holePunchData.ExitNodes {
exitNodes[i] = holepunch.ExitNode{
Endpoint: node.Endpoint,
PublicKey: node.PublicKey,
}
}
// Start hole punching using the manager
logger.Info("Starting hole punch for %d exit nodes", len(exitNodes))
if err := holePunchManager.StartMultipleExitNodes(exitNodes); err != nil {
logger.Warn("Failed to start hole punch: %v", err)
}
})
olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) {
// THIS ENDPOINT IS FOR BACKWARD COMPATIBILITY
logger.Debug("Received message: %v", msg.Data)
type LegacyHolePunchData struct {
ServerPubKey string `json:"serverPubKey"`
Endpoint string `json:"endpoint"`
}
var legacyHolePunchData LegacyHolePunchData
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &legacyHolePunchData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
// Stop any existing hole punch operations
if holePunchManager != nil {
holePunchManager.Stop()
}
// Start hole punching for the exit node
logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey)
if err := holePunchManager.StartSingleEndpoint(legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey); err != nil {
logger.Warn("Failed to start hole punch: %v", err)
}
})
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
@@ -338,9 +276,9 @@ func StartTunnel(config TunnelConfig) {
stopRegister = nil
}
// wait 10 milliseconds to ensure the previous connection is closed
logger.Debug("Waiting 500 milliseconds to ensure previous connection is closed")
time.Sleep(500 * time.Millisecond)
if updateRegister != nil {
updateRegister = nil
}
// if there is an existing tunnel then close it
if dev != nil {
@@ -572,6 +510,11 @@ func StartTunnel(config TunnelConfig) {
olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) {
logger.Debug("Received add-peer message: %v", msg.Data)
if stopPeerSend != nil {
stopPeerSend()
stopPeerSend = nil
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
@@ -584,6 +527,8 @@ func StartTunnel(config TunnelConfig) {
return
}
holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
if err := peerManager.AddPeer(siteConfig, endpoint); err != nil {
logger.Error("Failed to add peer: %v", err)
return
@@ -753,6 +698,59 @@ func StartTunnel(config TunnelConfig) {
peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
})
// Handler for peer handshake - adds exit node to holepunch rotation and notifies server
olm.RegisterHandler("olm/wg/peer/holepunch/site/add", func(msg websocket.WSMessage) {
logger.Debug("Received peer-handshake message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling handshake data: %v", err)
return
}
var handshakeData struct {
SiteId int `json:"siteId"`
ExitNode struct {
PublicKey string `json:"publicKey"`
Endpoint string `json:"endpoint"`
} `json:"exitNode"`
}
if err := json.Unmarshal(jsonData, &handshakeData); err != nil {
logger.Error("Error unmarshaling handshake data: %v", err)
return
}
// Add exit node to holepunch rotation if we have a holepunch manager
if holePunchManager != nil {
exitNode := holepunch.ExitNode{
Endpoint: handshakeData.ExitNode.Endpoint,
PublicKey: handshakeData.ExitNode.PublicKey,
}
added := holePunchManager.AddExitNode(exitNode)
if added {
logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint)
} else {
logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint)
}
// Start holepunching if not already running
if !holePunchManager.IsRunning() {
if err := holePunchManager.Start(); err != nil {
logger.Error("Failed to start holepunch manager: %v", err)
}
}
}
// Send handshake acknowledgment back to server with retry
stopPeerSend, _ = olm.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": handshakeData.SiteId,
}, 1*time.Second)
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
})
olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) {
logger.Info("Received terminate message")
apiServer.SetTerminated(true)
@@ -779,15 +777,17 @@ func StartTunnel(config TunnelConfig) {
publicKey := privateKey.PublicKey()
// delay for 500ms to allow for time for the hp to get processed
time.Sleep(500 * time.Millisecond)
if stopRegister == nil {
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch)
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
stopRegister, updateRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !config.Holepunch,
"olmVersion": globalConfig.Version,
"orgId": config.OrgID,
"userToken": userToken,
// "doNotCreateNewClient": config.DoNotCreateNewClient,
}, 1*time.Second)
// Invoke onRegistered callback if configured
@@ -801,9 +801,28 @@ func StartTunnel(config TunnelConfig) {
return nil
})
olm.OnTokenUpdate(func(token string) {
olm.OnTokenUpdate(func(token string, exitNodes []websocket.ExitNode) {
if holePunchManager != nil {
holePunchManager.SetToken(token)
logger.Debug("Got exit nodes for hole punching: %v", exitNodes)
// Convert websocket.ExitNode to holepunch.ExitNode
hpExitNodes := make([]holepunch.ExitNode, len(exitNodes))
for i, node := range exitNodes {
hpExitNodes[i] = holepunch.ExitNode{
Endpoint: node.Endpoint,
PublicKey: node.PublicKey,
}
}
logger.Debug("Updated hole punch exit nodes: %v", hpExitNodes)
// Start hole punching using the manager
logger.Info("Starting hole punch for %d exit nodes", len(exitNodes))
if err := holePunchManager.StartMultipleExitNodes(hpExitNodes); err != nil {
logger.Warn("Failed to start hole punch: %v", err)
}
}
})
@@ -814,6 +833,7 @@ func StartTunnel(config TunnelConfig) {
apiServer.SetRegistered(false)
apiServer.ClearPeerStatuses()
network.ClearNetworkSettings()
Close()
if globalConfig.OnAuthError != nil {
@@ -864,6 +884,10 @@ func Close() {
stopRegister = nil
}
if updateRegister != nil {
updateRegister = nil
}
if peerMonitor != nil {
peerMonitor.Close() // Close() also calls Stop() internally
peerMonitor = nil
@@ -992,7 +1016,7 @@ func SwitchOrg(orgID string) error {
Close()
// Recreate sharedBind and holepunch manager - needed because Close() releases them
if err := initSharedBindAndHolepunch(olmClient.GetConfig().ID); err != nil {
if err := initTunnelInfo(olmClient.GetConfig().ID); err != nil {
return err
}
@@ -1002,7 +1026,7 @@ func SwitchOrg(orgID string) error {
// Trigger re-registration with new orgId
logger.Info("Re-registering with new orgId: %s", orgID)
publicKey := privateKey.PublicKey()
stopRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]interface{}{
stopRegister, updateRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": true, // Default to relay mode for org switch
"olmVersion": globalConfig.Version,

View File

@@ -12,25 +12,6 @@ type WgData struct {
UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses
}
type HolePunchMessage struct {
NewtID string `json:"newtId"`
}
type ExitNode struct {
Endpoint string `json:"endpoint"`
PublicKey string `json:"publicKey"`
}
type HolePunchData struct {
ExitNodes []ExitNode `json:"exitNodes"`
}
type EncryptedHolePunchMessage struct {
EphemeralPublicKey string `json:"ephemeralPublicKey"`
Nonce []byte `json:"nonce"`
Ciphertext []byte `json:"ciphertext"`
}
type GlobalConfig struct {
// Logging
LogLevel string

View File

@@ -38,12 +38,18 @@ func IsAuthError(err error) bool {
type TokenResponse struct {
Data struct {
Token string `json:"token"`
Token string `json:"token"`
ExitNodes []ExitNode `json:"exitNodes"`
} `json:"data"`
Success bool `json:"success"`
Message string `json:"message"`
}
type ExitNode struct {
Endpoint string `json:"endpoint"`
PublicKey string `json:"publicKey"`
}
type WSMessage struct {
Type string `json:"type"`
Data interface{} `json:"data"`
@@ -71,7 +77,7 @@ type Client struct {
pingInterval time.Duration
pingTimeout time.Duration
onConnect func() error
onTokenUpdate func(token string)
onTokenUpdate func(token string, exitNodes []ExitNode)
onAuthError func(statusCode int, message string) // Callback for auth errors
writeMux sync.Mutex
clientType string // Type of client (e.g., "newt", "olm")
@@ -116,7 +122,7 @@ func (c *Client) OnConnect(callback func() error) {
c.onConnect = callback
}
func (c *Client) OnTokenUpdate(callback func(token string)) {
func (c *Client) OnTokenUpdate(callback func(token string, exitNodes []ExitNode)) {
c.onTokenUpdate = callback
}
@@ -212,13 +218,17 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
return c.conn.WriteJSON(msg)
}
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) {
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) {
stopChan := make(chan struct{})
updateChan := make(chan interface{})
var dataMux sync.Mutex
currentData := data
go func() {
count := 0
maxAttempts := 10
err := c.SendMessage(messageType, data) // Send immediately
err := c.SendMessage(messageType, currentData) // Send immediately
if err != nil {
logger.Error("Failed to send initial message: %v", err)
}
@@ -233,19 +243,46 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
return
}
err = c.SendMessage(messageType, data)
dataMux.Lock()
err = c.SendMessage(messageType, currentData)
dataMux.Unlock()
if err != nil {
logger.Error("Failed to send message: %v", err)
}
count++
case newData := <-updateChan:
dataMux.Lock()
// Merge newData into currentData if both are maps
if currentMap, ok := currentData.(map[string]interface{}); ok {
if newMap, ok := newData.(map[string]interface{}); ok {
// Update or add keys from newData
for key, value := range newMap {
currentMap[key] = value
}
currentData = currentMap
} else {
// If newData is not a map, replace entirely
currentData = newData
}
} else {
// If currentData is not a map, replace entirely
currentData = newData
}
dataMux.Unlock()
case <-stopChan:
return
}
}
}()
return func() {
close(stopChan)
}
close(stopChan)
}, func(newData interface{}) {
select {
case updateChan <- newData:
case <-stopChan:
// Channel is closed, ignore update
}
}
}
// RegisterHandler registers a handler for a specific message type
@@ -255,11 +292,11 @@ func (c *Client) RegisterHandler(messageType string, handler MessageHandler) {
c.handlers[messageType] = handler
}
func (c *Client) getToken() (string, error) {
func (c *Client) getToken() (string, []ExitNode, error) {
// Parse the base URL to ensure we have the correct hostname
baseURL, err := url.Parse(c.baseURL)
if err != nil {
return "", fmt.Errorf("failed to parse base URL: %w", err)
return "", nil, fmt.Errorf("failed to parse base URL: %w", err)
}
// Ensure we have the base URL without trailing slashes
@@ -271,7 +308,7 @@ func (c *Client) getToken() (string, error) {
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
tlsConfig, err = c.setupTLS()
if err != nil {
return "", fmt.Errorf("failed to setup TLS configuration: %w", err)
return "", nil, fmt.Errorf("failed to setup TLS configuration: %w", err)
}
}
@@ -293,7 +330,7 @@ func (c *Client) getToken() (string, error) {
jsonData, err := json.Marshal(tokenData)
if err != nil {
return "", fmt.Errorf("failed to marshal token request data: %w", err)
return "", nil, fmt.Errorf("failed to marshal token request data: %w", err)
}
// Create a new request
@@ -303,7 +340,7 @@ func (c *Client) getToken() (string, error) {
bytes.NewBuffer(jsonData),
)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
return "", nil, fmt.Errorf("failed to create request: %w", err)
}
// Set headers
@@ -319,7 +356,7 @@ func (c *Client) getToken() (string, error) {
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to request new token: %w", err)
return "", nil, fmt.Errorf("failed to request new token: %w", err)
}
defer resp.Body.Close()
@@ -329,33 +366,33 @@ func (c *Client) getToken() (string, error) {
// Return AuthError for 401/403 status codes
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
return "", &AuthError{
return "", nil, &AuthError{
StatusCode: resp.StatusCode,
Message: string(body),
}
}
// For other errors (5xx, network issues, etc.), return regular error
return "", fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
return "", nil, fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
}
var tokenResp TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
logger.Error("Failed to decode token response.")
return "", fmt.Errorf("failed to decode token response: %w", err)
return "", nil, fmt.Errorf("failed to decode token response: %w", err)
}
if !tokenResp.Success {
return "", fmt.Errorf("failed to get token: %s", tokenResp.Message)
return "", nil, fmt.Errorf("failed to get token: %s", tokenResp.Message)
}
if tokenResp.Data.Token == "" {
return "", fmt.Errorf("received empty token from server")
return "", nil, fmt.Errorf("received empty token from server")
}
logger.Debug("Received token: %s", tokenResp.Data.Token)
return tokenResp.Data.Token, nil
return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil
}
func (c *Client) connectWithRetry() {
@@ -389,13 +426,13 @@ func (c *Client) connectWithRetry() {
func (c *Client) establishConnection() error {
// Get token for authentication
token, err := c.getToken()
token, exitNodes, err := c.getToken()
if err != nil {
return fmt.Errorf("failed to get token: %w", err)
}
if c.onTokenUpdate != nil {
c.onTokenUpdate(token)
c.onTokenUpdate(token, exitNodes)
}
// Parse the base URL to determine protocol and hostname