diff --git a/main.go b/main.go index 7306176..eaa9ce5 100644 --- a/main.go +++ b/main.go @@ -16,6 +16,7 @@ import ( "time" "github.com/fosrl/gerbil/logger" + "github.com/fosrl/gerbil/relay" "github.com/vishvananda/netlink" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -88,99 +89,26 @@ func parseLogLevel(level string) logger.LogLevel { } } -// Update the startUDPServer function -func startUDPServer(addr string, server string) { - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - logger.Fatal("Failed to resolve UDP address: %v", err) - } - - conn, err := net.ListenUDP("udp", udpAddr) - if err != nil { - logger.Fatal("Failed to start UDP server: %v", err) - } - defer conn.Close() - - logger.Info("UDP server listening on %s", addr) - - buffer := make([]byte, 1024) - for { - n, remoteAddr, err := conn.ReadFromUDP(buffer) - if err != nil { - logger.Error("Error reading UDP packet: %v", err) - continue - } - - var msg HolePunchMessage - if err := json.Unmarshal(buffer[:n], &msg); err != nil { - logger.Error("Error unmarshaling message: %v", err) - continue - } - - // Create endpoint info - endpoint := ClientEndpoint{ - OlmID: msg.OlmID, - NewtID: msg.NewtID, - IP: remoteAddr.IP.String(), - Port: remoteAddr.Port, - Timestamp: time.Now().Unix(), - } - - // Send the endpoint info to the Olm server - go notifyServer(endpoint, server) - } -} - -// Add this new function -func notifyServer(endpoint ClientEndpoint, server string) { - jsonData, err := json.Marshal(endpoint) - if err != nil { - logger.Error("Failed to marshal endpoint data: %v", err) - return - } - - resp, err := http.Post(server, - "application/json", - bytes.NewBuffer(jsonData)) - if err != nil { - logger.Error("Failed to notify Olm server: %v", err) - return - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - logger.Error("Olm server returned non-OK status: %d, body: %s", - resp.StatusCode, - string(body)) - return - } -} - func main() { var ( err error wgconfig WgConfig configFile string remoteConfigURL string - reportBandwidthTo string generateAndSaveKeyTo string reachableAt string logLevel string mtu string - reportHolePunchTo string ) interfaceName = os.Getenv("INTERFACE") configFile = os.Getenv("CONFIG") remoteConfigURL = os.Getenv("REMOTE_CONFIG") listenAddr = os.Getenv("LISTEN") - reportBandwidthTo = os.Getenv("REPORT_BANDWIDTH_TO") generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO") reachableAt = os.Getenv("REACHABLE_AT") logLevel = os.Getenv("LOG_LEVEL") mtu = os.Getenv("MTU") - reportHolePunchTo = os.Getenv("REPORT_HOLE_PUNCH_TO") if interfaceName == "" { flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface") @@ -189,17 +117,16 @@ func main() { flag.StringVar(&configFile, "config", "", "Path to local configuration file") } if remoteConfigURL == "" { - flag.StringVar(&remoteConfigURL, "remoteConfig", "", "URL to fetch remote configuration") + flag.StringVar(&remoteConfigURL, "remoteConfig", "", "URL of the Pangolin server") } if listenAddr == "" { flag.StringVar(&listenAddr, "listen", ":3003", "Address to listen on") } - if reportBandwidthTo == "" { - flag.StringVar(&reportBandwidthTo, "reportBandwidthTo", "", "Address to listen on") - } - if reportHolePunchTo == "" { - flag.StringVar(&reportHolePunchTo, "reportHolePunchTo", "", "Address to listen on") - } + // DEPRECATED AND UNSED: reportBandwidthTo + // allow reportBandwidthTo to be passed but dont do anything with it just thow it away + reportBandwidthTo := "" + flag.StringVar(&reportBandwidthTo, "reportBandwidthTo", "", "DEPRECATED: Use remoteConfig instead") + if generateAndSaveKeyTo == "" { flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key") } @@ -232,6 +159,10 @@ func main() { logger.Fatal("You must provide either a config file or a remote config URL, not both") } + // clean up the reomte config URL for backwards compatibility + remoteConfigURL = strings.TrimSuffix(remoteConfigURL, "/gerbil/get-config") + remoteConfigURL = strings.TrimSuffix(remoteConfigURL, "/") + var key wgtypes.Key // if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file if generateAndSaveKeyTo != "" { @@ -279,8 +210,8 @@ func main() { } else { // loop until we get the config for wgconfig.PrivateKey == "" { - logger.Info("Fetching remote config from %s", remoteConfigURL) - wgconfig, err = loadRemoteConfig(remoteConfigURL, key, reachableAt) + logger.Info("Fetching remote config from %s", remoteConfigURL+"/gerbil/get-config") + wgconfig, err = loadRemoteConfig(remoteConfigURL+"/gerbil/get-config", key, reachableAt) if err != nil { logger.Error("Failed to load configuration: %v", err) time.Sleep(5 * time.Second) @@ -304,12 +235,14 @@ func main() { // Ensure the WireGuard peers exist ensureWireguardPeers(wgconfig.Peers) - if reportBandwidthTo != "" { - go periodicBandwidthCheck(reportBandwidthTo) - } + go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth") - // run the udp server - go startUDPServer(":21820", reportHolePunchTo) + server := relay.NewUDPProxyServer(":21820", remoteConfigURL, key) + err = server.Start() + if err != nil { + logger.Fatal("Failed to start server: %v", err) + } + defer server.Stop() http.HandleFunc("/peer", handlePeer) logger.Info("Starting server on %s", listenAddr) diff --git a/relay/relay.go b/relay/relay.go new file mode 100644 index 0000000..8c02905 --- /dev/null +++ b/relay/relay.go @@ -0,0 +1,299 @@ +package relay + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "sync" + "time" + + "github.com/fosrl/gerbil/logger" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +type HolePunchMessage struct { + OlmID string `json:"olmId"` + NewtID string `json:"newtId"` +} + +type ClientEndpoint struct { + OlmID string `json:"olmId"` + NewtID string `json:"newtId"` + IP string `json:"ip"` + Port int `json:"port"` + Timestamp int64 `json:"timestamp"` +} + +type ProxyMapping struct { + DestinationIP string `json:"destinationIP"` + DestinationPort int `json:"destinationPort"` +} + +type DestinationConn struct { + conn *net.UDPConn + lastUsed time.Time +} + +type InitialMappings struct { + Mappings map[string]ProxyMapping `json:"mappings"` // key is "ip:port" +} + +type UDPProxyServer struct { + addr string + serverURL string + conn *net.UDPConn + proxyMappings sync.Map // map[string]ProxyMapping where key is "ip:port" + connections sync.Map // map[string]*DestinationConn where key is destination "ip:port" + publicKey wgtypes.Key +} + +type Logger interface { + Info(format string, args ...interface{}) + Error(format string, args ...interface{}) + Fatal(format string, args ...interface{}) +} + +func NewUDPProxyServer(addr, serverURL string, publicKey wgtypes.Key) *UDPProxyServer { + return &UDPProxyServer{ + addr: addr, + serverURL: serverURL, + publicKey: publicKey, + } +} + +func (s *UDPProxyServer) Start() error { + // First fetch initial mappings + if err := s.fetchInitialMappings(); err != nil { + return fmt.Errorf("failed to fetch initial mappings: %v", err) + } + + udpAddr, err := net.ResolveUDPAddr("udp", s.addr) + if err != nil { + return err + } + + conn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return err + } + + s.conn = conn + logger.Info("UDP server listening on %s", s.addr) + + go s.handlePackets() + go s.cleanupIdleConnections() + return nil +} + +func (s *UDPProxyServer) Stop() { + s.conn.Close() +} + +func (s *UDPProxyServer) fetchInitialMappings() error { + body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, s.publicKey.PublicKey().String()))) + + resp, err := http.Post(s.serverURL+"/gerbil/get-all-relays", "application/json", body) + if err != nil { + return fmt.Errorf("failed to fetch mappings: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("server returned non-OK status: %d, body: %s", + resp.StatusCode, string(body)) + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response body: %v", err) + } + + logger.Info("Received initial mappings: %s", string(data)) + + var initialMappings InitialMappings + if err := json.Unmarshal(data, &initialMappings); err != nil { + return fmt.Errorf("failed to unmarshal initial mappings: %v", err) + } + + // Store all mappings in our sync.Map + for key, mapping := range initialMappings.Mappings { + s.proxyMappings.Store(key, mapping) + } + + logger.Info("Loaded %d initial proxy mappings", len(initialMappings.Mappings)) + return nil +} + +func (s *UDPProxyServer) handlePackets() { + buffer := make([]byte, 1500) // Standard MTU size + for { + n, remoteAddr, err := s.conn.ReadFromUDP(buffer) + if err != nil { + logger.Error("Error reading UDP packet: %v", err) + continue + } + + // Otherwise, treat it as an incoming WireGuard or Hole Punch request + if n > 0 && buffer[0] >= 1 && buffer[0] <= 4 { + go s.handleWireGuardPacket(buffer[:n], remoteAddr) + continue + } + + // Try to handle as hole punch message + var msg HolePunchMessage + if err := json.Unmarshal(buffer[:n], &msg); err != nil { + logger.Error("Error unmarshaling message: %v", err) + continue + } + + endpoint := ClientEndpoint{ + OlmID: msg.OlmID, + NewtID: msg.NewtID, + IP: remoteAddr.IP.String(), + Port: remoteAddr.Port, + Timestamp: time.Now().Unix(), + } + + go s.notifyServer(endpoint) + } +} + +func (s *UDPProxyServer) getOrCreateConnection(destAddr *net.UDPAddr, remoteAddr *net.UDPAddr) (*net.UDPConn, error) { + key := remoteAddr.String() + + // Check if we have an existing connection + if conn, ok := s.connections.Load(key); ok { + destConn := conn.(*DestinationConn) + destConn.lastUsed = time.Now() + return destConn.conn, nil + } + + // Create new connection + newConn, err := net.DialUDP("udp", nil, destAddr) + if err != nil { + return nil, fmt.Errorf("failed to create UDP connection: %v", err) + } + + // Store the new connection + s.connections.Store(key, &DestinationConn{ + conn: newConn, + lastUsed: time.Now(), + }) + + // Start a goroutine to handle responses + go s.handleResponses(newConn, destAddr, remoteAddr) + + return newConn, nil +} + +func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAddr, remoteAddr *net.UDPAddr) { + buffer := make([]byte, 1500) + for { + n, err := conn.Read(buffer) + if err != nil { + logger.Error("Error reading response from %s: %v", destAddr.String(), err) + return + } + + // Forward the response back through the main listener + _, err = s.conn.WriteToUDP(buffer[:n], remoteAddr) + if err != nil { + logger.Error("Failed to forward response: %v", err) + } + } +} + +func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UDPAddr) { + key := remoteAddr.String() + mapping, ok := s.proxyMappings.Load(key) + if !ok { + logger.Error("No proxy mapping found for %s", key) + return + } + + proxyMapping := mapping.(ProxyMapping) + destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", + proxyMapping.DestinationIP, proxyMapping.DestinationPort)) + if err != nil { + logger.Error("Failed to resolve destination address: %v", err) + return + } + + // Get or create a connection to the destination + conn, err := s.getOrCreateConnection(destAddr, remoteAddr) + if err != nil { + logger.Error("Failed to get/create connection: %v", err) + return + } + + // Forward the packet + _, err = conn.Write(packet) + if err != nil { + logger.Error("Failed to proxy packet: %v", err) + } +} + +// Add a cleanup method to periodically remove idle connections +func (s *UDPProxyServer) cleanupIdleConnections() { + ticker := time.NewTicker(5 * time.Minute) + for range ticker.C { + now := time.Now() + s.connections.Range(func(key, value interface{}) bool { + destConn := value.(*DestinationConn) + if now.Sub(destConn.lastUsed) > 10*time.Minute { + destConn.conn.Close() + s.connections.Delete(key) + } + return true + }) + } +} + +func (s *UDPProxyServer) notifyServer(endpoint ClientEndpoint) { + jsonData, err := json.Marshal(endpoint) + if err != nil { + logger.Error("Failed to marshal endpoint data: %v", err) + return + } + + resp, err := http.Post(s.serverURL+"/gerbil/update-hole-punch", "application/json", bytes.NewBuffer(jsonData)) + if err != nil { + logger.Error("Failed to notify server: %v", err) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + logger.Error("Server returned non-OK status: %d, body: %s", + resp.StatusCode, string(body)) + return + } + + // Parse the proxy mapping response + var mapping ProxyMapping + if err := json.NewDecoder(resp.Body).Decode(&mapping); err != nil { + logger.Error("Failed to decode proxy mapping: %v", err) + return + } + + // Store the mapping + key := fmt.Sprintf("%s:%d", endpoint.IP, endpoint.Port) + s.proxyMappings.Store(key, mapping) + + logger.Debug("Stored proxy mapping for %s: %v", key, mapping) +} + +func (s *UDPProxyServer) UpdateProxyMapping(sourceIP string, sourcePort int, + destinationIP string, destinationPort int) { + key := net.JoinHostPort(sourceIP, string(sourcePort)) + mapping := ProxyMapping{ + DestinationIP: destinationIP, + DestinationPort: destinationPort, + } + s.proxyMappings.Store(key, mapping) +}