diff --git a/api/api.go b/api/api.go index 442162e..b85b041 100644 --- a/api/api.go +++ b/api/api.go @@ -49,11 +49,18 @@ type PeerStatus struct { HolepunchConnected bool `json:"holepunchConnected"` } +// OlmError holds error information from registration failures +type OlmError struct { + Code string `json:"code"` + Message string `json:"message"` +} + // StatusResponse is returned by the status endpoint type StatusResponse struct { Connected bool `json:"connected"` Registered bool `json:"registered"` Terminated bool `json:"terminated"` + OlmError *OlmError `json:"error,omitempty"` Version string `json:"version,omitempty"` Agent string `json:"agent,omitempty"` OrgID string `json:"orgId,omitempty"` @@ -85,6 +92,7 @@ type API struct { isConnected bool isRegistered bool isTerminated bool + olmError *OlmError version string agent string @@ -138,7 +146,7 @@ func (s *API) Start() error { if s.socketPath == "" && s.addr == "" { return fmt.Errorf("either socketPath or addr must be provided to start the API server") } - + mux := http.NewServeMux() mux.HandleFunc("/connect", s.handleConnect) mux.HandleFunc("/status", s.handleStatus) @@ -260,6 +268,27 @@ func (s *API) SetRegistered(registered bool) { s.statusMu.Lock() defer s.statusMu.Unlock() s.isRegistered = registered + // Clear any registration error when successfully registered + if registered { + s.olmError = nil + } +} + +// SetOlmError sets the registration error +func (s *API) SetOlmError(code string, message string) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + s.olmError = &OlmError{ + Code: code, + Message: message, + } +} + +// ClearOlmError clears any registration error +func (s *API) ClearOlmError() { + s.statusMu.Lock() + defer s.statusMu.Unlock() + s.olmError = nil } func (s *API) SetTerminated(terminated bool) { @@ -387,6 +416,7 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) { Connected: s.isConnected, Registered: s.isRegistered, Terminated: s.isTerminated, + OlmError: s.olmError, Version: s.version, Agent: s.agent, OrgID: s.orgID, @@ -553,6 +583,7 @@ func (s *API) GetStatus() StatusResponse { Connected: s.isConnected, Registered: s.isRegistered, Terminated: s.isTerminated, + OlmError: s.olmError, Version: s.version, Agent: s.agent, OrgID: s.orgID, diff --git a/olm/connect.go b/olm/connect.go index a610ea4..ebe7009 100644 --- a/olm/connect.go +++ b/olm/connect.go @@ -19,6 +19,12 @@ import ( "golang.zx2c4.com/wireguard/tun" ) +// OlmErrorData represents the error data sent from the server +type OlmErrorData struct { + Code string `json:"code"` + Message string `json:"message"` +} + func (o *Olm) handleConnect(msg websocket.WSMessage) { logger.Debug("Received message: %v", msg.Data) @@ -206,11 +212,39 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) { logger.Info("WireGuard device created.") } +func (o *Olm) handleOlmError(msg websocket.WSMessage) { + logger.Debug("Received olm error message: %v", msg.Data) + + var errorData OlmErrorData + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling olm error data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &errorData); err != nil { + logger.Error("Error unmarshaling olm error data: %v", err) + return + } + + logger.Error("Olm error (code: %s): %s", errorData.Code, errorData.Message) + + // Set the olm error in the API server so it can be exposed via status + o.apiServer.SetOlmError(errorData.Code, errorData.Message) + + // Invoke onOlmError callback if configured + if o.olmConfig.OnOlmError != nil { + go o.olmConfig.OnOlmError(errorData.Code, errorData.Message) + } +} + func (o *Olm) handleTerminate(msg websocket.WSMessage) { logger.Info("Received terminate message") o.apiServer.SetTerminated(true) o.apiServer.SetConnectionStatus(false) o.apiServer.SetRegistered(false) + o.apiServer.ClearOlmError() o.apiServer.ClearPeerStatuses() network.ClearNetworkSettings() diff --git a/olm/olm.go b/olm/olm.go index bc06602..df6cad0 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -337,6 +337,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { // Handlers for managing connection status o.websocket.RegisterHandler("olm/wg/connect", o.handleConnect) + o.websocket.RegisterHandler("olm/error", o.handleOlmError) o.websocket.RegisterHandler("olm/terminate", o.handleTerminate) // Handlers for managing peers @@ -427,6 +428,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { o.apiServer.SetTerminated(true) o.apiServer.SetConnectionStatus(false) o.apiServer.SetRegistered(false) + o.apiServer.ClearOlmError() o.apiServer.ClearPeerStatuses() network.ClearNetworkSettings() @@ -471,20 +473,20 @@ func (o *Olm) StartTunnel(config TunnelConfig) { for { select { case <-sigChan: - - logger.Info("SIGHUP received, toggling power mode") - if powerMode == "normal" { - powerMode = "low" - if err := o.SetPowerMode("low"); err != nil { - logger.Error("Failed to set low power mode: %v", err) + + logger.Info("SIGHUP received, toggling power mode") + if powerMode == "normal" { + powerMode = "low" + if err := o.SetPowerMode("low"); err != nil { + logger.Error("Failed to set low power mode: %v", err) + } + } else { + powerMode = "normal" + if err := o.SetPowerMode("normal"); err != nil { + logger.Error("Failed to set normal power mode: %v", err) + } } - } else { - powerMode = "normal" - if err := o.SetPowerMode("normal"); err != nil { - logger.Error("Failed to set normal power mode: %v", err) - } - } - + case <-tunnelCtx.Done(): return } @@ -597,6 +599,7 @@ func (o *Olm) StopTunnel() error { // Update API server status o.apiServer.SetConnectionStatus(false) o.apiServer.SetRegistered(false) + o.apiServer.ClearOlmError() network.ClearNetworkSettings() o.apiServer.ClearPeerStatuses() diff --git a/olm/types.go b/olm/types.go index 2e56ad7..198b222 100644 --- a/olm/types.go +++ b/olm/types.go @@ -46,6 +46,7 @@ type OlmConfig struct { OnConnected func() OnTerminated func() OnAuthError func(statusCode int, message string) // Called when auth fails (401/403) + OnOlmError func(code string, message string) // Called when registration fails OnExit func() // Called when exit is requested via API }