diff --git a/httpserver/httpserver.go b/httpserver/httpserver.go new file mode 100644 index 0000000..a3c3d3b --- /dev/null +++ b/httpserver/httpserver.go @@ -0,0 +1,177 @@ +package httpserver + +import ( + "encoding/json" + "fmt" + "net/http" + "sync" + "time" + + "github.com/fosrl/newt/logger" +) + +// ConnectionRequest defines the structure for an incoming connection request +type ConnectionRequest struct { + ID string `json:"id"` + Secret string `json:"secret"` + Endpoint string `json:"endpoint"` +} + +// PeerStatus represents the status of a peer connection +type PeerStatus struct { + SiteID int `json:"siteId"` + Connected bool `json:"connected"` + RTT time.Duration `json:"rtt"` + LastSeen time.Time `json:"lastSeen"` +} + +// StatusResponse is returned by the status endpoint +type StatusResponse struct { + Status string `json:"status"` + Connected bool `json:"connected"` + TunnelIP string `json:"tunnelIP,omitempty"` + PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"` +} + +// HTTPServer represents the HTTP server and its state +type HTTPServer struct { + addr string + server *http.Server + connectionChan chan ConnectionRequest + statusMu sync.RWMutex + peerStatuses map[int]*PeerStatus + connectedAt time.Time + isConnected bool +} + +// NewHTTPServer creates a new HTTP server +func NewHTTPServer(addr string) *HTTPServer { + s := &HTTPServer{ + addr: addr, + connectionChan: make(chan ConnectionRequest, 1), + peerStatuses: make(map[int]*PeerStatus), + } + + return s +} + +// Start starts the HTTP server +func (s *HTTPServer) Start() error { + mux := http.NewServeMux() + mux.HandleFunc("/connect", s.handleConnect) + mux.HandleFunc("/status", s.handleStatus) + + s.server = &http.Server{ + Addr: s.addr, + Handler: mux, + } + + logger.Info("Starting HTTP server on %s", s.addr) + go func() { + if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.Error("HTTP server error: %v", err) + } + }() + + return nil +} + +// Stop stops the HTTP server +func (s *HTTPServer) Stop() error { + logger.Info("Stopping HTTP server") + return s.server.Close() +} + +// GetConnectionChannel returns the channel for receiving connection requests +func (s *HTTPServer) GetConnectionChannel() <-chan ConnectionRequest { + return s.connectionChan +} + +// UpdatePeerStatus updates the status of a peer +func (s *HTTPServer) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + + status, exists := s.peerStatuses[siteID] + if !exists { + status = &PeerStatus{ + SiteID: siteID, + } + s.peerStatuses[siteID] = status + } + + status.Connected = connected + status.RTT = rtt + status.LastSeen = time.Now() +} + +// SetConnectionStatus sets the overall connection status +func (s *HTTPServer) SetConnectionStatus(isConnected bool) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + + s.isConnected = isConnected + + if isConnected { + s.connectedAt = time.Now() + } else { + // Clear peer statuses when disconnected + s.peerStatuses = make(map[int]*PeerStatus) + } +} + +// handleConnect handles the /connect endpoint +func (s *HTTPServer) handleConnect(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req ConnectionRequest + decoder := json.NewDecoder(r.Body) + if err := decoder.Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest) + return + } + + // Validate required fields + if req.ID == "" || req.Secret == "" || req.Endpoint == "" { + http.Error(w, "Missing required fields: id, secret, and endpoint must be provided", http.StatusBadRequest) + return + } + + // Send the request to the main goroutine + s.connectionChan <- req + + // Return a success response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + json.NewEncoder(w).Encode(map[string]string{ + "status": "connection request accepted", + }) +} + +// handleStatus handles the /status endpoint +func (s *HTTPServer) handleStatus(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + s.statusMu.RLock() + defer s.statusMu.RUnlock() + + resp := StatusResponse{ + Connected: s.isConnected, + PeerStatuses: s.peerStatuses, + } + + if s.isConnected { + resp.Status = "connected" + } else { + resp.Status = "disconnected" + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) +} diff --git a/main.go b/main.go index 967c390..0cbc16b 100644 --- a/main.go +++ b/main.go @@ -17,6 +17,7 @@ import ( "time" "github.com/fosrl/newt/logger" + "github.com/fosrl/olm/httpserver" "github.com/fosrl/olm/peermonitor" "github.com/fosrl/olm/websocket" "github.com/vishvananda/netlink" @@ -130,6 +131,8 @@ func main() { err error logLevel string interfaceName string + enableHTTP bool + httpAddr string ) stopHolepunch = make(chan struct{}) @@ -144,6 +147,7 @@ func main() { dns = os.Getenv("DNS") logLevel = os.Getenv("LOG_LEVEL") interfaceName = os.Getenv("INTERFACE") + httpAddr = os.Getenv("HTTP_ADDR") if endpoint == "" { flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your Pangolin server") @@ -166,6 +170,11 @@ func main() { if interfaceName == "" { flag.StringVar(&interfaceName, "interface", "olm", "Name of the WireGuard interface") } + if httpAddr == "" { + flag.StringVar(&httpAddr, "http-addr", ":9452", "HTTP server address (e.g., ':9452')") + } + + flag.BoolVar(&enableHTTP, "http", false, "Enable HTTP server") // do a --version check version := flag.Bool("version", false, "Print the version") @@ -181,6 +190,32 @@ func main() { loggerLevel := parseLogLevel(logLevel) logger.GetLogger().SetLevel(parseLogLevel(logLevel)) + var httpServer *httpserver.HTTPServer + if enableHTTP { + httpServer = httpserver.NewHTTPServer(httpAddr) + if err := httpServer.Start(); err != nil { + logger.Fatal("Failed to start HTTP server: %v", err) + } + + // Use a goroutine to handle connection requests + go func() { + for req := range httpServer.GetConnectionChannel() { + logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint) + + // Set the connection parameters + id = req.ID + secret = req.Secret + endpoint = req.Endpoint + } + }() + } + + // wait until we have a client id and secret and endpoint + for id == "" || secret == "" || endpoint == "" { + logger.Debug("Waiting for client ID, secret, and endpoint...") + time.Sleep(1 * time.Second) + } + // parse the mtu string into an int mtuInt, err = strconv.Atoi(mtu) if err != nil { @@ -357,6 +392,9 @@ func main() { peerMonitor = peermonitor.NewPeerMonitor( func(siteID int, connected bool, rtt time.Duration) { + if httpServer != nil { + httpServer.UpdatePeerStatus(siteID, connected, rtt) + } if connected { logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt) } else { @@ -370,6 +408,9 @@ func main() { // loop over the sites and call ConfigurePeer for each one for _, site := range wgData.Sites { + if httpServer != nil { + httpServer.UpdatePeerStatus(site.SiteId, false, 0) + } err = ConfigurePeer(dev, site, privateKey, endpoint) if err != nil { logger.Error("Failed to configure peer: %v", err) @@ -548,6 +589,10 @@ func main() { go keepSendingRegistration(olm, publicKey.String()) go keepSendingPing(olm) + if httpServer != nil { + httpServer.SetConnectionStatus(true) + } + logger.Info("Sent registration message") return nil })