Merge branch 'bubble-errors-up' into dev

This commit is contained in:
Owen
2026-01-18 11:38:20 -08:00
4 changed files with 98 additions and 14 deletions

View File

@@ -49,11 +49,18 @@ type PeerStatus struct {
HolepunchConnected bool `json:"holepunchConnected"` HolepunchConnected bool `json:"holepunchConnected"`
} }
// OlmError holds error information from registration failures
type OlmError struct {
Code string `json:"code"`
Message string `json:"message"`
}
// StatusResponse is returned by the status endpoint // StatusResponse is returned by the status endpoint
type StatusResponse struct { type StatusResponse struct {
Connected bool `json:"connected"` Connected bool `json:"connected"`
Registered bool `json:"registered"` Registered bool `json:"registered"`
Terminated bool `json:"terminated"` Terminated bool `json:"terminated"`
OlmError *OlmError `json:"error,omitempty"`
Version string `json:"version,omitempty"` Version string `json:"version,omitempty"`
Agent string `json:"agent,omitempty"` Agent string `json:"agent,omitempty"`
OrgID string `json:"orgId,omitempty"` OrgID string `json:"orgId,omitempty"`
@@ -86,6 +93,7 @@ type API struct {
isConnected bool isConnected bool
isRegistered bool isRegistered bool
isTerminated bool isTerminated bool
olmError *OlmError
version string version string
agent string agent string
@@ -141,7 +149,7 @@ func (s *API) Start() error {
if s.socketPath == "" && s.addr == "" { if s.socketPath == "" && s.addr == "" {
return fmt.Errorf("either socketPath or addr must be provided to start the API server") return fmt.Errorf("either socketPath or addr must be provided to start the API server")
} }
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/connect", s.handleConnect) mux.HandleFunc("/connect", s.handleConnect)
mux.HandleFunc("/status", s.handleStatus) mux.HandleFunc("/status", s.handleStatus)
@@ -264,6 +272,27 @@ func (s *API) SetRegistered(registered bool) {
s.statusMu.Lock() s.statusMu.Lock()
defer s.statusMu.Unlock() defer s.statusMu.Unlock()
s.isRegistered = registered s.isRegistered = registered
// Clear any registration error when successfully registered
if registered {
s.olmError = nil
}
}
// SetOlmError sets the registration error
func (s *API) SetOlmError(code string, message string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.olmError = &OlmError{
Code: code,
Message: message,
}
}
// ClearOlmError clears any registration error
func (s *API) ClearOlmError() {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.olmError = nil
} }
func (s *API) SetTerminated(terminated bool) { func (s *API) SetTerminated(terminated bool) {
@@ -391,6 +420,7 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
Connected: s.isConnected, Connected: s.isConnected,
Registered: s.isRegistered, Registered: s.isRegistered,
Terminated: s.isTerminated, Terminated: s.isTerminated,
OlmError: s.olmError,
Version: s.version, Version: s.version,
Agent: s.agent, Agent: s.agent,
OrgID: s.orgID, OrgID: s.orgID,
@@ -557,6 +587,7 @@ func (s *API) GetStatus() StatusResponse {
Connected: s.isConnected, Connected: s.isConnected,
Registered: s.isRegistered, Registered: s.isRegistered,
Terminated: s.isTerminated, Terminated: s.isTerminated,
OlmError: s.olmError,
Version: s.version, Version: s.version,
Agent: s.agent, Agent: s.agent,
OrgID: s.orgID, OrgID: s.orgID,

View File

@@ -19,6 +19,12 @@ import (
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
) )
// OlmErrorData represents the error data sent from the server
type OlmErrorData struct {
Code string `json:"code"`
Message string `json:"message"`
}
func (o *Olm) handleConnect(msg websocket.WSMessage) { func (o *Olm) handleConnect(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data) logger.Debug("Received message: %v", msg.Data)
@@ -209,8 +215,51 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
logger.Info("WireGuard device created.") logger.Info("WireGuard device created.")
} }
func (o *Olm) handleOlmError(msg websocket.WSMessage) {
logger.Debug("Received olm error message: %v", msg.Data)
var errorData OlmErrorData
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling olm error data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &errorData); err != nil {
logger.Error("Error unmarshaling olm error data: %v", err)
return
}
logger.Error("Olm error (code: %s): %s", errorData.Code, errorData.Message)
// Set the olm error in the API server so it can be exposed via status
o.apiServer.SetOlmError(errorData.Code, errorData.Message)
// Invoke onOlmError callback if configured
if o.olmConfig.OnOlmError != nil {
go o.olmConfig.OnOlmError(errorData.Code, errorData.Message)
}
}
func (o *Olm) handleTerminate(msg websocket.WSMessage) { func (o *Olm) handleTerminate(msg websocket.WSMessage) {
logger.Info("Received terminate message") logger.Info("Received terminate message")
var errorData OlmErrorData
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling terminate error data: %v", err)
} else {
if err := json.Unmarshal(jsonData, &errorData); err != nil {
logger.Error("Error unmarshaling terminate error data: %v", err)
} else {
logger.Info("Terminate reason (code: %s): %s", errorData.Code, errorData.Message)
// Set the olm error in the API server so it can be exposed via status
o.apiServer.SetOlmError(errorData.Code, errorData.Message)
}
}
o.apiServer.SetTerminated(true) o.apiServer.SetTerminated(true)
o.apiServer.SetConnectionStatus(false) o.apiServer.SetConnectionStatus(false)
o.apiServer.SetRegistered(false) o.apiServer.SetRegistered(false)

View File

@@ -342,6 +342,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
// Handlers for managing connection status // Handlers for managing connection status
o.websocket.RegisterHandler("olm/wg/connect", o.handleConnect) o.websocket.RegisterHandler("olm/wg/connect", o.handleConnect)
o.websocket.RegisterHandler("olm/error", o.handleOlmError)
o.websocket.RegisterHandler("olm/terminate", o.handleTerminate) o.websocket.RegisterHandler("olm/terminate", o.handleTerminate)
// Handlers for managing peers // Handlers for managing peers
@@ -434,6 +435,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
o.apiServer.SetTerminated(true) o.apiServer.SetTerminated(true)
o.apiServer.SetConnectionStatus(false) o.apiServer.SetConnectionStatus(false)
o.apiServer.SetRegistered(false) o.apiServer.SetRegistered(false)
o.apiServer.ClearOlmError()
o.apiServer.ClearPeerStatuses() o.apiServer.ClearPeerStatuses()
network.ClearNetworkSettings() network.ClearNetworkSettings()
@@ -478,20 +480,20 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
for { for {
select { select {
case <-sigChan: case <-sigChan:
logger.Info("SIGHUP received, toggling power mode") logger.Info("SIGHUP received, toggling power mode")
if powerMode == "normal" { if powerMode == "normal" {
powerMode = "low" powerMode = "low"
if err := o.SetPowerMode("low"); err != nil { if err := o.SetPowerMode("low"); err != nil {
logger.Error("Failed to set low power mode: %v", err) logger.Error("Failed to set low power mode: %v", err)
}
} else {
powerMode = "normal"
if err := o.SetPowerMode("normal"); err != nil {
logger.Error("Failed to set normal power mode: %v", err)
}
} }
} else {
powerMode = "normal"
if err := o.SetPowerMode("normal"); err != nil {
logger.Error("Failed to set normal power mode: %v", err)
}
}
case <-tunnelCtx.Done(): case <-tunnelCtx.Done():
return return
} }
@@ -604,6 +606,7 @@ func (o *Olm) StopTunnel() error {
// Update API server status // Update API server status
o.apiServer.SetConnectionStatus(false) o.apiServer.SetConnectionStatus(false)
o.apiServer.SetRegistered(false) o.apiServer.SetRegistered(false)
o.apiServer.ClearOlmError()
network.ClearNetworkSettings() network.ClearNetworkSettings()
o.apiServer.ClearPeerStatuses() o.apiServer.ClearPeerStatuses()

View File

@@ -46,6 +46,7 @@ type OlmConfig struct {
OnConnected func() OnConnected func()
OnTerminated func() OnTerminated func()
OnAuthError func(statusCode int, message string) // Called when auth fails (401/403) OnAuthError func(statusCode int, message string) // Called when auth fails (401/403)
OnOlmError func(code string, message string) // Called when registration fails
OnExit func() // Called when exit is requested via API OnExit func() // Called when exit is requested via API
} }