mirror of
https://github.com/fosrl/gerbil.git
synced 2026-02-08 05:56:40 +00:00
Add mutex
This commit is contained in:
31
main.go
31
main.go
@@ -30,6 +30,7 @@ var (
|
||||
mtuInt int
|
||||
lastReadings = make(map[string]PeerReading)
|
||||
mu sync.Mutex
|
||||
wgMu sync.Mutex // Protects WireGuard operations
|
||||
notifyURL string
|
||||
proxyServer *relay.UDPProxyServer
|
||||
)
|
||||
@@ -429,6 +430,9 @@ func assignIPAddress(ipAddress string) error {
|
||||
}
|
||||
|
||||
func ensureWireguardPeers(peers []Peer) error {
|
||||
wgMu.Lock()
|
||||
defer wgMu.Unlock()
|
||||
|
||||
// get the current peers
|
||||
device, err := wgClient.Device(interfaceName)
|
||||
if err != nil {
|
||||
@@ -451,8 +455,8 @@ func ensureWireguardPeers(peers []Peer) error {
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
err := removePeer(peer)
|
||||
if err != nil {
|
||||
// Note: We need to call the internal removal logic without re-acquiring the lock
|
||||
if err := removePeerInternal(peer); err != nil {
|
||||
return fmt.Errorf("failed to remove peer: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -468,8 +472,8 @@ func ensureWireguardPeers(peers []Peer) error {
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
err := addPeer(configPeer)
|
||||
if err != nil {
|
||||
// Note: We need to call the internal addition logic without re-acquiring the lock
|
||||
if err := addPeerInternal(configPeer); err != nil {
|
||||
return fmt.Errorf("failed to add peer: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -529,7 +533,7 @@ func ensureMSSClamping() error {
|
||||
errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
|
||||
chain, err, string(out))
|
||||
logger.Error(errMsg)
|
||||
errors = append(errors, fmt.Errorf(errMsg))
|
||||
errors = append(errors, fmt.Errorf("%s", errMsg))
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -546,7 +550,7 @@ func ensureMSSClamping() error {
|
||||
errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
|
||||
chain, err, string(out))
|
||||
logger.Error(errMsg)
|
||||
errors = append(errors, fmt.Errorf(errMsg))
|
||||
errors = append(errors, fmt.Errorf("%s", errMsg))
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -598,6 +602,12 @@ func handleAddPeer(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func addPeer(peer Peer) error {
|
||||
wgMu.Lock()
|
||||
defer wgMu.Unlock()
|
||||
return addPeerInternal(peer)
|
||||
}
|
||||
|
||||
func addPeerInternal(peer Peer) error {
|
||||
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %v", err)
|
||||
@@ -662,6 +672,12 @@ func handleRemovePeer(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func removePeer(publicKey string) error {
|
||||
wgMu.Lock()
|
||||
defer wgMu.Unlock()
|
||||
return removePeerInternal(publicKey)
|
||||
}
|
||||
|
||||
func removePeerInternal(publicKey string) error {
|
||||
pubKey, err := wgtypes.ParseKey(publicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %v", err)
|
||||
@@ -766,7 +782,10 @@ func periodicBandwidthCheck(endpoint string) {
|
||||
}
|
||||
|
||||
func calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||
wgMu.Lock()
|
||||
device, err := wgClient.Device(interfaceName)
|
||||
wgMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get device: %v", err)
|
||||
}
|
||||
|
||||
@@ -440,7 +440,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
||||
|
||||
_, err = conn.Write(packet)
|
||||
if err != nil {
|
||||
logger.Error("Failed to forward transport data: %v", err)
|
||||
logger.Debug("Failed to forward transport data: %v", err)
|
||||
}
|
||||
} else {
|
||||
// No known session, fall back to forwarding to all peers
|
||||
@@ -460,7 +460,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
||||
|
||||
_, err = conn.Write(packet)
|
||||
if err != nil {
|
||||
logger.Error("Failed to forward transport data: %v", err)
|
||||
logger.Debug("Failed to forward transport data: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -524,7 +524,7 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd
|
||||
for {
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
logger.Error("Error reading response from %s: %v", destAddr.String(), err)
|
||||
logger.Debug("Error reading response from %s: %v", destAddr.String(), err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user