diff --git a/api/api.go b/api/api.go index a6ac9cd..442162e 100644 --- a/api/api.go +++ b/api/api.go @@ -61,6 +61,11 @@ type StatusResponse struct { NetworkSettings network.NetworkSettings `json:"networkSettings,omitempty"` } +type MetadataChangeRequest struct { + Fingerprint map[string]any `json:"fingerprint"` + Postures map[string]any `json:"postures"` +} + // API represents the HTTP server and its state type API struct { addr string @@ -68,10 +73,11 @@ type API struct { listener net.Listener server *http.Server - onConnect func(ConnectionRequest) error - onSwitchOrg func(SwitchOrgRequest) error - onDisconnect func() error - onExit func() error + onConnect func(ConnectionRequest) error + onSwitchOrg func(SwitchOrgRequest) error + onMetadataChange func(MetadataChangeRequest) error + onDisconnect func() error + onExit func() error statusMu sync.RWMutex peerStatuses map[int]*PeerStatus @@ -117,6 +123,7 @@ func NewAPIStub() *API { func (s *API) SetHandlers( onConnect func(ConnectionRequest) error, onSwitchOrg func(SwitchOrgRequest) error, + onMetadataChange func(MetadataChangeRequest) error, onDisconnect func() error, onExit func() error, ) { @@ -136,6 +143,7 @@ func (s *API) Start() error { mux.HandleFunc("/connect", s.handleConnect) mux.HandleFunc("/status", s.handleStatus) mux.HandleFunc("/switch-org", s.handleSwitchOrg) + mux.HandleFunc("/metadata", s.handleMetadataChange) mux.HandleFunc("/disconnect", s.handleDisconnect) mux.HandleFunc("/exit", s.handleExit) mux.HandleFunc("/health", s.handleHealth) @@ -514,6 +522,32 @@ func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) { }) } +// handleMetadataChange handles the /metadata endpoint +func (s *API) handleMetadataChange(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req MetadataChangeRequest + decoder := json.NewDecoder(r.Body) + if err := decoder.Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest) + return + } + + logger.Info("Received metadata change request via API: %v", req) + + _ = s.onMetadataChange(req) + + // Return a success response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{ + "status": "metadata updated", + }) +} + func (s *API) GetStatus() StatusResponse { return StatusResponse{ Connected: s.isConnected, diff --git a/olm/olm.go b/olm/olm.go index 22a936f..97bd4b7 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -59,6 +59,11 @@ type Olm struct { olmConfig OlmConfig tunnelConfig TunnelConfig + // Metadata to send alongside pings + fingerprint map[string]any + postures map[string]any + metaMu sync.Mutex + stopRegister func() updateRegister func(newData any) @@ -240,6 +245,20 @@ func (o *Olm) registerAPICallbacks() { logger.Info("Received switch organization request via HTTP: orgID=%s", req.OrgID) return o.SwitchOrg(req.OrgID) }, + // onMetadataChange + func(req api.MetadataChangeRequest) error { + logger.Info("Received change metadata request via API") + + if req.Fingerprint != nil { + o.SetFingerprint(req.Fingerprint) + } + + if req.Postures != nil { + o.SetPostures(req.Postures) + } + + return nil + }, // onDisconnect func() error { logger.Info("Processing disconnect request via API") @@ -346,12 +365,14 @@ func (o *Olm) StartTunnel(config TunnelConfig) { if o.stopRegister == nil { logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch) o.stopRegister, o.updateRegister = o.websocket.SendMessageInterval("olm/wg/register", map[string]any{ - "publicKey": publicKey.String(), - "relay": !config.Holepunch, - "olmVersion": o.olmConfig.Version, - "olmAgent": o.olmConfig.Agent, - "orgId": config.OrgID, - "userToken": userToken, + "publicKey": publicKey.String(), + "relay": !config.Holepunch, + "olmVersion": o.olmConfig.Version, + "olmAgent": o.olmConfig.Agent, + "orgId": config.OrgID, + "userToken": userToken, + "fingerprint": o.fingerprint, + "postures": o.postures, }, 1*time.Second, 10) // Invoke onRegistered callback if configured @@ -412,6 +433,19 @@ func (o *Olm) StartTunnel(config TunnelConfig) { } }) + fingerprint := config.InitialFingerprint + if fingerprint == nil { + fingerprint = make(map[string]any) + } + + postures := config.InitialPostures + if postures == nil { + postures = make(map[string]any) + } + + o.SetFingerprint(fingerprint) + o.SetPostures(postures) + // Connect to the WebSocket server if err := o.websocket.Connect(); err != nil { logger.Error("Failed to connect to server: %v", err) @@ -608,6 +642,20 @@ func (o *Olm) SwitchOrg(orgID string) error { return nil } +func (o *Olm) SetFingerprint(data map[string]any) { + o.metaMu.Lock() + defer o.metaMu.Unlock() + + o.fingerprint = data +} + +func (o *Olm) SetPostures(data map[string]any) { + o.metaMu.Lock() + defer o.metaMu.Unlock() + + o.postures = data +} + // SetPowerMode switches between normal and low power modes // In low power mode: websocket is closed (stopping pings) and monitoring intervals are set to 10 minutes // In normal power mode: websocket is reconnected (restarting pings) and monitoring intervals are restored diff --git a/olm/types.go b/olm/types.go index 804f8e5..491ed19 100644 --- a/olm/types.go +++ b/olm/types.go @@ -73,5 +73,8 @@ type TunnelConfig struct { OverrideDNS bool TunnelDNS bool + InitialFingerprint map[string]any + InitialPostures map[string]any + DisableRelay bool }