diff --git a/api/api.go b/api/api.go index e18bee7..b11cc70 100644 --- a/api/api.go +++ b/api/api.go @@ -49,11 +49,18 @@ type PeerStatus struct { HolepunchConnected bool `json:"holepunchConnected"` } +// OlmError holds error information from registration failures +type OlmError struct { + Code string `json:"code"` + Message string `json:"message"` +} + // StatusResponse is returned by the status endpoint type StatusResponse struct { Connected bool `json:"connected"` Registered bool `json:"registered"` Terminated bool `json:"terminated"` + OlmError *OlmError `json:"error,omitempty"` Version string `json:"version,omitempty"` Agent string `json:"agent,omitempty"` OrgID string `json:"orgId,omitempty"` @@ -86,6 +93,7 @@ type API struct { isConnected bool isRegistered bool isTerminated bool + olmError *OlmError version string agent string @@ -141,7 +149,7 @@ func (s *API) Start() error { if s.socketPath == "" && s.addr == "" { return fmt.Errorf("either socketPath or addr must be provided to start the API server") } - + mux := http.NewServeMux() mux.HandleFunc("/connect", s.handleConnect) mux.HandleFunc("/status", s.handleStatus) @@ -264,6 +272,27 @@ func (s *API) SetRegistered(registered bool) { s.statusMu.Lock() defer s.statusMu.Unlock() s.isRegistered = registered + // Clear any registration error when successfully registered + if registered { + s.olmError = nil + } +} + +// SetOlmError sets the registration error +func (s *API) SetOlmError(code string, message string) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + s.olmError = &OlmError{ + Code: code, + Message: message, + } +} + +// ClearOlmError clears any registration error +func (s *API) ClearOlmError() { + s.statusMu.Lock() + defer s.statusMu.Unlock() + s.olmError = nil } func (s *API) SetTerminated(terminated bool) { @@ -391,6 +420,7 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) { Connected: s.isConnected, Registered: s.isRegistered, Terminated: s.isTerminated, + OlmError: s.olmError, Version: s.version, Agent: s.agent, OrgID: s.orgID, @@ -557,6 +587,7 @@ func (s *API) GetStatus() StatusResponse { Connected: s.isConnected, Registered: s.isRegistered, Terminated: s.isTerminated, + OlmError: s.olmError, Version: s.version, Agent: s.agent, OrgID: s.orgID, diff --git a/olm/connect.go b/olm/connect.go index 7f3785e..575a8fd 100644 --- a/olm/connect.go +++ b/olm/connect.go @@ -19,6 +19,12 @@ import ( "golang.zx2c4.com/wireguard/tun" ) +// OlmErrorData represents the error data sent from the server +type OlmErrorData struct { + Code string `json:"code"` + Message string `json:"message"` +} + func (o *Olm) handleConnect(msg websocket.WSMessage) { logger.Debug("Received message: %v", msg.Data) @@ -209,8 +215,51 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) { logger.Info("WireGuard device created.") } +func (o *Olm) handleOlmError(msg websocket.WSMessage) { + logger.Debug("Received olm error message: %v", msg.Data) + + var errorData OlmErrorData + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling olm error data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &errorData); err != nil { + logger.Error("Error unmarshaling olm error data: %v", err) + return + } + + logger.Error("Olm error (code: %s): %s", errorData.Code, errorData.Message) + + // Set the olm error in the API server so it can be exposed via status + o.apiServer.SetOlmError(errorData.Code, errorData.Message) + + // Invoke onOlmError callback if configured + if o.olmConfig.OnOlmError != nil { + go o.olmConfig.OnOlmError(errorData.Code, errorData.Message) + } +} + func (o *Olm) handleTerminate(msg websocket.WSMessage) { logger.Info("Received terminate message") + + var errorData OlmErrorData + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling terminate error data: %v", err) + } else { + if err := json.Unmarshal(jsonData, &errorData); err != nil { + logger.Error("Error unmarshaling terminate error data: %v", err) + } else { + logger.Info("Terminate reason (code: %s): %s", errorData.Code, errorData.Message) + // Set the olm error in the API server so it can be exposed via status + o.apiServer.SetOlmError(errorData.Code, errorData.Message) + } + } + o.apiServer.SetTerminated(true) o.apiServer.SetConnectionStatus(false) o.apiServer.SetRegistered(false) diff --git a/olm/olm.go b/olm/olm.go index b43ddd7..6c975d3 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -342,6 +342,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { // Handlers for managing connection status o.websocket.RegisterHandler("olm/wg/connect", o.handleConnect) + o.websocket.RegisterHandler("olm/error", o.handleOlmError) o.websocket.RegisterHandler("olm/terminate", o.handleTerminate) // Handlers for managing peers @@ -434,6 +435,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { o.apiServer.SetTerminated(true) o.apiServer.SetConnectionStatus(false) o.apiServer.SetRegistered(false) + o.apiServer.ClearOlmError() o.apiServer.ClearPeerStatuses() network.ClearNetworkSettings() @@ -478,20 +480,20 @@ func (o *Olm) StartTunnel(config TunnelConfig) { for { select { case <-sigChan: - - logger.Info("SIGHUP received, toggling power mode") - if powerMode == "normal" { - powerMode = "low" - if err := o.SetPowerMode("low"); err != nil { - logger.Error("Failed to set low power mode: %v", err) + + logger.Info("SIGHUP received, toggling power mode") + if powerMode == "normal" { + powerMode = "low" + if err := o.SetPowerMode("low"); err != nil { + logger.Error("Failed to set low power mode: %v", err) + } + } else { + powerMode = "normal" + if err := o.SetPowerMode("normal"); err != nil { + logger.Error("Failed to set normal power mode: %v", err) + } } - } else { - powerMode = "normal" - if err := o.SetPowerMode("normal"); err != nil { - logger.Error("Failed to set normal power mode: %v", err) - } - } - + case <-tunnelCtx.Done(): return } @@ -604,6 +606,7 @@ func (o *Olm) StopTunnel() error { // Update API server status o.apiServer.SetConnectionStatus(false) o.apiServer.SetRegistered(false) + o.apiServer.ClearOlmError() network.ClearNetworkSettings() o.apiServer.ClearPeerStatuses() diff --git a/olm/types.go b/olm/types.go index 2e56ad7..198b222 100644 --- a/olm/types.go +++ b/olm/types.go @@ -46,6 +46,7 @@ type OlmConfig struct { OnConnected func() OnTerminated func() OnAuthError func(statusCode int, message string) // Called when auth fails (401/403) + OnOlmError func(code string, message string) // Called when registration fails OnExit func() // Called when exit is requested via API }