diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..c0364e4 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,8 @@ +.gitignore +.dockerignore +gerbil +data.json +*.json +docker-compose.yml +README.md +Makefile \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c015594 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +go.sum +gerbil diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..c4f40f0 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,29 @@ +FROM golang:1.21.5-alpine AS build + +# Set the working directory inside the container +WORKDIR /app + +# Copy go mod and sum files +COPY go.mod go.sum ./ + +# Download all dependencies +RUN go mod download + +# Copy the source code into the container +COPY . . + +# Build the application +RUN CGO_ENABLED=0 GOOS=linux go build -o /gerbil + +# Start a new stage from scratch +FROM ubuntu:22:04 + +RUN RUN apt-get update && apt-get install -y nftables && apt-get clean + +WORKDIR /root/ + +# Copy the pre-built binary file from the previous stage +COPY --from=build /gerbil . + +# Command to run the executable +CMD ["./gerbil"] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..2b790d2 --- /dev/null +++ b/Makefile @@ -0,0 +1,5 @@ +all: + CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o gerbil + +clean: + rm gerbil \ No newline at end of file diff --git a/badger b/badger new file mode 100755 index 0000000..c05f876 Binary files /dev/null and b/badger differ diff --git a/config_exmaple.json b/config_exmaple.json new file mode 100644 index 0000000..4b20271 --- /dev/null +++ b/config_exmaple.json @@ -0,0 +1,23 @@ +{ + "privateKey": "kBGTgk7c+zncEEoSnMl+jsLjVh5ZVoL/HwBSQem+d1M=", + "listenPort": 51820, + "ipAddress": "10.0.0.1/24", + "peers": [ + { + "publicKey": "5UzzoeveFVSzuqK3nTMS5bA1jIMs1fQffVQzJ8MXUQM=", + "allowedIps": ["10.0.0.0/28"] + }, + { + "publicKey": "kYrZpuO2NsrFoBh1GMNgkhd1i9Rgtu1rAjbJ7qsfngU=", + "allowedIps": ["10.0.0.16/28"] + }, + { + "publicKey": "1YfPUVr9ZF4zehkbI2BQhCxaRLz+Vtwa4vJwH+mpK0A=", + "allowedIps": ["10.0.0.32/28"] + }, + { + "publicKey": "2/U4oyZ+sai336Dal/yExCphL8AxyqvIxMk4qsUy4iI=", + "allowedIps": ["10.0.0.48/28"] + } + ] +} \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..7f9305a --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,16 @@ +version: '3' + +services: + gerbil: + image: gerbil + container_name: gerbil + cap_add: + - NET_ADMIN + - SYS_MODULE + volumes: + - ./config:/config + ports: + - 51820:51820/udp + sysctls: + - net.ipv4.conf.all.src_valid_mark=1 + restart: unless-stopped diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..26bbe23 --- /dev/null +++ b/go.mod @@ -0,0 +1,19 @@ +module github.com/fosrl/gerbil + +go 1.21.5 + +require ( + github.com/google/go-cmp v0.5.9 // indirect + github.com/josharian/native v1.1.0 // indirect + github.com/mdlayher/genetlink v1.3.2 // indirect + github.com/mdlayher/netlink v1.7.2 // indirect + github.com/mdlayher/socket v0.4.1 // indirect + github.com/vishvananda/netlink v1.3.0 // indirect + github.com/vishvananda/netns v0.0.4 // indirect + golang.org/x/crypto v0.8.0 // indirect + golang.org/x/net v0.9.0 // indirect + golang.org/x/sync v0.1.0 // indirect + golang.org/x/sys v0.10.0 // indirect + golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b // indirect + golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 // indirect +) diff --git a/main.go b/main.go new file mode 100644 index 0000000..0a96874 --- /dev/null +++ b/main.go @@ -0,0 +1,477 @@ +package main + +import ( + "bytes" + "encoding/json" + "flag" + "fmt" + "io" + "log" + "net" + "net/http" + "os" + "time" + + "github.com/vishvananda/netlink" + "golang.zx2c4.com/wireguard/wgctrl" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +var ( + interfaceName = "wg0" + listenAddr = ":8080" +) + +type WgConfig struct { + PrivateKey string `json:"privateKey"` + ListenPort int `json:"listenPort"` + IpAddress string `json:"ipAddress"` + Peers []Peer `json:"peers"` +} + +type Peer struct { + PublicKey string `json:"publicKey"` + AllowedIPs []string `json:"allowedIps"` +} + +type PeerBandwidth struct { + PublicKey string `json:"publicKey"` + BytesIn float64 `json:"bytesIn"` + BytesOut float64 `json:"bytesOut"` +} + +var ( + wgClient *wgctrl.Client +) + +func main() { + var err error + var wgconfig WgConfig + + // Define command line flags + interfaceNameArg := flag.String("interface", "wg0", "Name of the WireGuard interface") + configFile := flag.String("config", "", "Path to local configuration file") + remoteConfigURL := flag.String("remoteConfig", "", "URL to fetch remote configuration") + listenAddrArg := flag.String("listen", ":8080", "Address to listen on") + reportBandwidthTo := flag.String("reportBandwidthTo", "", "Address to listen on") + flag.Parse() + + if *interfaceNameArg != "" { + interfaceName = *interfaceNameArg + } + if *listenAddrArg != "" { + listenAddr = *listenAddrArg + } + + // Validate that only one config option is provided + if (*configFile != "" && *remoteConfigURL != "") || (*configFile == "" && *remoteConfigURL == "") { + log.Fatal("Please provide either --config or --remoteConfig, but not both") + } + + wgClient, err = wgctrl.New() + if err != nil { + log.Fatalf("Failed to create WireGuard client: %v", err) + } + defer wgClient.Close() + + // Load configuration based on provided argument + if *configFile != "" { + wgconfig, err = loadConfig(*configFile) + } else { + wgconfig, err = loadRemoteConfig(*remoteConfigURL) + } + + if err != nil { + log.Fatalf("Failed to load configuration: %v", err) + } + + // Ensure the WireGuard interface exists and is configured + if err := ensureWireguardInterface(wgconfig); err != nil { + log.Fatalf("Failed to ensure WireGuard interface: %v", err) + } + + // Ensure the WireGuard peers exist + ensureWireguardPeers(wgconfig.Peers) + + if *reportBandwidthTo != "" { + go periodicBandwidthCheck(*reportBandwidthTo) + } + + http.HandleFunc("/peer", handlePeer) + log.Printf("Starting server on %s", listenAddr) + log.Fatal(http.ListenAndServe(listenAddr, nil)) +} + +func loadRemoteConfig(url string) (WgConfig, error) { + resp, err := http.Get(url) + if err != nil { + return WgConfig{}, err + } + defer resp.Body.Close() + + data, err := io.ReadAll(resp.Body) + if err != nil { + return WgConfig{}, err + } + + var config WgConfig + err = json.Unmarshal(data, &config) + return config, err +} + +func loadConfig(filename string) (WgConfig, error) { + // Open the JSON file + file, err := os.Open(filename) + if err != nil { + fmt.Println("Error opening file:", err) + return WgConfig{}, err + } + defer file.Close() + + // Read the file contents + byteValue, err := io.ReadAll(file) + if err != nil { + fmt.Println("Error reading file:", err) + return WgConfig{}, err + } + + // Create a variable of the appropriate type to hold the unmarshaled data + var wgconfig WgConfig + + // Unmarshal the JSON data into the struct + err = json.Unmarshal(byteValue, &wgconfig) + if err != nil { + fmt.Println("Error unmarshaling JSON:", err) + return WgConfig{}, err + } + + return wgconfig, nil +} + +func ensureWireguardInterface(wgconfig WgConfig) error { + // Check if the WireGuard interface exists + _, err := netlink.LinkByName(interfaceName) + if err != nil { + if _, ok := err.(netlink.LinkNotFoundError); ok { + // Interface doesn't exist, so create it + err = createWireGuardInterface() + if err != nil { + log.Fatalf("Failed to create WireGuard interface: %v", err) + } + log.Printf("Created WireGuard interface %s\n", interfaceName) + } else { + log.Fatalf("Error checking for WireGuard interface: %v", err) + } + } else { + log.Printf("WireGuard interface %s already exists\n", interfaceName) + return nil + } + + // Assign IP address to the interface + err = assignIPAddress(wgconfig.IpAddress) + if err != nil { + log.Fatalf("Failed to assign IP address: %v", err) + } + log.Printf("Assigned IP address %s to interface %s\n", wgconfig.IpAddress, interfaceName) + + // Check if the interface already exists + _, err = wgClient.Device(interfaceName) + if err != nil { + return fmt.Errorf("interface %s does not exist", interfaceName) + } + + // Parse the private key + key, err := wgtypes.ParseKey(wgconfig.PrivateKey) + if err != nil { + return fmt.Errorf("failed to parse private key: %v", err) + } + + // Create a new WireGuard configuration + config := wgtypes.Config{ + PrivateKey: &key, + ListenPort: new(int), + } + *config.ListenPort = wgconfig.ListenPort + + // Create and configure the WireGuard interface + err = wgClient.ConfigureDevice(interfaceName, config) + if err != nil { + return fmt.Errorf("failed to configure WireGuard device: %v", err) + } + + // bring up the interface + link, err := netlink.LinkByName(interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface: %v", err) + } + if err := netlink.LinkSetUp(link); err != nil { + return fmt.Errorf("failed to bring up interface: %v", err) + } + + log.Printf("WireGuard interface %s created and configured", interfaceName) + + return nil +} + +func createWireGuardInterface() error { + wgLink := &netlink.GenericLink{ + LinkAttrs: netlink.LinkAttrs{Name: interfaceName}, + LinkType: "wireguard", + } + return netlink.LinkAdd(wgLink) +} + +func assignIPAddress(ipAddress string) error { + link, err := netlink.LinkByName(interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface: %v", err) + } + + addr, err := netlink.ParseAddr(ipAddress) + if err != nil { + return fmt.Errorf("failed to parse IP address: %v", err) + } + + return netlink.AddrAdd(link, addr) +} + +func ensureWireguardPeers(peers []Peer) error { + // get the current peers + device, err := wgClient.Device(interfaceName) + if err != nil { + return fmt.Errorf("failed to get device: %v", err) + } + + // get the peer public keys + var currentPeers []string + for _, peer := range device.Peers { + currentPeers = append(currentPeers, peer.PublicKey.String()) + } + + // remove any peers that are not in the config + for _, peer := range currentPeers { + found := false + for _, configPeer := range peers { + if peer == configPeer.PublicKey { + found = true + break + } + } + if !found { + err := removePeer(peer) + if err != nil { + return fmt.Errorf("failed to remove peer: %v", err) + } + } + } + + // add any peers that are in the config but not in the current peers + for _, configPeer := range peers { + found := false + for _, peer := range currentPeers { + if configPeer.PublicKey == peer { + found = true + break + } + } + if !found { + err := addPeer(configPeer) + if err != nil { + return fmt.Errorf("failed to add peer: %v", err) + } + } + } + + return nil +} + +func handlePeer(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + handleAddPeer(w, r) + case http.MethodDelete: + handleRemovePeer(w, r) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } +} + +func handleAddPeer(w http.ResponseWriter, r *http.Request) { + var peer Peer + if err := json.NewDecoder(r.Body).Decode(&peer); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + err := addPeer(peer) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(map[string]string{"status": "Peer added successfully"}) +} + +func addPeer(peer Peer) error { + pubKey, err := wgtypes.ParseKey(peer.PublicKey) + if err != nil { + return fmt.Errorf("failed to parse public key: %v", err) + } + + // parse allowed IPs into array of net.IPNet + var allowedIPs []net.IPNet + for _, ipStr := range peer.AllowedIPs { + _, ipNet, err := net.ParseCIDR(ipStr) + if err != nil { + return fmt.Errorf("failed to parse allowed IP: %v", err) + } + allowedIPs = append(allowedIPs, *ipNet) + } + + peerConfig := wgtypes.PeerConfig{ + PublicKey: pubKey, + AllowedIPs: allowedIPs, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peerConfig}, + } + + if err := wgClient.ConfigureDevice(interfaceName, config); err != nil { + return fmt.Errorf("failed to add peer: %v", err) + } + + log.Printf("Peer %s added successfully", peer.PublicKey) + + return nil +} + +func handleRemovePeer(w http.ResponseWriter, r *http.Request) { + publicKey := r.URL.Query().Get("public_key") + if publicKey == "" { + http.Error(w, "Missing public_key query parameter", http.StatusBadRequest) + return + } + + err := removePeer(publicKey) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"status": "Peer removed successfully"}) +} + +func removePeer(publicKey string) error { + pubKey, err := wgtypes.ParseKey(publicKey) + if err != nil { + return fmt.Errorf("failed to parse public key: %v", err) + } + + peerConfig := wgtypes.PeerConfig{ + PublicKey: pubKey, + Remove: true, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peerConfig}, + } + + if err := wgClient.ConfigureDevice(interfaceName, config); err != nil { + return fmt.Errorf("failed to remove peer: %v", err) + } + + log.Printf("Peer %s removed successfully", publicKey) + + return nil +} + +func periodicBandwidthCheck(endpoint string) { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for range ticker.C { + if err := reportPeerBandwidth(endpoint); err != nil { + log.Printf("Failed to report peer bandwidth: %v", err) + } + } +} + +func calculatePeerBandwidth() ([]PeerBandwidth, error) { + device, err := wgClient.Device(interfaceName) + if err != nil { + return nil, fmt.Errorf("failed to get device: %v", err) + } + + peerBandwidths := []PeerBandwidth{} + + for _, peer := range device.Peers { + // Store initial values + initialBytesReceived := peer.ReceiveBytes + initialBytesSent := peer.TransmitBytes + + // Wait for a short period to measure change + time.Sleep(5 * time.Second) + + // Get updated device info + updatedDevice, err := wgClient.Device(interfaceName) + if err != nil { + return nil, fmt.Errorf("failed to get updated device: %v", err) + } + + var updatedPeer *wgtypes.Peer + for _, p := range updatedDevice.Peers { + if p.PublicKey == peer.PublicKey { + updatedPeer = &p + break + } + } + + if updatedPeer == nil { + continue + } + + // Calculate change in bytes + bytesInDiff := float64(updatedPeer.ReceiveBytes - initialBytesReceived) + bytesOutDiff := float64(updatedPeer.TransmitBytes - initialBytesSent) + + // Convert to MB + bytesInMB := bytesInDiff / (1024 * 1024) + bytesOutMB := bytesOutDiff / (1024 * 1024) + + peerBandwidths = append(peerBandwidths, PeerBandwidth{ + PublicKey: peer.PublicKey.String(), + BytesIn: bytesInMB, + BytesOut: bytesOutMB, + }) + } + + return peerBandwidths, nil +} + +func reportPeerBandwidth(apiURL string) error { + bandwidths, err := calculatePeerBandwidth() + if err != nil { + return fmt.Errorf("failed to calculate peer bandwidth: %v", err) + } + + jsonData, err := json.Marshal(bandwidths) + if err != nil { + return fmt.Errorf("failed to marshal bandwidth data: %v", err) + } + + resp, err := http.Post(apiURL, "application/json", bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to send bandwidth data: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("API returned non-OK status: %s", resp.Status) + } + + // log.Println("Bandwidth data sent successfully") + return nil +}