diff --git a/api/api.go b/api/api.go index 047ce08..895140b 100644 --- a/api/api.go +++ b/api/api.go @@ -78,6 +78,13 @@ type MetadataChangeRequest struct { Postures map[string]any `json:"postures"` } +// JITConnectionRequest defines the structure for a dynamic Just-In-Time connection request. +// Either SiteID or ResourceID must be provided (but not necessarily both). +type JITConnectionRequest struct { + Site string `json:"site,omitempty"` + Resource string `json:"resource,omitempty"` +} + // API represents the HTTP server and its state type API struct { addr string @@ -92,6 +99,7 @@ type API struct { onExit func() error onRebind func() error onPowerMode func(PowerModeRequest) error + onJITConnect func(JITConnectionRequest) error statusMu sync.RWMutex peerStatuses map[int]*PeerStatus @@ -143,6 +151,7 @@ func (s *API) SetHandlers( onExit func() error, onRebind func() error, onPowerMode func(PowerModeRequest) error, + onJITConnect func(JITConnectionRequest) error, ) { s.onConnect = onConnect s.onSwitchOrg = onSwitchOrg @@ -151,6 +160,7 @@ func (s *API) SetHandlers( s.onExit = onExit s.onRebind = onRebind s.onPowerMode = onPowerMode + s.onJITConnect = onJITConnect } // Start starts the HTTP server @@ -169,6 +179,7 @@ func (s *API) Start() error { mux.HandleFunc("/health", s.handleHealth) mux.HandleFunc("/rebind", s.handleRebind) mux.HandleFunc("/power-mode", s.handlePowerMode) + mux.HandleFunc("/jit-connect", s.handleJITConnect) s.server = &http.Server{ Handler: mux, @@ -633,6 +644,54 @@ func (s *API) handleRebind(w http.ResponseWriter, r *http.Request) { }) } +// handleJITConnect handles the /jit-connect endpoint. +// It initiates a dynamic Just-In-Time connection to a site identified by either +// a site or a resource. Exactly one of the two must be provided. +func (s *API) handleJITConnect(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req JITConnectionRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest) + return + } + + // Validate that exactly one of site or resource is provided + if req.Site == "" && req.Resource == "" { + http.Error(w, "Missing required field: either site or resource must be provided", http.StatusBadRequest) + return + } + if req.Site != "" && req.Resource != "" { + http.Error(w, "Ambiguous request: provide either site or resource, not both", http.StatusBadRequest) + return + } + + if req.Site != "" { + logger.Info("Received JIT connection request via API: site=%s", req.Site) + } else { + logger.Info("Received JIT connection request via API: resource=%s", req.Resource) + } + + if s.onJITConnect != nil { + if err := s.onJITConnect(req); err != nil { + http.Error(w, fmt.Sprintf("JIT connection failed: %v", err), http.StatusInternalServerError) + return + } + } else { + http.Error(w, "JIT connect handler not configured", http.StatusNotImplemented) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + _ = json.NewEncoder(w).Encode(map[string]string{ + "status": "JIT connection request accepted", + }) +} + // handlePowerMode handles the /power-mode endpoint // This allows changing the power mode between "normal" and "low" func (s *API) handlePowerMode(w http.ResponseWriter, r *http.Request) { diff --git a/olm/data.go b/olm/data.go index 8bd0997..015931b 100644 --- a/olm/data.go +++ b/olm/data.go @@ -220,6 +220,7 @@ func (o *Olm) handleSync(msg websocket.WSMessage) { logger.Info("Sync: Adding new peer for site %d", siteId) o.holePunchManager.TriggerHolePunch() + o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud // // TODO: do we need to send the message to the cloud to add the peer that way? // if err := o.peerManager.AddPeer(expectedSite); err != nil { diff --git a/olm/olm.go b/olm/olm.go index 9bd41b2..fa32ebd 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -66,6 +66,7 @@ type Olm struct { updateRegister func(newData any) stopPeerSend func() + stopPeerInit func() // WaitGroup to track tunnel lifecycle tunnelWg sync.WaitGroup @@ -284,6 +285,16 @@ func (o *Olm) registerAPICallbacks() { logger.Info("Processing power mode change request via API: mode=%s", req.Mode) return o.SetPowerMode(req.Mode) }, + func(req api.JITConnectionRequest) error { + logger.Info("Processing JIT connect request via API: site=%s resource=%s", req.Site, req.Resource) + + o.stopPeerInit, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{ + "siteId": req.Site, + "resourceId": req.Resource, + }, 2*time.Second, 10) + + return nil + }, ) }