diff --git a/main.go b/main.go index 786ecbd..9cee671 100644 --- a/main.go +++ b/main.go @@ -20,6 +20,7 @@ import ( "github.com/fosrl/newt/logger" "github.com/fosrl/newt/proxy" "github.com/fosrl/newt/websocket" + "github.com/fosrl/newt/wg" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" @@ -246,15 +247,18 @@ func resolveDomain(domain string) (string, error) { func main() { var ( - endpoint string - id string - secret string - mtu string - mtuInt int - dns string - privateKey wgtypes.Key - err error - logLevel string + endpoint string + id string + secret string + mtu string + mtuInt int + dns string + privateKey wgtypes.Key + err error + logLevel string + interfaceName string + generateAndSaveKeyTo string + reachableAt string ) // if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values @@ -264,6 +268,9 @@ func main() { mtu = os.Getenv("MTU") dns = os.Getenv("DNS") logLevel = os.Getenv("LOG_LEVEL") + interfaceName = os.Getenv("INTERFACE") + generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO") + reachableAt = os.Getenv("REACHABLE_AT") if endpoint == "" { flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") @@ -283,6 +290,15 @@ func main() { if logLevel == "" { flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") } + if interfaceName == "" { + flag.StringVar(&interfaceName, "interface", "wg-1", "Name of the WireGuard interface") + } + if generateAndSaveKeyTo == "" { + flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key") + } + if reachableAt == "" { + flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about") + } // do a --version check version := flag.Bool("version", false, "Print the version") @@ -319,6 +335,21 @@ func main() { logger.Fatal("Failed to create client: %v", err) } + // Create WireGuard service + wgService, err := wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client) + if err != nil { + logger.Fatal("Failed to create WireGuard service: %v", err) + } + // defer wgService.Close() + + // Start the WireGuard service + if err := wgService.Start(); err != nil { + logger.Fatal("Failed to start WireGuard service: %v", err) + } + + // Start bandwidth reporting + wgService.StartBandwidthReporting() + // Create TUN device and network stack var tun tun.Device var tnet *netstack.Net diff --git a/wg/wg.go b/wg/wg.go index ed868ea..dc5e337 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -3,14 +3,10 @@ package wg import ( "bytes" "encoding/json" - "flag" "fmt" - "io" "net" - "net/http" "os" "os/exec" - "strconv" "strings" "sync" "time" @@ -58,147 +54,149 @@ var ( wgClient *wgctrl.Client ) -func main() { - var ( - err error - wgconfig WgConfig - remoteConfigURL string - generateAndSaveKeyTo string - reachableAt string - logLevel string - mtu string - ) +type WireGuardService struct { + interfaceName string + mtu int + client *websocket.Client + wgClient *wgctrl.Client + config WgConfig + key wgtypes.Key + reachableAt string + generateAndSaveKeyTo string + lastReadings map[string]PeerReading + mu sync.Mutex +} - interfaceName = os.Getenv("INTERFACE") - remoteConfigURL = os.Getenv("REMOTE_CONFIG") - listenAddr = os.Getenv("LISTEN") - generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO") - reachableAt = os.Getenv("REACHABLE_AT") - logLevel = os.Getenv("LOG_LEVEL") - mtu = os.Getenv("MTU") - - if interfaceName == "" { - flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface") - } - if remoteConfigURL == "" { - flag.StringVar(&remoteConfigURL, "remoteConfig", "", "URL to fetch remote configuration") - } - if generateAndSaveKeyTo == "" { - flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key") - } - if reachableAt == "" { - flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about") - } - if logLevel == "" { - flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") - } - if mtu == "" { - flag.StringVar(&mtu, "mtu", "1280", "MTU of the WireGuard interface") - } - flag.Parse() - - mtuInt, err = strconv.Atoi(mtu) +func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) { + wgClient, err := wgctrl.New() if err != nil { - logger.Fatal("Failed to parse MTU: %v", err) + return nil, fmt.Errorf("failed to create WireGuard client: %v", err) } - var key wgtypes.Key + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + return nil, fmt.Errorf("failed to generate private key: %v", err) + } + + service := &WireGuardService{ + interfaceName: interfaceName, + mtu: mtu, + client: wsClient, + wgClient: wgClient, + key: key, + reachableAt: reachableAt, + generateAndSaveKeyTo: generateAndSaveKeyTo, + lastReadings: make(map[string]PeerReading), + } + + // Register websocket handlers + wsClient.RegisterHandler("wg/peer/config", service.handleConfig) + wsClient.RegisterHandler("wg/peer/add", service.handleAddPeer) + wsClient.RegisterHandler("wg/peer/remove", service.handleRemovePeer) + + // Register connect handler to initiate configuration + wsClient.OnConnect(service.handleConnect) + + return service, nil +} + +func (s *WireGuardService) handleConnect() error { + logger.Debug("Public key: %s", s.key.PublicKey()) + + err := s.client.SendMessage("wg/register", map[string]interface{}{ + "publicKey": fmt.Sprintf("%s", s.key.PublicKey()), + }) + if err != nil { + logger.Error("Failed to send registration message: %v", err) + return err + } + + logger.Info("Sent registration message") + return nil +} + +func (s *WireGuardService) Start() error { + // if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file - if generateAndSaveKeyTo != "" { - if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) { - // generate a new private key - key, err = wgtypes.GeneratePrivateKey() - if err != nil { - logger.Fatal("Failed to generate private key: %v", err) - } - // save the key to the file - err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644) - if err != nil { - logger.Fatal("Failed to save private key: %v", err) - } - } else { - keyData, err := os.ReadFile(generateAndSaveKeyTo) - if err != nil { - logger.Fatal("Failed to read private key: %v", err) - } - key, err = wgtypes.ParseKey(string(keyData)) - if err != nil { - logger.Fatal("Failed to parse private key: %v", err) - } - } - } else { - // if no generateAndSaveKeyTo is provided, ensure that the private key is provided - if wgconfig.PrivateKey == "" { - // generate a new one - key, err = wgtypes.GeneratePrivateKey() - if err != nil { - logger.Fatal("Failed to generate private key: %v", err) - } - } - } - - // loop until we get the config - for wgconfig.PrivateKey == "" { - logger.Info("Fetching remote config from %s", remoteConfigURL) - wgconfig, err = loadRemoteConfig(remoteConfigURL, key, reachableAt) + if _, err := os.Stat(s.generateAndSaveKeyTo); os.IsNotExist(err) { + // generate a new private key + s.key, err = wgtypes.GeneratePrivateKey() if err != nil { - logger.Error("Failed to load configuration: %v", err) - time.Sleep(5 * time.Second) - continue + logger.Fatal("Failed to generate private key: %v", err) + } + // save the key to the file + err = os.WriteFile(s.generateAndSaveKeyTo, []byte(s.key.String()), 0644) + if err != nil { + logger.Fatal("Failed to save private key: %v", err) } - wgconfig.PrivateKey = key.String() - } - - wgClient, err = wgctrl.New() - if err != nil { - logger.Fatal("Failed to create WireGuard client: %v", err) - } - defer wgClient.Close() - - // Ensure the WireGuard interface exists and is configured - if err := ensureWireguardInterface(wgconfig); err != nil { - logger.Fatal("Failed to ensure WireGuard interface: %v", err) - } - - // Ensure the WireGuard peers exist - ensureWireguardPeers(wgconfig.Peers) - - // go periodicBandwidthCheck(reportBandwidthTo) -} - -func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) { - var body *bytes.Buffer - if reachableAt == "" { - body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, key.PublicKey().String()))) } else { - body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "reachableAt": "%s"}`, key.PublicKey().String(), reachableAt))) + keyData, err := os.ReadFile(s.generateAndSaveKeyTo) + if err != nil { + logger.Fatal("Failed to read private key: %v", err) + } + s.key, err = wgtypes.ParseKey(string(keyData)) + if err != nil { + logger.Fatal("Failed to parse private key: %v", err) + } } - resp, err := http.Post(url, "application/json", body) + + // Get initial configuration + err := s.loadRemoteConfig() if err != nil { - // print the error - logger.Error("Error fetching remote config %s: %v", url, err) - return WgConfig{}, err - } - defer resp.Body.Close() - - data, err := io.ReadAll(resp.Body) - if err != nil { - return WgConfig{}, err + return fmt.Errorf("failed to load initial configuration: %v", err) } - var config WgConfig - err = json.Unmarshal(data, &config) - - return config, err + return nil } -func ensureWireguardInterface(wgconfig WgConfig) error { +func (s *WireGuardService) StartBandwidthReporting() { + go s.periodicBandwidthCheck() +} + +func (s *WireGuardService) loadRemoteConfig() error { + body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "reachableAt": "%s"}`, s.key.PublicKey().String(), s.reachableAt))) + + // send a ws message to the server to get the config + + err := s.client.SendMessage("wg/config/get", body) + if err != nil { + return fmt.Errorf("failed to send config request: %v", err) + } + + return nil +} + +func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { + var config WgConfig + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + } + + if err := json.Unmarshal(jsonData, &config); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + } + + s.config = config + + // Ensure the WireGuard interface and peers are configured + if err := s.ensureWireguardInterface(config); err != nil { + logger.Error("Failed to ensure WireGuard interface: %v", err) + } + + if err := s.ensureWireguardPeers(config.Peers); err != nil { + logger.Error("Failed to ensure WireGuard peers: %v", err) + } +} + +func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { // Check if the WireGuard interface exists _, err := netlink.LinkByName(interfaceName) if err != nil { if _, ok := err.(netlink.LinkNotFoundError); ok { // Interface doesn't exist, so create it - err = createWireGuardInterface() + err = s.createWireGuardInterface() if err != nil { logger.Fatal("Failed to create WireGuard interface: %v", err) } @@ -212,7 +210,7 @@ func ensureWireguardInterface(wgconfig WgConfig) error { } // Assign IP address to the interface - err = assignIPAddress(wgconfig.IpAddress) + err = s.assignIPAddress(wgconfig.IpAddress) if err != nil { logger.Fatal("Failed to assign IP address: %v", err) } @@ -257,7 +255,7 @@ func ensureWireguardInterface(wgconfig WgConfig) error { return fmt.Errorf("failed to bring up interface: %v", err) } - if err := ensureMSSClamping(); err != nil { + if err := s.ensureMSSClamping(); err != nil { logger.Warn("Failed to ensure MSS clamping: %v", err) } @@ -266,7 +264,7 @@ func ensureWireguardInterface(wgconfig WgConfig) error { return nil } -func createWireGuardInterface() error { +func (s *WireGuardService) createWireGuardInterface() error { wgLink := &netlink.GenericLink{ LinkAttrs: netlink.LinkAttrs{Name: interfaceName}, LinkType: "wireguard", @@ -274,7 +272,7 @@ func createWireGuardInterface() error { return netlink.LinkAdd(wgLink) } -func assignIPAddress(ipAddress string) error { +func (s *WireGuardService) assignIPAddress(ipAddress string) error { link, err := netlink.LinkByName(interfaceName) if err != nil { return fmt.Errorf("failed to get interface: %v", err) @@ -288,7 +286,7 @@ func assignIPAddress(ipAddress string) error { return netlink.AddrAdd(link, addr) } -func ensureWireguardPeers(peers []Peer) error { +func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error { // get the current peers device, err := wgClient.Device(interfaceName) if err != nil { @@ -311,7 +309,7 @@ func ensureWireguardPeers(peers []Peer) error { } } if !found { - err := removePeer(peer) + err := s.removePeer(peer) if err != nil { return fmt.Errorf("failed to remove peer: %v", err) } @@ -328,7 +326,7 @@ func ensureWireguardPeers(peers []Peer) error { } } if !found { - err := addPeer(configPeer) + err := s.addPeer(configPeer) if err != nil { return fmt.Errorf("failed to add peer: %v", err) } @@ -338,7 +336,7 @@ func ensureWireguardPeers(peers []Peer) error { return nil } -func ensureMSSClamping() error { +func (s *WireGuardService) ensureMSSClamping() error { // Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20)) mssValue := mtuInt - 40 @@ -426,7 +424,7 @@ func ensureMSSClamping() error { return nil } -func handleAddPeer(msg websocket.WSMessage) { +func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) { var peer Peer jsonData, err := json.Marshal(msg.Data) @@ -438,13 +436,13 @@ func handleAddPeer(msg websocket.WSMessage) { logger.Info("Error unmarshaling target data: %v", err) } - err = addPeer(peer) + err = s.addPeer(peer) if err != nil { return } } -func addPeer(peer Peer) error { +func (s *WireGuardService) addPeer(peer Peer) error { pubKey, err := wgtypes.ParseKey(peer.PublicKey) if err != nil { return fmt.Errorf("failed to parse public key: %v", err) @@ -478,7 +476,7 @@ func addPeer(peer Peer) error { return nil } -func handleRemovePeer(msg websocket.WSMessage) { +func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) { // parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" } type RemoveRequest struct { PublicKey string `json:"publicKey"` @@ -495,13 +493,13 @@ func handleRemovePeer(msg websocket.WSMessage) { return } - if err := removePeer(request.PublicKey); err != nil { + if err := s.removePeer(request.PublicKey); err != nil { logger.Info("Error removing peer: %v", err) return } } -func removePeer(publicKey string) error { +func (s *WireGuardService) removePeer(publicKey string) error { pubKey, err := wgtypes.ParseKey(publicKey) if err != nil { return fmt.Errorf("failed to parse public key: %v", err) @@ -525,18 +523,18 @@ func removePeer(publicKey string) error { return nil } -func periodicBandwidthCheck(endpoint string) { +func (s *WireGuardService) periodicBandwidthCheck() { ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() for range ticker.C { - if err := reportPeerBandwidth(endpoint); err != nil { + if err := s.reportPeerBandwidth(); err != nil { logger.Info("Failed to report peer bandwidth: %v", err) } } } -func calculatePeerBandwidth() ([]PeerBandwidth, error) { +func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) { device, err := wgClient.Device(interfaceName) if err != nil { return nil, fmt.Errorf("failed to get device: %v", err) @@ -621,8 +619,8 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) { return peerBandwidths, nil } -func reportPeerBandwidth(apiURL string) error { - bandwidths, err := calculatePeerBandwidth() +func (s *WireGuardService) reportPeerBandwidth() error { + bandwidths, err := s.calculatePeerBandwidth() if err != nil { return fmt.Errorf("failed to calculate peer bandwidth: %v", err) } @@ -632,15 +630,10 @@ func reportPeerBandwidth(apiURL string) error { return fmt.Errorf("failed to marshal bandwidth data: %v", err) } - resp, err := http.Post(apiURL, "application/json", bytes.NewBuffer(jsonData)) + err = s.client.SendMessage("wg/bandwidth", jsonData) if err != nil { return fmt.Errorf("failed to send bandwidth data: %v", err) } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("API returned non-OK status: %s", resp.Status) - } return nil }