diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index 5c7a3fbdd..a3a4ba40f 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -40,7 +40,6 @@ type Manager struct { fwRules []firewall.Rule tcpRules []firewall.Rule dnsForwarder *DNSForwarder - port uint16 } func ListenPort() uint16 { @@ -49,11 +48,16 @@ func ListenPort() uint16 { return listenPort } -func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager { +func SetListenPort(port uint16) { + listenPortMu.Lock() + listenPort = port + listenPortMu.Unlock() +} + +func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager { return &Manager{ firewall: fw, statusRecorder: statusRecorder, - port: port, } } @@ -67,12 +71,6 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { return err } - if m.port > 0 { - listenPortMu.Lock() - listenPort = m.port - listenPortMu.Unlock() - } - m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder) go func() { if err := m.dnsForwarder.Listen(fwdEntries); err != nil { diff --git a/client/internal/engine.go b/client/internal/engine.go index 646e059d4..bebf04f6c 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1849,6 +1849,10 @@ func (e *Engine) updateDNSForwarder( return } + if forwarderPort > 0 { + dnsfwd.SetListenPort(forwarderPort) + } + if !enabled { if e.dnsForwardMgr == nil { return @@ -1862,7 +1866,7 @@ func (e *Engine) updateDNSForwarder( if len(fwdEntries) > 0 { switch { case e.dnsForwardMgr == nil: - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort) + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil @@ -1892,7 +1896,7 @@ func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPor if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { log.Errorf("failed to stop DNS forward: %v", err) } - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort) + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil