diff --git a/holepunch/tester.go b/holepunch/tester.go index 3bebc4d..babaac9 100644 --- a/holepunch/tester.go +++ b/holepunch/tester.go @@ -41,6 +41,12 @@ func DefaultTestOptions() TestConnectionOptions { } } +// cachedAddr holds a cached resolved UDP address +type cachedAddr struct { + addr *net.UDPAddr + resolvedAt time.Time +} + // HolepunchTester monitors holepunch connectivity using magic packets type HolepunchTester struct { sharedBind *bind.SharedBind @@ -53,6 +59,11 @@ type HolepunchTester struct { // Callback when connection status changes callback HolepunchStatusCallback + + // Address cache to avoid repeated DNS/UDP resolution + addrCache map[string]*cachedAddr + addrCacheMu sync.RWMutex + addrCacheTTL time.Duration // How long cached addresses are valid } // HolepunchStatus represents the status of a holepunch connection @@ -75,7 +86,9 @@ type pendingRequest struct { // NewHolepunchTester creates a new holepunch tester using the given SharedBind func NewHolepunchTester(sharedBind *bind.SharedBind) *HolepunchTester { return &HolepunchTester{ - sharedBind: sharedBind, + sharedBind: sharedBind, + addrCache: make(map[string]*cachedAddr), + addrCacheTTL: 5 * time.Minute, // Cache addresses for 5 minutes } } @@ -135,9 +148,67 @@ func (t *HolepunchTester) Stop() { return true }) + // Clear address cache + t.addrCacheMu.Lock() + t.addrCache = make(map[string]*cachedAddr) + t.addrCacheMu.Unlock() + logger.Debug("HolepunchTester stopped") } +// resolveEndpoint resolves an endpoint to a UDP address, using cache when possible +func (t *HolepunchTester) resolveEndpoint(endpoint string) (*net.UDPAddr, error) { + // Check cache first + t.addrCacheMu.RLock() + cached, ok := t.addrCache[endpoint] + ttl := t.addrCacheTTL + t.addrCacheMu.RUnlock() + + if ok && time.Since(cached.resolvedAt) < ttl { + return cached.addr, nil + } + + // Resolve the endpoint + host, err := util.ResolveDomain(endpoint) + if err != nil { + host = endpoint + } + + _, _, err = net.SplitHostPort(host) + if err != nil { + host = net.JoinHostPort(host, "21820") + } + + remoteAddr, err := net.ResolveUDPAddr("udp", host) + if err != nil { + return nil, fmt.Errorf("failed to resolve UDP address %s: %w", host, err) + } + + // Cache the result + t.addrCacheMu.Lock() + t.addrCache[endpoint] = &cachedAddr{ + addr: remoteAddr, + resolvedAt: time.Now(), + } + t.addrCacheMu.Unlock() + + return remoteAddr, nil +} + +// InvalidateCache removes a specific endpoint from the address cache +func (t *HolepunchTester) InvalidateCache(endpoint string) { + t.addrCacheMu.Lock() + delete(t.addrCache, endpoint) + t.addrCacheMu.Unlock() +} + +// ClearCache clears all cached addresses +func (t *HolepunchTester) ClearCache() { + t.addrCacheMu.Lock() + t.addrCache = make(map[string]*cachedAddr) + t.addrCacheMu.Unlock() +} + // handleResponse is called by SharedBind when a magic response is received func (t *HolepunchTester) handleResponse(addr netip.AddrPort, echoData []byte) { logger.Debug("Received magic response from %s", addr.String()) @@ -183,20 +254,10 @@ func (t *HolepunchTester) TestEndpoint(endpoint string, timeout time.Duration) T return result } - // Resolve the endpoint - host, err := util.ResolveDomain(endpoint) + // Resolve the endpoint (using cache) + remoteAddr, err := t.resolveEndpoint(endpoint) if err != nil { - host = endpoint - } - - _, _, err = net.SplitHostPort(host) - if err != nil { - host = net.JoinHostPort(host, "21820") - } - - remoteAddr, err := net.ResolveUDPAddr("udp", host) - if err != nil { - result.Error = fmt.Errorf("failed to resolve UDP address %s: %w", host, err) + result.Error = err return result }