diff --git a/main.go b/main.go index 9f01344..4270251 100644 --- a/main.go +++ b/main.go @@ -30,6 +30,7 @@ var ( mtuInt int lastReadings = make(map[string]PeerReading) mu sync.Mutex + wgMu sync.Mutex // Protects WireGuard operations notifyURL string proxyServer *relay.UDPProxyServer ) @@ -429,6 +430,9 @@ func assignIPAddress(ipAddress string) error { } func ensureWireguardPeers(peers []Peer) error { + wgMu.Lock() + defer wgMu.Unlock() + // get the current peers device, err := wgClient.Device(interfaceName) if err != nil { @@ -451,8 +455,8 @@ func ensureWireguardPeers(peers []Peer) error { } } if !found { - err := removePeer(peer) - if err != nil { + // Note: We need to call the internal removal logic without re-acquiring the lock + if err := removePeerInternal(peer); err != nil { return fmt.Errorf("failed to remove peer: %v", err) } } @@ -468,8 +472,8 @@ func ensureWireguardPeers(peers []Peer) error { } } if !found { - err := addPeer(configPeer) - if err != nil { + // Note: We need to call the internal addition logic without re-acquiring the lock + if err := addPeerInternal(configPeer); err != nil { return fmt.Errorf("failed to add peer: %v", err) } } @@ -529,7 +533,7 @@ func ensureMSSClamping() error { errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)", chain, err, string(out)) logger.Error(errMsg) - errors = append(errors, fmt.Errorf(errMsg)) + errors = append(errors, fmt.Errorf("%s", errMsg)) continue } @@ -546,7 +550,7 @@ func ensureMSSClamping() error { errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)", chain, err, string(out)) logger.Error(errMsg) - errors = append(errors, fmt.Errorf(errMsg)) + errors = append(errors, fmt.Errorf("%s", errMsg)) continue } @@ -598,6 +602,12 @@ func handleAddPeer(w http.ResponseWriter, r *http.Request) { } func addPeer(peer Peer) error { + wgMu.Lock() + defer wgMu.Unlock() + return addPeerInternal(peer) +} + +func addPeerInternal(peer Peer) error { pubKey, err := wgtypes.ParseKey(peer.PublicKey) if err != nil { return fmt.Errorf("failed to parse public key: %v", err) @@ -662,6 +672,12 @@ func handleRemovePeer(w http.ResponseWriter, r *http.Request) { } func removePeer(publicKey string) error { + wgMu.Lock() + defer wgMu.Unlock() + return removePeerInternal(publicKey) +} + +func removePeerInternal(publicKey string) error { pubKey, err := wgtypes.ParseKey(publicKey) if err != nil { return fmt.Errorf("failed to parse public key: %v", err) @@ -766,7 +782,10 @@ func periodicBandwidthCheck(endpoint string) { } func calculatePeerBandwidth() ([]PeerBandwidth, error) { + wgMu.Lock() device, err := wgClient.Device(interfaceName) + wgMu.Unlock() + if err != nil { return nil, fmt.Errorf("failed to get device: %v", err) } diff --git a/relay/relay.go b/relay/relay.go index a71eeed..939f4d2 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -440,7 +440,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD _, err = conn.Write(packet) if err != nil { - logger.Error("Failed to forward transport data: %v", err) + logger.Debug("Failed to forward transport data: %v", err) } } else { // No known session, fall back to forwarding to all peers @@ -460,7 +460,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD _, err = conn.Write(packet) if err != nil { - logger.Error("Failed to forward transport data: %v", err) + logger.Debug("Failed to forward transport data: %v", err) } } } @@ -524,7 +524,7 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd for { n, err := conn.Read(buffer) if err != nil { - logger.Error("Error reading response from %s: %v", destAddr.String(), err) + logger.Debug("Error reading response from %s: %v", destAddr.String(), err) return }