mirror of
https://github.com/fosrl/newt.git
synced 2026-02-07 21:46:39 +00:00
Handle holepunches better
This commit is contained in:
@@ -184,7 +184,7 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str
|
||||
|
||||
// Create the holepunch manager with ResolveDomain function
|
||||
// We'll need to pass a domain resolver function
|
||||
service.holePunchManager = holepunch.NewManager(sharedBind, newtId, "newt")
|
||||
service.holePunchManager = holepunch.NewManager(sharedBind, newtId, "newt", key.PublicKey().String())
|
||||
|
||||
// Register websocket handlers
|
||||
wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig)
|
||||
|
||||
@@ -30,20 +30,29 @@ type Manager struct {
|
||||
sharedBind *bind.SharedBind
|
||||
ID string
|
||||
token string
|
||||
publicKey string
|
||||
clientType string
|
||||
exitNodes map[string]ExitNode // key is endpoint
|
||||
updateChan chan struct{} // signals the goroutine to refresh exit nodes
|
||||
|
||||
sendHolepunchInterval time.Duration
|
||||
}
|
||||
|
||||
const sendHolepunchIntervalMax = 60 * time.Second
|
||||
const sendHolepunchIntervalMin = 1 * time.Second
|
||||
|
||||
// NewManager creates a new hole punch manager
|
||||
func NewManager(sharedBind *bind.SharedBind, ID string, clientType string) *Manager {
|
||||
func NewManager(sharedBind *bind.SharedBind, ID string, clientType string, publicKey string) *Manager {
|
||||
return &Manager{
|
||||
sharedBind: sharedBind,
|
||||
ID: ID,
|
||||
clientType: clientType,
|
||||
sharedBind: sharedBind,
|
||||
ID: ID,
|
||||
clientType: clientType,
|
||||
publicKey: publicKey,
|
||||
exitNodes: make(map[string]ExitNode),
|
||||
sendHolepunchInterval: sendHolepunchIntervalMin,
|
||||
}
|
||||
}
|
||||
|
||||
const sendHolepunchInterval = 15 * time.Second
|
||||
|
||||
// SetToken updates the authentication token used for hole punching
|
||||
func (m *Manager) SetToken(token string) {
|
||||
m.mu.Lock()
|
||||
@@ -72,10 +81,129 @@ func (m *Manager) Stop() {
|
||||
m.stopChan = nil
|
||||
}
|
||||
|
||||
if m.updateChan != nil {
|
||||
close(m.updateChan)
|
||||
m.updateChan = nil
|
||||
}
|
||||
|
||||
m.running = false
|
||||
logger.Info("Hole punch manager stopped")
|
||||
}
|
||||
|
||||
// AddExitNode adds a new exit node to the rotation if it doesn't already exist
|
||||
func (m *Manager) AddExitNode(exitNode ExitNode) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.exitNodes[exitNode.Endpoint]; exists {
|
||||
logger.Debug("Exit node %s already exists in rotation", exitNode.Endpoint)
|
||||
return false
|
||||
}
|
||||
|
||||
m.exitNodes[exitNode.Endpoint] = exitNode
|
||||
logger.Info("Added exit node %s to hole punch rotation", exitNode.Endpoint)
|
||||
|
||||
// Signal the goroutine to refresh if running
|
||||
if m.running && m.updateChan != nil {
|
||||
select {
|
||||
case m.updateChan <- struct{}{}:
|
||||
default:
|
||||
// Channel full or closed, skip
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// RemoveExitNode removes an exit node from the rotation
|
||||
func (m *Manager) RemoveExitNode(endpoint string) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.exitNodes[endpoint]; !exists {
|
||||
logger.Debug("Exit node %s not found in rotation", endpoint)
|
||||
return false
|
||||
}
|
||||
|
||||
delete(m.exitNodes, endpoint)
|
||||
logger.Info("Removed exit node %s from hole punch rotation", endpoint)
|
||||
|
||||
// Signal the goroutine to refresh if running
|
||||
if m.running && m.updateChan != nil {
|
||||
select {
|
||||
case m.updateChan <- struct{}{}:
|
||||
default:
|
||||
// Channel full or closed, skip
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// GetExitNodes returns a copy of the current exit nodes
|
||||
func (m *Manager) GetExitNodes() []ExitNode {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
nodes := make([]ExitNode, 0, len(m.exitNodes))
|
||||
for _, node := range m.exitNodes {
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
return nodes
|
||||
}
|
||||
|
||||
// TriggerHolePunch sends an immediate hole punch packet to all configured exit nodes
|
||||
// This is useful for triggering hole punching on demand without waiting for the interval
|
||||
func (m *Manager) TriggerHolePunch() error {
|
||||
m.mu.Lock()
|
||||
|
||||
if len(m.exitNodes) == 0 {
|
||||
m.mu.Unlock()
|
||||
return fmt.Errorf("no exit nodes configured")
|
||||
}
|
||||
|
||||
// Get a copy of exit nodes to work with
|
||||
currentExitNodes := make([]ExitNode, 0, len(m.exitNodes))
|
||||
for _, node := range m.exitNodes {
|
||||
currentExitNodes = append(currentExitNodes, node)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
logger.Info("Triggering on-demand hole punch to %d exit nodes", len(currentExitNodes))
|
||||
|
||||
// Send hole punch to all exit nodes
|
||||
successCount := 0
|
||||
for _, exitNode := range currentExitNodes {
|
||||
host, err := util.ResolveDomain(exitNode.Endpoint)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
||||
continue
|
||||
}
|
||||
|
||||
serverAddr := net.JoinHostPort(host, "21820")
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.sendHolePunch(remoteAddr, exitNode.PublicKey); err != nil {
|
||||
logger.Warn("Failed to send on-demand hole punch to %s: %v", exitNode.Endpoint, err)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Debug("Sent on-demand hole punch to %s", exitNode.Endpoint)
|
||||
successCount++
|
||||
}
|
||||
|
||||
if successCount == 0 {
|
||||
return fmt.Errorf("failed to send hole punch to any exit node")
|
||||
}
|
||||
|
||||
logger.Info("Successfully sent on-demand hole punch to %d/%d exit nodes", successCount, len(currentExitNodes))
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartMultipleExitNodes starts hole punching to multiple exit nodes
|
||||
func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error {
|
||||
m.mu.Lock()
|
||||
@@ -92,13 +220,48 @@ func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error {
|
||||
return fmt.Errorf("no exit nodes provided")
|
||||
}
|
||||
|
||||
// Populate exit nodes map
|
||||
m.exitNodes = make(map[string]ExitNode)
|
||||
for _, node := range exitNodes {
|
||||
m.exitNodes[node.Endpoint] = node
|
||||
}
|
||||
|
||||
m.running = true
|
||||
m.stopChan = make(chan struct{})
|
||||
m.updateChan = make(chan struct{}, 1)
|
||||
m.mu.Unlock()
|
||||
|
||||
logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes))
|
||||
|
||||
go m.runMultipleExitNodes(exitNodes)
|
||||
go m.runMultipleExitNodes()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start starts hole punching with the current set of exit nodes
|
||||
func (m *Manager) Start() error {
|
||||
m.mu.Lock()
|
||||
|
||||
if m.running {
|
||||
m.mu.Unlock()
|
||||
logger.Debug("UDP hole punch already running")
|
||||
return fmt.Errorf("hole punch already running")
|
||||
}
|
||||
|
||||
if len(m.exitNodes) == 0 {
|
||||
m.mu.Unlock()
|
||||
logger.Warn("No exit nodes configured for hole punching")
|
||||
return fmt.Errorf("no exit nodes configured")
|
||||
}
|
||||
|
||||
m.running = true
|
||||
m.stopChan = make(chan struct{})
|
||||
m.updateChan = make(chan struct{}, 1)
|
||||
m.mu.Unlock()
|
||||
|
||||
logger.Info("Starting UDP hole punch with %d exit nodes", len(m.exitNodes))
|
||||
|
||||
go m.runMultipleExitNodes()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -125,7 +288,7 @@ func (m *Manager) StartSingleEndpoint(endpoint, serverPubKey string) error {
|
||||
}
|
||||
|
||||
// runMultipleExitNodes performs hole punching to multiple exit nodes
|
||||
func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) {
|
||||
func (m *Manager) runMultipleExitNodes() {
|
||||
defer func() {
|
||||
m.mu.Lock()
|
||||
m.running = false
|
||||
@@ -140,29 +303,41 @@ func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) {
|
||||
endpointName string
|
||||
}
|
||||
|
||||
var resolvedNodes []resolvedExitNode
|
||||
for _, exitNode := range exitNodes {
|
||||
host, err := util.ResolveDomain(exitNode.Endpoint)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
||||
continue
|
||||
resolveNodes := func() []resolvedExitNode {
|
||||
m.mu.Lock()
|
||||
currentExitNodes := make([]ExitNode, 0, len(m.exitNodes))
|
||||
for _, node := range m.exitNodes {
|
||||
currentExitNodes = append(currentExitNodes, node)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
serverAddr := net.JoinHostPort(host, "21820")
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
|
||||
continue
|
||||
var resolvedNodes []resolvedExitNode
|
||||
for _, exitNode := range currentExitNodes {
|
||||
host, err := util.ResolveDomain(exitNode.Endpoint)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
||||
continue
|
||||
}
|
||||
|
||||
serverAddr := net.JoinHostPort(host, "21820")
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
resolvedNodes = append(resolvedNodes, resolvedExitNode{
|
||||
remoteAddr: remoteAddr,
|
||||
publicKey: exitNode.PublicKey,
|
||||
endpointName: exitNode.Endpoint,
|
||||
})
|
||||
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
|
||||
}
|
||||
|
||||
resolvedNodes = append(resolvedNodes, resolvedExitNode{
|
||||
remoteAddr: remoteAddr,
|
||||
publicKey: exitNode.PublicKey,
|
||||
endpointName: exitNode.Endpoint,
|
||||
})
|
||||
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
|
||||
return resolvedNodes
|
||||
}
|
||||
|
||||
resolvedNodes := resolveNodes()
|
||||
|
||||
if len(resolvedNodes) == 0 {
|
||||
logger.Error("No exit nodes could be resolved")
|
||||
return
|
||||
@@ -175,7 +350,12 @@ func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) {
|
||||
}
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(sendHolepunchInterval)
|
||||
// Start with minimum interval
|
||||
m.mu.Lock()
|
||||
m.sendHolepunchInterval = sendHolepunchIntervalMin
|
||||
m.mu.Unlock()
|
||||
|
||||
ticker := time.NewTicker(m.sendHolepunchInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
@@ -183,6 +363,24 @@ func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) {
|
||||
case <-m.stopChan:
|
||||
logger.Debug("Hole punch stopped by signal")
|
||||
return
|
||||
case <-m.updateChan:
|
||||
// Re-resolve exit nodes when update is signaled
|
||||
logger.Info("Refreshing exit nodes for hole punching")
|
||||
resolvedNodes = resolveNodes()
|
||||
if len(resolvedNodes) == 0 {
|
||||
logger.Warn("No exit nodes available after refresh")
|
||||
}
|
||||
// Reset interval to minimum on update
|
||||
m.mu.Lock()
|
||||
m.sendHolepunchInterval = sendHolepunchIntervalMin
|
||||
m.mu.Unlock()
|
||||
ticker.Reset(m.sendHolepunchInterval)
|
||||
// Send immediate hole punch to newly resolved nodes
|
||||
for _, node := range resolvedNodes {
|
||||
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
|
||||
logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err)
|
||||
}
|
||||
}
|
||||
case <-ticker.C:
|
||||
// Send hole punch to all exit nodes
|
||||
for _, node := range resolvedNodes {
|
||||
@@ -190,6 +388,18 @@ func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) {
|
||||
logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err)
|
||||
}
|
||||
}
|
||||
// Exponential backoff: double the interval up to max
|
||||
m.mu.Lock()
|
||||
newInterval := m.sendHolepunchInterval * 2
|
||||
if newInterval > sendHolepunchIntervalMax {
|
||||
newInterval = sendHolepunchIntervalMax
|
||||
}
|
||||
if newInterval != m.sendHolepunchInterval {
|
||||
m.sendHolepunchInterval = newInterval
|
||||
ticker.Reset(m.sendHolepunchInterval)
|
||||
logger.Debug("Increased hole punch interval to %v", m.sendHolepunchInterval)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -222,7 +432,12 @@ func (m *Manager) runSingleEndpoint(endpoint, serverPubKey string) {
|
||||
logger.Warn("Failed to send initial hole punch: %v", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(sendHolepunchInterval)
|
||||
// Start with minimum interval
|
||||
m.mu.Lock()
|
||||
m.sendHolepunchInterval = sendHolepunchIntervalMin
|
||||
m.mu.Unlock()
|
||||
|
||||
ticker := time.NewTicker(m.sendHolepunchInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
@@ -234,6 +449,18 @@ func (m *Manager) runSingleEndpoint(endpoint, serverPubKey string) {
|
||||
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
|
||||
logger.Debug("Failed to send hole punch: %v", err)
|
||||
}
|
||||
// Exponential backoff: double the interval up to max
|
||||
m.mu.Lock()
|
||||
newInterval := m.sendHolepunchInterval * 2
|
||||
if newInterval > sendHolepunchIntervalMax {
|
||||
newInterval = sendHolepunchIntervalMax
|
||||
}
|
||||
if newInterval != m.sendHolepunchInterval {
|
||||
m.sendHolepunchInterval = newInterval
|
||||
ticker.Reset(m.sendHolepunchInterval)
|
||||
logger.Debug("Increased hole punch interval to %v", m.sendHolepunchInterval)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -252,19 +479,23 @@ func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) er
|
||||
var payload interface{}
|
||||
if m.clientType == "newt" {
|
||||
payload = struct {
|
||||
ID string `json:"newtId"`
|
||||
Token string `json:"token"`
|
||||
ID string `json:"newtId"`
|
||||
Token string `json:"token"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
}{
|
||||
ID: ID,
|
||||
Token: token,
|
||||
ID: ID,
|
||||
Token: token,
|
||||
PublicKey: m.publicKey,
|
||||
}
|
||||
} else {
|
||||
payload = struct {
|
||||
ID string `json:"olmId"`
|
||||
Token string `json:"token"`
|
||||
ID string `json:"olmId"`
|
||||
Token string `json:"token"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
}{
|
||||
ID: ID,
|
||||
Token: token,
|
||||
ID: ID,
|
||||
Token: token,
|
||||
PublicKey: m.publicKey,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user