mirror of
https://github.com/fosrl/gerbil.git
synced 2026-02-25 22:36:48 +00:00
Add mutex
This commit is contained in:
31
main.go
31
main.go
@@ -30,6 +30,7 @@ var (
|
|||||||
mtuInt int
|
mtuInt int
|
||||||
lastReadings = make(map[string]PeerReading)
|
lastReadings = make(map[string]PeerReading)
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
wgMu sync.Mutex // Protects WireGuard operations
|
||||||
notifyURL string
|
notifyURL string
|
||||||
proxyServer *relay.UDPProxyServer
|
proxyServer *relay.UDPProxyServer
|
||||||
)
|
)
|
||||||
@@ -429,6 +430,9 @@ func assignIPAddress(ipAddress string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ensureWireguardPeers(peers []Peer) error {
|
func ensureWireguardPeers(peers []Peer) error {
|
||||||
|
wgMu.Lock()
|
||||||
|
defer wgMu.Unlock()
|
||||||
|
|
||||||
// get the current peers
|
// get the current peers
|
||||||
device, err := wgClient.Device(interfaceName)
|
device, err := wgClient.Device(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -451,8 +455,8 @@ func ensureWireguardPeers(peers []Peer) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !found {
|
if !found {
|
||||||
err := removePeer(peer)
|
// Note: We need to call the internal removal logic without re-acquiring the lock
|
||||||
if err != nil {
|
if err := removePeerInternal(peer); err != nil {
|
||||||
return fmt.Errorf("failed to remove peer: %v", err)
|
return fmt.Errorf("failed to remove peer: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -468,8 +472,8 @@ func ensureWireguardPeers(peers []Peer) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !found {
|
if !found {
|
||||||
err := addPeer(configPeer)
|
// Note: We need to call the internal addition logic without re-acquiring the lock
|
||||||
if err != nil {
|
if err := addPeerInternal(configPeer); err != nil {
|
||||||
return fmt.Errorf("failed to add peer: %v", err)
|
return fmt.Errorf("failed to add peer: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -529,7 +533,7 @@ func ensureMSSClamping() error {
|
|||||||
errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
|
errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
|
||||||
chain, err, string(out))
|
chain, err, string(out))
|
||||||
logger.Error(errMsg)
|
logger.Error(errMsg)
|
||||||
errors = append(errors, fmt.Errorf(errMsg))
|
errors = append(errors, fmt.Errorf("%s", errMsg))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -546,7 +550,7 @@ func ensureMSSClamping() error {
|
|||||||
errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
|
errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
|
||||||
chain, err, string(out))
|
chain, err, string(out))
|
||||||
logger.Error(errMsg)
|
logger.Error(errMsg)
|
||||||
errors = append(errors, fmt.Errorf(errMsg))
|
errors = append(errors, fmt.Errorf("%s", errMsg))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -598,6 +602,12 @@ func handleAddPeer(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func addPeer(peer Peer) error {
|
func addPeer(peer Peer) error {
|
||||||
|
wgMu.Lock()
|
||||||
|
defer wgMu.Unlock()
|
||||||
|
return addPeerInternal(peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func addPeerInternal(peer Peer) error {
|
||||||
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
|
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to parse public key: %v", err)
|
return fmt.Errorf("failed to parse public key: %v", err)
|
||||||
@@ -662,6 +672,12 @@ func handleRemovePeer(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func removePeer(publicKey string) error {
|
func removePeer(publicKey string) error {
|
||||||
|
wgMu.Lock()
|
||||||
|
defer wgMu.Unlock()
|
||||||
|
return removePeerInternal(publicKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func removePeerInternal(publicKey string) error {
|
||||||
pubKey, err := wgtypes.ParseKey(publicKey)
|
pubKey, err := wgtypes.ParseKey(publicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to parse public key: %v", err)
|
return fmt.Errorf("failed to parse public key: %v", err)
|
||||||
@@ -766,7 +782,10 @@ func periodicBandwidthCheck(endpoint string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
func calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||||
|
wgMu.Lock()
|
||||||
device, err := wgClient.Device(interfaceName)
|
device, err := wgClient.Device(interfaceName)
|
||||||
|
wgMu.Unlock()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get device: %v", err)
|
return nil, fmt.Errorf("failed to get device: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -440,7 +440,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
|||||||
|
|
||||||
_, err = conn.Write(packet)
|
_, err = conn.Write(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to forward transport data: %v", err)
|
logger.Debug("Failed to forward transport data: %v", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// No known session, fall back to forwarding to all peers
|
// No known session, fall back to forwarding to all peers
|
||||||
@@ -460,7 +460,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
|||||||
|
|
||||||
_, err = conn.Write(packet)
|
_, err = conn.Write(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to forward transport data: %v", err)
|
logger.Debug("Failed to forward transport data: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -524,7 +524,7 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd
|
|||||||
for {
|
for {
|
||||||
n, err := conn.Read(buffer)
|
n, err := conn.Read(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error reading response from %s: %v", destAddr.String(), err)
|
logger.Debug("Error reading response from %s: %v", destAddr.String(), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user