mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
251 lines
5.4 KiB
Go
251 lines
5.4 KiB
Go
//go:build !js
|
|
|
|
package portforward
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/libp2p/go-nat"
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const (
|
|
defaultMappingTTL = 2 * time.Hour
|
|
renewalInterval = defaultMappingTTL / 2
|
|
discoveryTimeout = 10 * time.Second
|
|
mappingDescription = "NetBird"
|
|
)
|
|
|
|
type Mapping struct {
|
|
Protocol string
|
|
InternalPort uint16
|
|
ExternalPort uint16
|
|
ExternalIP net.IP
|
|
NATType string
|
|
}
|
|
|
|
type Manager struct {
|
|
cancel context.CancelFunc
|
|
|
|
mapping *Mapping
|
|
mappingLock sync.Mutex
|
|
|
|
wgPort uint16
|
|
|
|
done chan struct{}
|
|
stopCtx chan context.Context
|
|
|
|
// protect exported functions
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func NewManager() *Manager {
|
|
return &Manager{
|
|
stopCtx: make(chan context.Context, 1),
|
|
}
|
|
}
|
|
|
|
func (m *Manager) Start(ctx context.Context, wgPort uint16) {
|
|
m.mu.Lock()
|
|
if m.cancel != nil {
|
|
m.mu.Unlock()
|
|
return
|
|
}
|
|
|
|
if isDisabledByEnv() {
|
|
log.Infof("NAT port mapper disabled via %s", envDisableNATMapper)
|
|
m.mu.Unlock()
|
|
return
|
|
}
|
|
|
|
if wgPort == 0 {
|
|
log.Warnf("invalid WireGuard port 0; NAT mapping disabled")
|
|
m.mu.Unlock()
|
|
return
|
|
}
|
|
m.wgPort = wgPort
|
|
|
|
m.done = make(chan struct{})
|
|
defer close(m.done)
|
|
|
|
ctx, m.cancel = context.WithCancel(ctx)
|
|
m.mu.Unlock()
|
|
|
|
gateway, mapping, err := m.setup(ctx)
|
|
if err != nil {
|
|
log.Errorf("failed to setup NAT port mapping: %v", err)
|
|
|
|
return
|
|
}
|
|
|
|
m.mappingLock.Lock()
|
|
m.mapping = mapping
|
|
m.mappingLock.Unlock()
|
|
|
|
m.renewLoop(ctx, gateway)
|
|
|
|
select {
|
|
case cleanupCtx := <-m.stopCtx:
|
|
// block the Start while cleaned up gracefully
|
|
m.cleanup(cleanupCtx, gateway)
|
|
default:
|
|
// return Start immediately and cleanup in background
|
|
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
go func() {
|
|
defer cleanupCancel()
|
|
m.cleanup(cleanupCtx, gateway)
|
|
}()
|
|
}
|
|
}
|
|
|
|
// GetMapping returns the current mapping if ready, nil otherwise
|
|
func (m *Manager) GetMapping() *Mapping {
|
|
m.mappingLock.Lock()
|
|
defer m.mappingLock.Unlock()
|
|
|
|
if m.mapping == nil {
|
|
return nil
|
|
}
|
|
|
|
mapping := *m.mapping
|
|
return &mapping
|
|
}
|
|
|
|
// GracefullyStop cancels the manager and attempts to delete the port mapping.
|
|
// After GracefullyStop returns, the manager cannot be restarted.
|
|
func (m *Manager) GracefullyStop(ctx context.Context) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if m.cancel == nil {
|
|
return nil
|
|
}
|
|
|
|
// Send cleanup context before cancelling, so Start picks it up after renewLoop exits.
|
|
m.startTearDown(ctx)
|
|
|
|
m.cancel()
|
|
m.cancel = nil
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-m.done:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) {
|
|
discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout)
|
|
defer discoverCancel()
|
|
|
|
gateway, err := nat.DiscoverGateway(discoverCtx)
|
|
if err != nil {
|
|
log.Infof("NAT gateway discovery failed: %v (port forwarding disabled)", err)
|
|
return nil, nil, err
|
|
}
|
|
|
|
log.Infof("discovered NAT gateway: %s", gateway.Type())
|
|
|
|
mapping, err := m.createMapping(ctx, gateway)
|
|
if err != nil {
|
|
log.Warnf("failed to create port mapping: %v", err)
|
|
return nil, nil, err
|
|
}
|
|
return gateway, mapping, nil
|
|
}
|
|
|
|
func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping, error) {
|
|
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
|
defer cancel()
|
|
|
|
externalPort, err := gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, defaultMappingTTL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
externalIP, err := gateway.GetExternalAddress()
|
|
if err != nil {
|
|
log.Debugf("failed to get external address: %v", err)
|
|
// todo return with err?
|
|
}
|
|
|
|
mapping := &Mapping{
|
|
Protocol: "udp",
|
|
InternalPort: m.wgPort,
|
|
ExternalPort: uint16(externalPort),
|
|
ExternalIP: externalIP,
|
|
NATType: gateway.Type(),
|
|
}
|
|
|
|
log.Infof("created port mapping: %d -> %d via %s (external IP: %s)",
|
|
m.wgPort, externalPort, gateway.Type(), externalIP)
|
|
return mapping, nil
|
|
}
|
|
|
|
func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT) {
|
|
ticker := time.NewTicker(renewalInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
if err := m.renewMapping(ctx, gateway); err != nil {
|
|
log.Warnf("failed to renew port mapping: %v", err)
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error {
|
|
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
|
defer cancel()
|
|
|
|
externalPort, err := gateway.AddPortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort), mappingDescription, defaultMappingTTL)
|
|
if err != nil {
|
|
return fmt.Errorf("add port mapping: %w", err)
|
|
}
|
|
|
|
if uint16(externalPort) != m.mapping.ExternalPort {
|
|
log.Warnf("external port changed on renewal: %d -> %d (candidate may be stale)", m.mapping.ExternalPort, externalPort)
|
|
m.mappingLock.Lock()
|
|
m.mapping.ExternalPort = uint16(externalPort)
|
|
m.mappingLock.Unlock()
|
|
}
|
|
|
|
log.Debugf("renewed port mapping: %d -> %d", m.mapping.InternalPort, m.mapping.ExternalPort)
|
|
return nil
|
|
}
|
|
|
|
func (m *Manager) cleanup(ctx context.Context, gateway nat.NAT) {
|
|
m.mappingLock.Lock()
|
|
mapping := m.mapping
|
|
m.mapping = nil
|
|
m.mappingLock.Unlock()
|
|
|
|
if mapping == nil {
|
|
return
|
|
}
|
|
|
|
if err := gateway.DeletePortMapping(ctx, mapping.Protocol, int(mapping.InternalPort)); err != nil {
|
|
log.Warnf("delete port mapping on stop: %v", err)
|
|
return
|
|
}
|
|
|
|
log.Infof("deleted port mapping for port %d", mapping.InternalPort)
|
|
}
|
|
|
|
func (m *Manager) startTearDown(ctx context.Context) {
|
|
select {
|
|
case m.stopCtx <- ctx:
|
|
default:
|
|
}
|
|
}
|