Files
gerbil/main.go
2024-09-30 09:43:27 -04:00

478 lines
12 KiB
Go

package main
import (
"bytes"
"encoding/json"
"flag"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"time"
"github.com/vishvananda/netlink"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
var (
interfaceName = "wg0"
listenAddr = ":3002"
)
type WgConfig struct {
PrivateKey string `json:"privateKey"`
ListenPort int `json:"listenPort"`
IpAddress string `json:"ipAddress"`
Peers []Peer `json:"peers"`
}
type Peer struct {
PublicKey string `json:"publicKey"`
AllowedIPs []string `json:"allowedIps"`
}
type PeerBandwidth struct {
PublicKey string `json:"publicKey"`
BytesIn float64 `json:"bytesIn"`
BytesOut float64 `json:"bytesOut"`
}
var (
wgClient *wgctrl.Client
)
func main() {
var err error
var wgconfig WgConfig
// Define command line flags
interfaceNameArg := flag.String("interface", "wg0", "Name of the WireGuard interface")
configFile := flag.String("config", "", "Path to local configuration file")
remoteConfigURL := flag.String("remoteConfig", "", "URL to fetch remote configuration")
listenAddrArg := flag.String("listen", ":3002", "Address to listen on")
reportBandwidthTo := flag.String("reportBandwidthTo", "", "Address to listen on")
flag.Parse()
if *interfaceNameArg != "" {
interfaceName = *interfaceNameArg
}
if *listenAddrArg != "" {
listenAddr = *listenAddrArg
}
// Validate that only one config option is provided
if (*configFile != "" && *remoteConfigURL != "") || (*configFile == "" && *remoteConfigURL == "") {
log.Fatal("Please provide either --config or --remoteConfig, but not both")
}
wgClient, err = wgctrl.New()
if err != nil {
log.Fatalf("Failed to create WireGuard client: %v", err)
}
defer wgClient.Close()
// Load configuration based on provided argument
if *configFile != "" {
wgconfig, err = loadConfig(*configFile)
} else {
wgconfig, err = loadRemoteConfig(*remoteConfigURL)
}
if err != nil {
log.Fatalf("Failed to load configuration: %v", err)
}
// Ensure the WireGuard interface exists and is configured
if err := ensureWireguardInterface(wgconfig); err != nil {
log.Fatalf("Failed to ensure WireGuard interface: %v", err)
}
// Ensure the WireGuard peers exist
ensureWireguardPeers(wgconfig.Peers)
if *reportBandwidthTo != "" {
go periodicBandwidthCheck(*reportBandwidthTo)
}
http.HandleFunc("/peer", handlePeer)
log.Printf("Starting server on %s", listenAddr)
log.Fatal(http.ListenAndServe(listenAddr, nil))
}
func loadRemoteConfig(url string) (WgConfig, error) {
resp, err := http.Get(url)
if err != nil {
return WgConfig{}, err
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return WgConfig{}, err
}
var config WgConfig
err = json.Unmarshal(data, &config)
return config, err
}
func loadConfig(filename string) (WgConfig, error) {
// Open the JSON file
file, err := os.Open(filename)
if err != nil {
fmt.Println("Error opening file:", err)
return WgConfig{}, err
}
defer file.Close()
// Read the file contents
byteValue, err := io.ReadAll(file)
if err != nil {
fmt.Println("Error reading file:", err)
return WgConfig{}, err
}
// Create a variable of the appropriate type to hold the unmarshaled data
var wgconfig WgConfig
// Unmarshal the JSON data into the struct
err = json.Unmarshal(byteValue, &wgconfig)
if err != nil {
fmt.Println("Error unmarshaling JSON:", err)
return WgConfig{}, err
}
return wgconfig, nil
}
func ensureWireguardInterface(wgconfig WgConfig) error {
// Check if the WireGuard interface exists
_, err := netlink.LinkByName(interfaceName)
if err != nil {
if _, ok := err.(netlink.LinkNotFoundError); ok {
// Interface doesn't exist, so create it
err = createWireGuardInterface()
if err != nil {
log.Fatalf("Failed to create WireGuard interface: %v", err)
}
log.Printf("Created WireGuard interface %s\n", interfaceName)
} else {
log.Fatalf("Error checking for WireGuard interface: %v", err)
}
} else {
log.Printf("WireGuard interface %s already exists\n", interfaceName)
return nil
}
// Assign IP address to the interface
err = assignIPAddress(wgconfig.IpAddress)
if err != nil {
log.Fatalf("Failed to assign IP address: %v", err)
}
log.Printf("Assigned IP address %s to interface %s\n", wgconfig.IpAddress, interfaceName)
// Check if the interface already exists
_, err = wgClient.Device(interfaceName)
if err != nil {
return fmt.Errorf("interface %s does not exist", interfaceName)
}
// Parse the private key
key, err := wgtypes.ParseKey(wgconfig.PrivateKey)
if err != nil {
return fmt.Errorf("failed to parse private key: %v", err)
}
// Create a new WireGuard configuration
config := wgtypes.Config{
PrivateKey: &key,
ListenPort: new(int),
}
*config.ListenPort = wgconfig.ListenPort
// Create and configure the WireGuard interface
err = wgClient.ConfigureDevice(interfaceName, config)
if err != nil {
return fmt.Errorf("failed to configure WireGuard device: %v", err)
}
// bring up the interface
link, err := netlink.LinkByName(interfaceName)
if err != nil {
return fmt.Errorf("failed to get interface: %v", err)
}
if err := netlink.LinkSetUp(link); err != nil {
return fmt.Errorf("failed to bring up interface: %v", err)
}
log.Printf("WireGuard interface %s created and configured", interfaceName)
return nil
}
func createWireGuardInterface() error {
wgLink := &netlink.GenericLink{
LinkAttrs: netlink.LinkAttrs{Name: interfaceName},
LinkType: "wireguard",
}
return netlink.LinkAdd(wgLink)
}
func assignIPAddress(ipAddress string) error {
link, err := netlink.LinkByName(interfaceName)
if err != nil {
return fmt.Errorf("failed to get interface: %v", err)
}
addr, err := netlink.ParseAddr(ipAddress)
if err != nil {
return fmt.Errorf("failed to parse IP address: %v", err)
}
return netlink.AddrAdd(link, addr)
}
func ensureWireguardPeers(peers []Peer) error {
// get the current peers
device, err := wgClient.Device(interfaceName)
if err != nil {
return fmt.Errorf("failed to get device: %v", err)
}
// get the peer public keys
var currentPeers []string
for _, peer := range device.Peers {
currentPeers = append(currentPeers, peer.PublicKey.String())
}
// remove any peers that are not in the config
for _, peer := range currentPeers {
found := false
for _, configPeer := range peers {
if peer == configPeer.PublicKey {
found = true
break
}
}
if !found {
err := removePeer(peer)
if err != nil {
return fmt.Errorf("failed to remove peer: %v", err)
}
}
}
// add any peers that are in the config but not in the current peers
for _, configPeer := range peers {
found := false
for _, peer := range currentPeers {
if configPeer.PublicKey == peer {
found = true
break
}
}
if !found {
err := addPeer(configPeer)
if err != nil {
return fmt.Errorf("failed to add peer: %v", err)
}
}
}
return nil
}
func handlePeer(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodPost:
handleAddPeer(w, r)
case http.MethodDelete:
handleRemovePeer(w, r)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
func handleAddPeer(w http.ResponseWriter, r *http.Request) {
var peer Peer
if err := json.NewDecoder(r.Body).Decode(&peer); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
err := addPeer(peer)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]string{"status": "Peer added successfully"})
}
func addPeer(peer Peer) error {
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
if err != nil {
return fmt.Errorf("failed to parse public key: %v", err)
}
// parse allowed IPs into array of net.IPNet
var allowedIPs []net.IPNet
for _, ipStr := range peer.AllowedIPs {
_, ipNet, err := net.ParseCIDR(ipStr)
if err != nil {
return fmt.Errorf("failed to parse allowed IP: %v", err)
}
allowedIPs = append(allowedIPs, *ipNet)
}
peerConfig := wgtypes.PeerConfig{
PublicKey: pubKey,
AllowedIPs: allowedIPs,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peerConfig},
}
if err := wgClient.ConfigureDevice(interfaceName, config); err != nil {
return fmt.Errorf("failed to add peer: %v", err)
}
log.Printf("Peer %s added successfully", peer.PublicKey)
return nil
}
func handleRemovePeer(w http.ResponseWriter, r *http.Request) {
publicKey := r.URL.Query().Get("public_key")
if publicKey == "" {
http.Error(w, "Missing public_key query parameter", http.StatusBadRequest)
return
}
err := removePeer(publicKey)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "Peer removed successfully"})
}
func removePeer(publicKey string) error {
pubKey, err := wgtypes.ParseKey(publicKey)
if err != nil {
return fmt.Errorf("failed to parse public key: %v", err)
}
peerConfig := wgtypes.PeerConfig{
PublicKey: pubKey,
Remove: true,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peerConfig},
}
if err := wgClient.ConfigureDevice(interfaceName, config); err != nil {
return fmt.Errorf("failed to remove peer: %v", err)
}
log.Printf("Peer %s removed successfully", publicKey)
return nil
}
func periodicBandwidthCheck(endpoint string) {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for range ticker.C {
if err := reportPeerBandwidth(endpoint); err != nil {
log.Printf("Failed to report peer bandwidth: %v", err)
}
}
}
func calculatePeerBandwidth() ([]PeerBandwidth, error) { //TODO: fix this to actually only report the change in bandwidth from the last query
device, err := wgClient.Device(interfaceName)
if err != nil {
return nil, fmt.Errorf("failed to get device: %v", err)
}
peerBandwidths := []PeerBandwidth{}
for _, peer := range device.Peers {
// Store initial values
initialBytesReceived := peer.ReceiveBytes
initialBytesSent := peer.TransmitBytes
// Wait for a short period to measure change
time.Sleep(5 * time.Second)
// Get updated device info
updatedDevice, err := wgClient.Device(interfaceName)
if err != nil {
return nil, fmt.Errorf("failed to get updated device: %v", err)
}
var updatedPeer *wgtypes.Peer
for _, p := range updatedDevice.Peers {
if p.PublicKey == peer.PublicKey {
updatedPeer = &p
break
}
}
if updatedPeer == nil {
continue
}
// Calculate change in bytes
bytesInDiff := float64(updatedPeer.ReceiveBytes - initialBytesReceived)
bytesOutDiff := float64(updatedPeer.TransmitBytes - initialBytesSent)
// Convert to MB
bytesInMB := bytesInDiff / (1024 * 1024)
bytesOutMB := bytesOutDiff / (1024 * 1024)
peerBandwidths = append(peerBandwidths, PeerBandwidth{
PublicKey: peer.PublicKey.String(),
BytesIn: bytesInMB,
BytesOut: bytesOutMB,
})
}
return peerBandwidths, nil
}
func reportPeerBandwidth(apiURL string) error {
bandwidths, err := calculatePeerBandwidth()
if err != nil {
return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
}
jsonData, err := json.Marshal(bandwidths)
if err != nil {
return fmt.Errorf("failed to marshal bandwidth data: %v", err)
}
resp, err := http.Post(apiURL, "application/json", bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to send bandwidth data: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("API returned non-OK status: %s", resp.Status)
}
// log.Println("Bandwidth data sent successfully")
return nil
}