diff --git a/main.go b/main.go index 6352b4b..10950e1 100644 --- a/main.go +++ b/main.go @@ -543,6 +543,10 @@ func ensureWireguardInterface(wgconfig WgConfig) error { logger.Warn("Failed to ensure MSS clamping: %v", err) } + if err := ensureWireguardFirewall(); err != nil { + logger.Warn("Failed to ensure WireGuard firewall rules: %v", err) + } + logger.Info("WireGuard interface %s created and configured", interfaceName) return nil @@ -711,6 +715,113 @@ func ensureMSSClamping() error { return nil } +func ensureWireguardFirewall() error { + // Rules to enforce: + // 1. Allow established/related connections (responses to our outbound traffic) + // 2. Allow ICMP ping packets + // 3. Drop all other inbound traffic from peers + + // Define the rules we want to ensure exist + rules := [][]string{ + // Allow established and related connections (responses to outbound traffic) + { + "-A", "INPUT", + "-i", interfaceName, + "-m", "conntrack", + "--ctstate", "ESTABLISHED,RELATED", + "-j", "ACCEPT", + }, + // Allow ICMP ping requests + { + "-A", "INPUT", + "-i", interfaceName, + "-p", "icmp", + "--icmp-type", "8", + "-j", "ACCEPT", + }, + // Drop all other inbound traffic from WireGuard interface + { + "-A", "INPUT", + "-i", interfaceName, + "-j", "DROP", + }, + } + + // First, try to delete any existing rules for this interface + for _, rule := range rules { + deleteArgs := make([]string, len(rule)) + copy(deleteArgs, rule) + // Change -A to -D for deletion + for i, arg := range deleteArgs { + if arg == "-A" { + deleteArgs[i] = "-D" + break + } + } + + deleteCmd := exec.Command("/usr/sbin/iptables", deleteArgs...) + logger.Debug("Attempting to delete existing firewall rule: %v", deleteArgs) + + // Try deletion multiple times to handle multiple existing rules + for i := 0; i < 5; i++ { + out, err := deleteCmd.CombinedOutput() + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + logger.Debug("Deletion stopped: %v (output: %s)", exitErr.String(), string(out)) + } + break // No more rules to delete + } + logger.Info("Deleted existing firewall rule (attempt %d)", i+1) + } + } + + // Now add the rules + var errors []error + for i, rule := range rules { + addCmd := exec.Command("/usr/sbin/iptables", rule...) + logger.Info("Adding WireGuard firewall rule %d: %v", i+1, rule) + + if out, err := addCmd.CombinedOutput(); err != nil { + errMsg := fmt.Sprintf("Failed to add firewall rule %d: %v (output: %s)", i+1, err, string(out)) + logger.Error("%s", errMsg) + errors = append(errors, fmt.Errorf("%s", errMsg)) + continue + } + + // Verify the rule was added by checking + checkArgs := make([]string, len(rule)) + copy(checkArgs, rule) + // Change -A to -C for check + for j, arg := range checkArgs { + if arg == "-A" { + checkArgs[j] = "-C" + break + } + } + + checkCmd := exec.Command("/usr/sbin/iptables", checkArgs...) + if out, err := checkCmd.CombinedOutput(); err != nil { + errMsg := fmt.Sprintf("Rule verification failed for rule %d: %v (output: %s)", i+1, err, string(out)) + logger.Error("%s", errMsg) + errors = append(errors, fmt.Errorf("%s", errMsg)) + continue + } + + logger.Info("Successfully added and verified WireGuard firewall rule %d", i+1) + } + + if len(errors) > 0 { + var errMsgs []string + for _, err := range errors { + errMsgs = append(errMsgs, err.Error()) + } + return fmt.Errorf("WireGuard firewall setup encountered errors:\n%s", strings.Join(errMsgs, "\n")) + } + + logger.Info("WireGuard firewall rules successfully configured for interface %s", interfaceName) + return nil +} + func handlePeer(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodPost: diff --git a/relay/relay.go b/relay/relay.go index 59faa4d..2772c67 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -839,7 +839,7 @@ func (s *UDPProxyServer) clearSessionsForIP(ip string) { s.wgSessions.Delete(key) } - logger.Info("Cleared %d sessions for WG IP: %s", len(keysToDelete), ip) + logger.Debug("Cleared %d sessions for WG IP: %s", len(keysToDelete), ip) } // // clearProxyMappingsForWGIP removes all proxy mappings that have destinations pointing to a specific WireGuard IP