diff --git a/api/api.go b/api/api.go index a6ac9cd..442162e 100644 --- a/api/api.go +++ b/api/api.go @@ -61,6 +61,11 @@ type StatusResponse struct { NetworkSettings network.NetworkSettings `json:"networkSettings,omitempty"` } +type MetadataChangeRequest struct { + Fingerprint map[string]any `json:"fingerprint"` + Postures map[string]any `json:"postures"` +} + // API represents the HTTP server and its state type API struct { addr string @@ -68,10 +73,11 @@ type API struct { listener net.Listener server *http.Server - onConnect func(ConnectionRequest) error - onSwitchOrg func(SwitchOrgRequest) error - onDisconnect func() error - onExit func() error + onConnect func(ConnectionRequest) error + onSwitchOrg func(SwitchOrgRequest) error + onMetadataChange func(MetadataChangeRequest) error + onDisconnect func() error + onExit func() error statusMu sync.RWMutex peerStatuses map[int]*PeerStatus @@ -117,6 +123,7 @@ func NewAPIStub() *API { func (s *API) SetHandlers( onConnect func(ConnectionRequest) error, onSwitchOrg func(SwitchOrgRequest) error, + onMetadataChange func(MetadataChangeRequest) error, onDisconnect func() error, onExit func() error, ) { @@ -136,6 +143,7 @@ func (s *API) Start() error { mux.HandleFunc("/connect", s.handleConnect) mux.HandleFunc("/status", s.handleStatus) mux.HandleFunc("/switch-org", s.handleSwitchOrg) + mux.HandleFunc("/metadata", s.handleMetadataChange) mux.HandleFunc("/disconnect", s.handleDisconnect) mux.HandleFunc("/exit", s.handleExit) mux.HandleFunc("/health", s.handleHealth) @@ -514,6 +522,32 @@ func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) { }) } +// handleMetadataChange handles the /metadata endpoint +func (s *API) handleMetadataChange(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req MetadataChangeRequest + decoder := json.NewDecoder(r.Body) + if err := decoder.Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest) + return + } + + logger.Info("Received metadata change request via API: %v", req) + + _ = s.onMetadataChange(req) + + // Return a success response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{ + "status": "metadata updated", + }) +} + func (s *API) GetStatus() StatusResponse { return StatusResponse{ Connected: s.isConnected, diff --git a/olm/olm.go b/olm/olm.go index 2db3630..0810025 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -7,6 +7,7 @@ import ( "net/http" _ "net/http/pprof" "os" + "sync" "time" "github.com/fosrl/newt/bind" @@ -51,6 +52,11 @@ type Olm struct { olmConfig OlmConfig tunnelConfig TunnelConfig + // Metadata to send alongside pings + fingerprint map[string]any + postures map[string]any + metaMu sync.Mutex + stopRegister func() stopPeerSend func() updateRegister func(newData any) @@ -229,6 +235,20 @@ func (o *Olm) registerAPICallbacks() { logger.Info("Received switch organization request via HTTP: orgID=%s", req.OrgID) return o.SwitchOrg(req.OrgID) }, + // onMetadataChange + func(req api.MetadataChangeRequest) error { + logger.Info("Received change metadata request via API") + + if req.Fingerprint != nil { + o.SetFingerprint(req.Fingerprint) + } + + if req.Postures != nil { + o.SetPostures(req.Postures) + } + + return nil + }, // onDisconnect func() error { logger.Info("Processing disconnect request via API") @@ -336,12 +356,14 @@ func (o *Olm) StartTunnel(config TunnelConfig) { if o.stopRegister == nil { logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch) o.stopRegister, o.updateRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]any{ - "publicKey": publicKey.String(), - "relay": !config.Holepunch, - "olmVersion": o.olmConfig.Version, - "olmAgent": o.olmConfig.Agent, - "orgId": config.OrgID, - "userToken": userToken, + "publicKey": publicKey.String(), + "relay": !config.Holepunch, + "olmVersion": o.olmConfig.Version, + "olmAgent": o.olmConfig.Agent, + "orgId": config.OrgID, + "userToken": userToken, + "fingerprint": o.fingerprint, + "postures": o.postures, }, 1*time.Second) // Invoke onRegistered callback if configured @@ -404,6 +426,19 @@ func (o *Olm) StartTunnel(config TunnelConfig) { } }) + fingerprint := config.InitialFingerprint + if fingerprint == nil { + fingerprint = make(map[string]any) + } + + postures := config.InitialPostures + if postures == nil { + postures = make(map[string]any) + } + + o.SetFingerprint(fingerprint) + o.SetPostures(postures) + // Connect to the WebSocket server if err := olmClient.Connect(); err != nil { logger.Error("Failed to connect to server: %v", err) @@ -577,28 +612,16 @@ func (o *Olm) SwitchOrg(orgID string) error { return nil } -func (o *Olm) AddDevice(fd uint32) error { - if o.middleDev == nil { - return fmt.Errorf("middle device is not initialized") - } +func (o *Olm) SetFingerprint(data map[string]any) { + o.metaMu.Lock() + defer o.metaMu.Unlock() - if o.tunnelConfig.MTU == 0 { - return fmt.Errorf("tunnel MTU is not set") - } - - tdev, err := olmDevice.CreateTUNFromFD(fd, o.tunnelConfig.MTU) - if err != nil { - return fmt.Errorf("failed to create TUN device from fd: %v", err) - } - - // Update interface name if available - if realInterfaceName, err2 := tdev.Name(); err2 == nil { - o.tunnelConfig.InterfaceName = realInterfaceName - } - - // Replace the existing TUN device in the middle device with the new one - o.middleDev.AddDevice(tdev) - - logger.Info("Added device from file descriptor %d", fd) - return nil + o.fingerprint = data +} + +func (o *Olm) SetPostures(data map[string]any) { + o.metaMu.Lock() + defer o.metaMu.Unlock() + + o.postures = data } diff --git a/olm/ping.go b/olm/ping.go index fd7706a..460fc38 100644 --- a/olm/ping.go +++ b/olm/ping.go @@ -8,11 +8,12 @@ import ( "github.com/fosrl/olm/websocket" ) -func sendPing(olm *websocket.Client) error { - logger.Debug("Sending ping message") +func (o *Olm) sendPing(olm *websocket.Client) error { err := olm.SendMessage("olm/ping", map[string]any{ - "timestamp": time.Now().Unix(), - "userToken": olm.GetConfig().UserToken, + "timestamp": time.Now().Unix(), + "userToken": olm.GetConfig().UserToken, + "fingerprint": o.fingerprint, + "postures": o.postures, }) if err != nil { logger.Error("Failed to send ping message: %v", err) @@ -24,7 +25,7 @@ func sendPing(olm *websocket.Client) error { func (o *Olm) keepSendingPing(olm *websocket.Client) { // Send ping immediately on startup - if err := sendPing(olm); err != nil { + if err := o.sendPing(olm); err != nil { logger.Error("Failed to send initial ping: %v", err) } else { logger.Info("Sent initial ping message") @@ -40,7 +41,7 @@ func (o *Olm) keepSendingPing(olm *websocket.Client) { logger.Info("Stopping ping messages") return case <-ticker.C: - if err := sendPing(olm); err != nil { + if err := o.sendPing(olm); err != nil { logger.Error("Failed to send periodic ping: %v", err) } } diff --git a/olm/types.go b/olm/types.go index 77c0b5f..28e2260 100644 --- a/olm/types.go +++ b/olm/types.go @@ -67,5 +67,8 @@ type TunnelConfig struct { OverrideDNS bool TunnelDNS bool + InitialFingerprint map[string]any + InitialPostures map[string]any + DisableRelay bool }