Add mutex

This commit is contained in:
Owen
2025-07-28 21:35:57 -07:00
parent fc7df8a530
commit 78c768e497
2 changed files with 28 additions and 9 deletions

31
main.go
View File

@@ -30,6 +30,7 @@ var (
mtuInt int
lastReadings = make(map[string]PeerReading)
mu sync.Mutex
wgMu sync.Mutex // Protects WireGuard operations
notifyURL string
proxyServer *relay.UDPProxyServer
)
@@ -429,6 +430,9 @@ func assignIPAddress(ipAddress string) error {
}
func ensureWireguardPeers(peers []Peer) error {
wgMu.Lock()
defer wgMu.Unlock()
// get the current peers
device, err := wgClient.Device(interfaceName)
if err != nil {
@@ -451,8 +455,8 @@ func ensureWireguardPeers(peers []Peer) error {
}
}
if !found {
err := removePeer(peer)
if err != nil {
// Note: We need to call the internal removal logic without re-acquiring the lock
if err := removePeerInternal(peer); err != nil {
return fmt.Errorf("failed to remove peer: %v", err)
}
}
@@ -468,8 +472,8 @@ func ensureWireguardPeers(peers []Peer) error {
}
}
if !found {
err := addPeer(configPeer)
if err != nil {
// Note: We need to call the internal addition logic without re-acquiring the lock
if err := addPeerInternal(configPeer); err != nil {
return fmt.Errorf("failed to add peer: %v", err)
}
}
@@ -529,7 +533,7 @@ func ensureMSSClamping() error {
errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
chain, err, string(out))
logger.Error(errMsg)
errors = append(errors, fmt.Errorf(errMsg))
errors = append(errors, fmt.Errorf("%s", errMsg))
continue
}
@@ -546,7 +550,7 @@ func ensureMSSClamping() error {
errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
chain, err, string(out))
logger.Error(errMsg)
errors = append(errors, fmt.Errorf(errMsg))
errors = append(errors, fmt.Errorf("%s", errMsg))
continue
}
@@ -598,6 +602,12 @@ func handleAddPeer(w http.ResponseWriter, r *http.Request) {
}
func addPeer(peer Peer) error {
wgMu.Lock()
defer wgMu.Unlock()
return addPeerInternal(peer)
}
func addPeerInternal(peer Peer) error {
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
if err != nil {
return fmt.Errorf("failed to parse public key: %v", err)
@@ -662,6 +672,12 @@ func handleRemovePeer(w http.ResponseWriter, r *http.Request) {
}
func removePeer(publicKey string) error {
wgMu.Lock()
defer wgMu.Unlock()
return removePeerInternal(publicKey)
}
func removePeerInternal(publicKey string) error {
pubKey, err := wgtypes.ParseKey(publicKey)
if err != nil {
return fmt.Errorf("failed to parse public key: %v", err)
@@ -766,7 +782,10 @@ func periodicBandwidthCheck(endpoint string) {
}
func calculatePeerBandwidth() ([]PeerBandwidth, error) {
wgMu.Lock()
device, err := wgClient.Device(interfaceName)
wgMu.Unlock()
if err != nil {
return nil, fmt.Errorf("failed to get device: %v", err)
}

View File

@@ -440,7 +440,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
_, err = conn.Write(packet)
if err != nil {
logger.Error("Failed to forward transport data: %v", err)
logger.Debug("Failed to forward transport data: %v", err)
}
} else {
// No known session, fall back to forwarding to all peers
@@ -460,7 +460,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
_, err = conn.Write(packet)
if err != nil {
logger.Error("Failed to forward transport data: %v", err)
logger.Debug("Failed to forward transport data: %v", err)
}
}
}
@@ -524,7 +524,7 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd
for {
n, err := conn.Read(buffer)
if err != nil {
logger.Error("Error reading response from %s: %v", destAddr.String(), err)
logger.Debug("Error reading response from %s: %v", destAddr.String(), err)
return
}