mirror of
https://github.com/fosrl/newt.git
synced 2026-03-03 17:26:43 +00:00
Change DNS lookup to conntrack
This commit is contained in:
@@ -151,59 +151,6 @@ type natState struct {
|
|||||||
rewrittenTo netip.Addr // The address we rewrote to
|
rewrittenTo netip.Addr // The address we rewrote to
|
||||||
}
|
}
|
||||||
|
|
||||||
// dnsCache entry for caching resolved addresses
|
|
||||||
type dnsCacheEntry struct {
|
|
||||||
addr netip.Addr
|
|
||||||
expiresAt time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// dnsCache provides TTL-based caching for DNS lookups
|
|
||||||
type dnsCache struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
entries map[string]*dnsCacheEntry
|
|
||||||
ttl time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// newDNSCache creates a new DNS cache with the specified TTL
|
|
||||||
func newDNSCache(ttl time.Duration) *dnsCache {
|
|
||||||
return &dnsCache{
|
|
||||||
entries: make(map[string]*dnsCacheEntry),
|
|
||||||
ttl: ttl,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// get retrieves a cached address if it exists and hasn't expired
|
|
||||||
func (c *dnsCache) get(domain string) (netip.Addr, bool) {
|
|
||||||
c.mu.RLock()
|
|
||||||
entry, exists := c.entries[domain]
|
|
||||||
c.mu.RUnlock()
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
return netip.Addr{}, false
|
|
||||||
}
|
|
||||||
|
|
||||||
if time.Now().After(entry.expiresAt) {
|
|
||||||
// Entry expired, remove it
|
|
||||||
c.mu.Lock()
|
|
||||||
delete(c.entries, domain)
|
|
||||||
c.mu.Unlock()
|
|
||||||
return netip.Addr{}, false
|
|
||||||
}
|
|
||||||
|
|
||||||
return entry.addr, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// set stores an address in the cache with the configured TTL
|
|
||||||
func (c *dnsCache) set(domain string, addr netip.Addr) {
|
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
|
|
||||||
c.entries[domain] = &dnsCacheEntry{
|
|
||||||
addr: addr,
|
|
||||||
expiresAt: time.Now().Add(c.ttl),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProxyHandler handles packet injection and extraction for promiscuous mode
|
// ProxyHandler handles packet injection and extraction for promiscuous mode
|
||||||
type ProxyHandler struct {
|
type ProxyHandler struct {
|
||||||
proxyStack *stack.Stack
|
proxyStack *stack.Stack
|
||||||
@@ -214,7 +161,6 @@ type ProxyHandler struct {
|
|||||||
subnetLookup *SubnetLookup
|
subnetLookup *SubnetLookup
|
||||||
natTable map[connKey]*natState
|
natTable map[connKey]*natState
|
||||||
natMu sync.RWMutex
|
natMu sync.RWMutex
|
||||||
dnsCache *dnsCache
|
|
||||||
enabled bool
|
enabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -235,7 +181,6 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
|
|||||||
enabled: true,
|
enabled: true,
|
||||||
subnetLookup: NewSubnetLookup(),
|
subnetLookup: NewSubnetLookup(),
|
||||||
natTable: make(map[connKey]*natState),
|
natTable: make(map[connKey]*natState),
|
||||||
dnsCache: newDNSCache(5 * time.Minute), // Cache DNS lookups for 5 minutes
|
|
||||||
proxyEp: channel.New(1024, uint32(options.MTU), ""),
|
proxyEp: channel.New(1024, uint32(options.MTU), ""),
|
||||||
proxyStack: stack.New(stack.Options{
|
proxyStack: stack.New(stack.Options{
|
||||||
NetworkProtocols: []stack.NetworkProtocolFactory{
|
NetworkProtocols: []stack.NetworkProtocolFactory{
|
||||||
@@ -309,9 +254,8 @@ func (p *ProxyHandler) RemoveSubnetRule(sourcePrefix, destPrefix netip.Prefix) {
|
|||||||
// resolveRewriteAddress resolves a rewrite address which can be either:
|
// resolveRewriteAddress resolves a rewrite address which can be either:
|
||||||
// - An IP address with CIDR notation (e.g., "192.168.1.1/32") - returns the IP directly
|
// - An IP address with CIDR notation (e.g., "192.168.1.1/32") - returns the IP directly
|
||||||
// - A plain IP address (e.g., "192.168.1.1") - returns the IP directly
|
// - A plain IP address (e.g., "192.168.1.1") - returns the IP directly
|
||||||
// - A domain name (e.g., "example.com") - performs DNS lookup with caching
|
// - A domain name (e.g., "example.com") - performs DNS lookup
|
||||||
func (p *ProxyHandler) resolveRewriteAddress(rewriteTo string) (netip.Addr, error) {
|
func (p *ProxyHandler) resolveRewriteAddress(rewriteTo string) (netip.Addr, error) {
|
||||||
|
|
||||||
logger.Debug("Resolving rewrite address: %s", rewriteTo)
|
logger.Debug("Resolving rewrite address: %s", rewriteTo)
|
||||||
|
|
||||||
// First, try to parse as a CIDR prefix (e.g., "192.168.1.1/32")
|
// First, try to parse as a CIDR prefix (e.g., "192.168.1.1/32")
|
||||||
@@ -324,14 +268,7 @@ func (p *ProxyHandler) resolveRewriteAddress(rewriteTo string) (netip.Addr, erro
|
|||||||
return addr, nil
|
return addr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Not an IP address, treat as domain name
|
// Not an IP address, treat as domain name - perform DNS lookup
|
||||||
// Check cache first
|
|
||||||
if cachedAddr, found := p.dnsCache.get(rewriteTo); found {
|
|
||||||
logger.Debug("DNS cache hit for %s: %s", rewriteTo, cachedAddr)
|
|
||||||
return cachedAddr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cache miss, perform DNS lookup
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@@ -348,9 +285,7 @@ func (p *ProxyHandler) resolveRewriteAddress(rewriteTo string) (netip.Addr, erro
|
|||||||
ip := ips[0]
|
ip := ips[0]
|
||||||
if ip4 := ip.To4(); ip4 != nil {
|
if ip4 := ip.To4(); ip4 != nil {
|
||||||
addr := netip.AddrFrom4([4]byte{ip4[0], ip4[1], ip4[2], ip4[3]})
|
addr := netip.AddrFrom4([4]byte{ip4[0], ip4[1], ip4[2], ip4[3]})
|
||||||
// Cache the result
|
logger.Debug("Resolved %s to %s", rewriteTo, addr)
|
||||||
p.dnsCache.set(rewriteTo, addr)
|
|
||||||
logger.Debug("DNS cache miss for %s, resolved to %s", rewriteTo, addr)
|
|
||||||
return addr, nil
|
return addr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -451,21 +386,8 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
|
|||||||
if matchedRule != nil {
|
if matchedRule != nil {
|
||||||
// Check if we need to perform DNAT
|
// Check if we need to perform DNAT
|
||||||
if matchedRule.RewriteTo != "" {
|
if matchedRule.RewriteTo != "" {
|
||||||
// Resolve the rewrite address (could be IP or domain)
|
// Create connection tracking key using original destination
|
||||||
newDst, err := p.resolveRewriteAddress(matchedRule.RewriteTo)
|
// This allows us to check if we've already resolved for this connection
|
||||||
if err != nil {
|
|
||||||
// Failed to resolve, skip DNAT but still proxy the packet
|
|
||||||
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
|
||||||
Payload: buffer.MakeWithData(packet),
|
|
||||||
})
|
|
||||||
p.proxyEp.InjectInbound(header.IPv4ProtocolNumber, pkb)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Perform DNAT - rewrite destination IP
|
|
||||||
originalDst := dstAddr
|
|
||||||
|
|
||||||
// Create connection tracking key
|
|
||||||
var srcPort uint16
|
var srcPort uint16
|
||||||
switch protocol {
|
switch protocol {
|
||||||
case header.TCPProtocolNumber:
|
case header.TCPProtocolNumber:
|
||||||
@@ -476,21 +398,48 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
|
|||||||
srcPort = udpHeader.SourcePort()
|
srcPort = udpHeader.SourcePort()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Key using original destination to track the connection
|
||||||
key := connKey{
|
key := connKey{
|
||||||
srcIP: srcAddr.String(),
|
srcIP: srcAddr.String(),
|
||||||
srcPort: srcPort,
|
srcPort: srcPort,
|
||||||
dstIP: newDst.String(),
|
dstIP: dstAddr.String(),
|
||||||
dstPort: dstPort,
|
dstPort: dstPort,
|
||||||
proto: uint8(protocol),
|
proto: uint8(protocol),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store NAT state for reverse translation
|
// Check if we already have a NAT entry for this connection
|
||||||
p.natMu.Lock()
|
p.natMu.RLock()
|
||||||
p.natTable[key] = &natState{
|
existingEntry, exists := p.natTable[key]
|
||||||
originalDst: originalDst,
|
p.natMu.RUnlock()
|
||||||
rewrittenTo: newDst,
|
|
||||||
|
var newDst netip.Addr
|
||||||
|
if exists {
|
||||||
|
// Use the previously resolved address for this connection
|
||||||
|
newDst = existingEntry.rewrittenTo
|
||||||
|
logger.Debug("Using existing NAT entry for connection: %s -> %s", dstAddr, newDst)
|
||||||
|
} else {
|
||||||
|
// New connection - resolve the rewrite address
|
||||||
|
var err error
|
||||||
|
newDst, err = p.resolveRewriteAddress(matchedRule.RewriteTo)
|
||||||
|
if err != nil {
|
||||||
|
// Failed to resolve, skip DNAT but still proxy the packet
|
||||||
|
logger.Debug("Failed to resolve rewrite address: %v", err)
|
||||||
|
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||||
|
Payload: buffer.MakeWithData(packet),
|
||||||
|
})
|
||||||
|
p.proxyEp.InjectInbound(header.IPv4ProtocolNumber, pkb)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store NAT state for this connection
|
||||||
|
p.natMu.Lock()
|
||||||
|
p.natTable[key] = &natState{
|
||||||
|
originalDst: dstAddr,
|
||||||
|
rewrittenTo: newDst,
|
||||||
|
}
|
||||||
|
p.natMu.Unlock()
|
||||||
|
logger.Debug("New NAT entry for connection: %s -> %s", dstAddr, newDst)
|
||||||
}
|
}
|
||||||
p.natMu.Unlock()
|
|
||||||
|
|
||||||
// Rewrite the packet
|
// Rewrite the packet
|
||||||
packet = p.rewritePacketDestination(packet, newDst)
|
packet = p.rewritePacketDestination(packet, newDst)
|
||||||
@@ -660,20 +609,23 @@ func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Look up NAT state (key is based on the request, so dst/src are swapped for replies)
|
// Look up NAT state for reverse translation
|
||||||
key := connKey{
|
// The key uses the original dst (before rewrite), so for replies we need to
|
||||||
srcIP: dstIP.String(),
|
// find the entry where the rewritten address matches the current source
|
||||||
srcPort: dstPort,
|
|
||||||
dstIP: srcIP.String(),
|
|
||||||
dstPort: srcPort,
|
|
||||||
proto: uint8(protocol),
|
|
||||||
}
|
|
||||||
|
|
||||||
p.natMu.RLock()
|
p.natMu.RLock()
|
||||||
natEntry, exists := p.natTable[key]
|
var natEntry *natState
|
||||||
|
for k, entry := range p.natTable {
|
||||||
|
// Match: reply's dst should be original src, reply's src should be rewritten dst
|
||||||
|
if k.srcIP == dstIP.String() && k.srcPort == dstPort &&
|
||||||
|
entry.rewrittenTo.String() == srcIP.String() && k.dstPort == srcPort &&
|
||||||
|
k.proto == uint8(protocol) {
|
||||||
|
natEntry = entry
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
p.natMu.RUnlock()
|
p.natMu.RUnlock()
|
||||||
|
|
||||||
if exists {
|
if natEntry != nil {
|
||||||
// Perform reverse NAT - rewrite source to original destination
|
// Perform reverse NAT - rewrite source to original destination
|
||||||
packet = p.rewritePacketSource(packet, natEntry.originalDst)
|
packet = p.rewritePacketSource(packet, natEntry.originalDst)
|
||||||
if packet != nil {
|
if packet != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user