mirror of
https://github.com/fosrl/newt.git
synced 2026-02-26 14:56:40 +00:00
Handle holepunches better
This commit is contained in:
@@ -184,7 +184,7 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str
|
|||||||
|
|
||||||
// Create the holepunch manager with ResolveDomain function
|
// Create the holepunch manager with ResolveDomain function
|
||||||
// We'll need to pass a domain resolver function
|
// We'll need to pass a domain resolver function
|
||||||
service.holePunchManager = holepunch.NewManager(sharedBind, newtId, "newt")
|
service.holePunchManager = holepunch.NewManager(sharedBind, newtId, "newt", key.PublicKey().String())
|
||||||
|
|
||||||
// Register websocket handlers
|
// Register websocket handlers
|
||||||
wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig)
|
wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig)
|
||||||
|
|||||||
@@ -30,20 +30,29 @@ type Manager struct {
|
|||||||
sharedBind *bind.SharedBind
|
sharedBind *bind.SharedBind
|
||||||
ID string
|
ID string
|
||||||
token string
|
token string
|
||||||
|
publicKey string
|
||||||
clientType string
|
clientType string
|
||||||
|
exitNodes map[string]ExitNode // key is endpoint
|
||||||
|
updateChan chan struct{} // signals the goroutine to refresh exit nodes
|
||||||
|
|
||||||
|
sendHolepunchInterval time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const sendHolepunchIntervalMax = 60 * time.Second
|
||||||
|
const sendHolepunchIntervalMin = 1 * time.Second
|
||||||
|
|
||||||
// NewManager creates a new hole punch manager
|
// NewManager creates a new hole punch manager
|
||||||
func NewManager(sharedBind *bind.SharedBind, ID string, clientType string) *Manager {
|
func NewManager(sharedBind *bind.SharedBind, ID string, clientType string, publicKey string) *Manager {
|
||||||
return &Manager{
|
return &Manager{
|
||||||
sharedBind: sharedBind,
|
sharedBind: sharedBind,
|
||||||
ID: ID,
|
ID: ID,
|
||||||
clientType: clientType,
|
clientType: clientType,
|
||||||
|
publicKey: publicKey,
|
||||||
|
exitNodes: make(map[string]ExitNode),
|
||||||
|
sendHolepunchInterval: sendHolepunchIntervalMin,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const sendHolepunchInterval = 15 * time.Second
|
|
||||||
|
|
||||||
// SetToken updates the authentication token used for hole punching
|
// SetToken updates the authentication token used for hole punching
|
||||||
func (m *Manager) SetToken(token string) {
|
func (m *Manager) SetToken(token string) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
@@ -72,10 +81,129 @@ func (m *Manager) Stop() {
|
|||||||
m.stopChan = nil
|
m.stopChan = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.updateChan != nil {
|
||||||
|
close(m.updateChan)
|
||||||
|
m.updateChan = nil
|
||||||
|
}
|
||||||
|
|
||||||
m.running = false
|
m.running = false
|
||||||
logger.Info("Hole punch manager stopped")
|
logger.Info("Hole punch manager stopped")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddExitNode adds a new exit node to the rotation if it doesn't already exist
|
||||||
|
func (m *Manager) AddExitNode(exitNode ExitNode) bool {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if _, exists := m.exitNodes[exitNode.Endpoint]; exists {
|
||||||
|
logger.Debug("Exit node %s already exists in rotation", exitNode.Endpoint)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
m.exitNodes[exitNode.Endpoint] = exitNode
|
||||||
|
logger.Info("Added exit node %s to hole punch rotation", exitNode.Endpoint)
|
||||||
|
|
||||||
|
// Signal the goroutine to refresh if running
|
||||||
|
if m.running && m.updateChan != nil {
|
||||||
|
select {
|
||||||
|
case m.updateChan <- struct{}{}:
|
||||||
|
default:
|
||||||
|
// Channel full or closed, skip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveExitNode removes an exit node from the rotation
|
||||||
|
func (m *Manager) RemoveExitNode(endpoint string) bool {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if _, exists := m.exitNodes[endpoint]; !exists {
|
||||||
|
logger.Debug("Exit node %s not found in rotation", endpoint)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(m.exitNodes, endpoint)
|
||||||
|
logger.Info("Removed exit node %s from hole punch rotation", endpoint)
|
||||||
|
|
||||||
|
// Signal the goroutine to refresh if running
|
||||||
|
if m.running && m.updateChan != nil {
|
||||||
|
select {
|
||||||
|
case m.updateChan <- struct{}{}:
|
||||||
|
default:
|
||||||
|
// Channel full or closed, skip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetExitNodes returns a copy of the current exit nodes
|
||||||
|
func (m *Manager) GetExitNodes() []ExitNode {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
nodes := make([]ExitNode, 0, len(m.exitNodes))
|
||||||
|
for _, node := range m.exitNodes {
|
||||||
|
nodes = append(nodes, node)
|
||||||
|
}
|
||||||
|
return nodes
|
||||||
|
}
|
||||||
|
|
||||||
|
// TriggerHolePunch sends an immediate hole punch packet to all configured exit nodes
|
||||||
|
// This is useful for triggering hole punching on demand without waiting for the interval
|
||||||
|
func (m *Manager) TriggerHolePunch() error {
|
||||||
|
m.mu.Lock()
|
||||||
|
|
||||||
|
if len(m.exitNodes) == 0 {
|
||||||
|
m.mu.Unlock()
|
||||||
|
return fmt.Errorf("no exit nodes configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get a copy of exit nodes to work with
|
||||||
|
currentExitNodes := make([]ExitNode, 0, len(m.exitNodes))
|
||||||
|
for _, node := range m.exitNodes {
|
||||||
|
currentExitNodes = append(currentExitNodes, node)
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
logger.Info("Triggering on-demand hole punch to %d exit nodes", len(currentExitNodes))
|
||||||
|
|
||||||
|
// Send hole punch to all exit nodes
|
||||||
|
successCount := 0
|
||||||
|
for _, exitNode := range currentExitNodes {
|
||||||
|
host, err := util.ResolveDomain(exitNode.Endpoint)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
serverAddr := net.JoinHostPort(host, "21820")
|
||||||
|
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.sendHolePunch(remoteAddr, exitNode.PublicKey); err != nil {
|
||||||
|
logger.Warn("Failed to send on-demand hole punch to %s: %v", exitNode.Endpoint, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Sent on-demand hole punch to %s", exitNode.Endpoint)
|
||||||
|
successCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
if successCount == 0 {
|
||||||
|
return fmt.Errorf("failed to send hole punch to any exit node")
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Successfully sent on-demand hole punch to %d/%d exit nodes", successCount, len(currentExitNodes))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// StartMultipleExitNodes starts hole punching to multiple exit nodes
|
// StartMultipleExitNodes starts hole punching to multiple exit nodes
|
||||||
func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error {
|
func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
@@ -92,13 +220,48 @@ func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error {
|
|||||||
return fmt.Errorf("no exit nodes provided")
|
return fmt.Errorf("no exit nodes provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Populate exit nodes map
|
||||||
|
m.exitNodes = make(map[string]ExitNode)
|
||||||
|
for _, node := range exitNodes {
|
||||||
|
m.exitNodes[node.Endpoint] = node
|
||||||
|
}
|
||||||
|
|
||||||
m.running = true
|
m.running = true
|
||||||
m.stopChan = make(chan struct{})
|
m.stopChan = make(chan struct{})
|
||||||
|
m.updateChan = make(chan struct{}, 1)
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
|
||||||
logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes))
|
logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes))
|
||||||
|
|
||||||
go m.runMultipleExitNodes(exitNodes)
|
go m.runMultipleExitNodes()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts hole punching with the current set of exit nodes
|
||||||
|
func (m *Manager) Start() error {
|
||||||
|
m.mu.Lock()
|
||||||
|
|
||||||
|
if m.running {
|
||||||
|
m.mu.Unlock()
|
||||||
|
logger.Debug("UDP hole punch already running")
|
||||||
|
return fmt.Errorf("hole punch already running")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(m.exitNodes) == 0 {
|
||||||
|
m.mu.Unlock()
|
||||||
|
logger.Warn("No exit nodes configured for hole punching")
|
||||||
|
return fmt.Errorf("no exit nodes configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.running = true
|
||||||
|
m.stopChan = make(chan struct{})
|
||||||
|
m.updateChan = make(chan struct{}, 1)
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
logger.Info("Starting UDP hole punch with %d exit nodes", len(m.exitNodes))
|
||||||
|
|
||||||
|
go m.runMultipleExitNodes()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -125,7 +288,7 @@ func (m *Manager) StartSingleEndpoint(endpoint, serverPubKey string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// runMultipleExitNodes performs hole punching to multiple exit nodes
|
// runMultipleExitNodes performs hole punching to multiple exit nodes
|
||||||
func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) {
|
func (m *Manager) runMultipleExitNodes() {
|
||||||
defer func() {
|
defer func() {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
m.running = false
|
m.running = false
|
||||||
@@ -140,29 +303,41 @@ func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) {
|
|||||||
endpointName string
|
endpointName string
|
||||||
}
|
}
|
||||||
|
|
||||||
var resolvedNodes []resolvedExitNode
|
resolveNodes := func() []resolvedExitNode {
|
||||||
for _, exitNode := range exitNodes {
|
m.mu.Lock()
|
||||||
host, err := util.ResolveDomain(exitNode.Endpoint)
|
currentExitNodes := make([]ExitNode, 0, len(m.exitNodes))
|
||||||
if err != nil {
|
for _, node := range m.exitNodes {
|
||||||
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
currentExitNodes = append(currentExitNodes, node)
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
serverAddr := net.JoinHostPort(host, "21820")
|
var resolvedNodes []resolvedExitNode
|
||||||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
for _, exitNode := range currentExitNodes {
|
||||||
if err != nil {
|
host, err := util.ResolveDomain(exitNode.Endpoint)
|
||||||
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
|
if err != nil {
|
||||||
continue
|
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
serverAddr := net.JoinHostPort(host, "21820")
|
||||||
|
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvedNodes = append(resolvedNodes, resolvedExitNode{
|
||||||
|
remoteAddr: remoteAddr,
|
||||||
|
publicKey: exitNode.PublicKey,
|
||||||
|
endpointName: exitNode.Endpoint,
|
||||||
|
})
|
||||||
|
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
|
||||||
}
|
}
|
||||||
|
return resolvedNodes
|
||||||
resolvedNodes = append(resolvedNodes, resolvedExitNode{
|
|
||||||
remoteAddr: remoteAddr,
|
|
||||||
publicKey: exitNode.PublicKey,
|
|
||||||
endpointName: exitNode.Endpoint,
|
|
||||||
})
|
|
||||||
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
resolvedNodes := resolveNodes()
|
||||||
|
|
||||||
if len(resolvedNodes) == 0 {
|
if len(resolvedNodes) == 0 {
|
||||||
logger.Error("No exit nodes could be resolved")
|
logger.Error("No exit nodes could be resolved")
|
||||||
return
|
return
|
||||||
@@ -175,7 +350,12 @@ func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ticker := time.NewTicker(sendHolepunchInterval)
|
// Start with minimum interval
|
||||||
|
m.mu.Lock()
|
||||||
|
m.sendHolepunchInterval = sendHolepunchIntervalMin
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(m.sendHolepunchInterval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -183,6 +363,24 @@ func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) {
|
|||||||
case <-m.stopChan:
|
case <-m.stopChan:
|
||||||
logger.Debug("Hole punch stopped by signal")
|
logger.Debug("Hole punch stopped by signal")
|
||||||
return
|
return
|
||||||
|
case <-m.updateChan:
|
||||||
|
// Re-resolve exit nodes when update is signaled
|
||||||
|
logger.Info("Refreshing exit nodes for hole punching")
|
||||||
|
resolvedNodes = resolveNodes()
|
||||||
|
if len(resolvedNodes) == 0 {
|
||||||
|
logger.Warn("No exit nodes available after refresh")
|
||||||
|
}
|
||||||
|
// Reset interval to minimum on update
|
||||||
|
m.mu.Lock()
|
||||||
|
m.sendHolepunchInterval = sendHolepunchIntervalMin
|
||||||
|
m.mu.Unlock()
|
||||||
|
ticker.Reset(m.sendHolepunchInterval)
|
||||||
|
// Send immediate hole punch to newly resolved nodes
|
||||||
|
for _, node := range resolvedNodes {
|
||||||
|
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
|
||||||
|
logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
// Send hole punch to all exit nodes
|
// Send hole punch to all exit nodes
|
||||||
for _, node := range resolvedNodes {
|
for _, node := range resolvedNodes {
|
||||||
@@ -190,6 +388,18 @@ func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) {
|
|||||||
logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err)
|
logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Exponential backoff: double the interval up to max
|
||||||
|
m.mu.Lock()
|
||||||
|
newInterval := m.sendHolepunchInterval * 2
|
||||||
|
if newInterval > sendHolepunchIntervalMax {
|
||||||
|
newInterval = sendHolepunchIntervalMax
|
||||||
|
}
|
||||||
|
if newInterval != m.sendHolepunchInterval {
|
||||||
|
m.sendHolepunchInterval = newInterval
|
||||||
|
ticker.Reset(m.sendHolepunchInterval)
|
||||||
|
logger.Debug("Increased hole punch interval to %v", m.sendHolepunchInterval)
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -222,7 +432,12 @@ func (m *Manager) runSingleEndpoint(endpoint, serverPubKey string) {
|
|||||||
logger.Warn("Failed to send initial hole punch: %v", err)
|
logger.Warn("Failed to send initial hole punch: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ticker := time.NewTicker(sendHolepunchInterval)
|
// Start with minimum interval
|
||||||
|
m.mu.Lock()
|
||||||
|
m.sendHolepunchInterval = sendHolepunchIntervalMin
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(m.sendHolepunchInterval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -234,6 +449,18 @@ func (m *Manager) runSingleEndpoint(endpoint, serverPubKey string) {
|
|||||||
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
|
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
|
||||||
logger.Debug("Failed to send hole punch: %v", err)
|
logger.Debug("Failed to send hole punch: %v", err)
|
||||||
}
|
}
|
||||||
|
// Exponential backoff: double the interval up to max
|
||||||
|
m.mu.Lock()
|
||||||
|
newInterval := m.sendHolepunchInterval * 2
|
||||||
|
if newInterval > sendHolepunchIntervalMax {
|
||||||
|
newInterval = sendHolepunchIntervalMax
|
||||||
|
}
|
||||||
|
if newInterval != m.sendHolepunchInterval {
|
||||||
|
m.sendHolepunchInterval = newInterval
|
||||||
|
ticker.Reset(m.sendHolepunchInterval)
|
||||||
|
logger.Debug("Increased hole punch interval to %v", m.sendHolepunchInterval)
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -252,19 +479,23 @@ func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) er
|
|||||||
var payload interface{}
|
var payload interface{}
|
||||||
if m.clientType == "newt" {
|
if m.clientType == "newt" {
|
||||||
payload = struct {
|
payload = struct {
|
||||||
ID string `json:"newtId"`
|
ID string `json:"newtId"`
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
}{
|
}{
|
||||||
ID: ID,
|
ID: ID,
|
||||||
Token: token,
|
Token: token,
|
||||||
|
PublicKey: m.publicKey,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
payload = struct {
|
payload = struct {
|
||||||
ID string `json:"olmId"`
|
ID string `json:"olmId"`
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
}{
|
}{
|
||||||
ID: ID,
|
ID: ID,
|
||||||
Token: token,
|
Token: token,
|
||||||
|
PublicKey: m.publicKey,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user