Merge branch 'dev' of github.com:fosrl/olm into dev

Former-commit-id: 1c21071ee1
This commit is contained in:
Owen
2026-01-15 16:37:09 -08:00
4 changed files with 100 additions and 39 deletions

View File

@@ -61,6 +61,11 @@ type StatusResponse struct {
NetworkSettings network.NetworkSettings `json:"networkSettings,omitempty"` NetworkSettings network.NetworkSettings `json:"networkSettings,omitempty"`
} }
type MetadataChangeRequest struct {
Fingerprint map[string]any `json:"fingerprint"`
Postures map[string]any `json:"postures"`
}
// API represents the HTTP server and its state // API represents the HTTP server and its state
type API struct { type API struct {
addr string addr string
@@ -68,10 +73,11 @@ type API struct {
listener net.Listener listener net.Listener
server *http.Server server *http.Server
onConnect func(ConnectionRequest) error onConnect func(ConnectionRequest) error
onSwitchOrg func(SwitchOrgRequest) error onSwitchOrg func(SwitchOrgRequest) error
onDisconnect func() error onMetadataChange func(MetadataChangeRequest) error
onExit func() error onDisconnect func() error
onExit func() error
statusMu sync.RWMutex statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus peerStatuses map[int]*PeerStatus
@@ -117,6 +123,7 @@ func NewAPIStub() *API {
func (s *API) SetHandlers( func (s *API) SetHandlers(
onConnect func(ConnectionRequest) error, onConnect func(ConnectionRequest) error,
onSwitchOrg func(SwitchOrgRequest) error, onSwitchOrg func(SwitchOrgRequest) error,
onMetadataChange func(MetadataChangeRequest) error,
onDisconnect func() error, onDisconnect func() error,
onExit func() error, onExit func() error,
) { ) {
@@ -136,6 +143,7 @@ func (s *API) Start() error {
mux.HandleFunc("/connect", s.handleConnect) mux.HandleFunc("/connect", s.handleConnect)
mux.HandleFunc("/status", s.handleStatus) mux.HandleFunc("/status", s.handleStatus)
mux.HandleFunc("/switch-org", s.handleSwitchOrg) mux.HandleFunc("/switch-org", s.handleSwitchOrg)
mux.HandleFunc("/metadata", s.handleMetadataChange)
mux.HandleFunc("/disconnect", s.handleDisconnect) mux.HandleFunc("/disconnect", s.handleDisconnect)
mux.HandleFunc("/exit", s.handleExit) mux.HandleFunc("/exit", s.handleExit)
mux.HandleFunc("/health", s.handleHealth) mux.HandleFunc("/health", s.handleHealth)
@@ -514,6 +522,32 @@ func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
}) })
} }
// handleMetadataChange handles the /metadata endpoint
func (s *API) handleMetadataChange(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req MetadataChangeRequest
decoder := json.NewDecoder(r.Body)
if err := decoder.Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
return
}
logger.Info("Received metadata change request via API: %v", req)
_ = s.onMetadataChange(req)
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "metadata updated",
})
}
func (s *API) GetStatus() StatusResponse { func (s *API) GetStatus() StatusResponse {
return StatusResponse{ return StatusResponse{
Connected: s.isConnected, Connected: s.isConnected,

View File

@@ -7,6 +7,7 @@ import (
"net/http" "net/http"
_ "net/http/pprof" _ "net/http/pprof"
"os" "os"
"sync"
"time" "time"
"github.com/fosrl/newt/bind" "github.com/fosrl/newt/bind"
@@ -51,6 +52,11 @@ type Olm struct {
olmConfig OlmConfig olmConfig OlmConfig
tunnelConfig TunnelConfig tunnelConfig TunnelConfig
// Metadata to send alongside pings
fingerprint map[string]any
postures map[string]any
metaMu sync.Mutex
stopRegister func() stopRegister func()
stopPeerSend func() stopPeerSend func()
updateRegister func(newData any) updateRegister func(newData any)
@@ -229,6 +235,20 @@ func (o *Olm) registerAPICallbacks() {
logger.Info("Received switch organization request via HTTP: orgID=%s", req.OrgID) logger.Info("Received switch organization request via HTTP: orgID=%s", req.OrgID)
return o.SwitchOrg(req.OrgID) return o.SwitchOrg(req.OrgID)
}, },
// onMetadataChange
func(req api.MetadataChangeRequest) error {
logger.Info("Received change metadata request via API")
if req.Fingerprint != nil {
o.SetFingerprint(req.Fingerprint)
}
if req.Postures != nil {
o.SetPostures(req.Postures)
}
return nil
},
// onDisconnect // onDisconnect
func() error { func() error {
logger.Info("Processing disconnect request via API") logger.Info("Processing disconnect request via API")
@@ -336,12 +356,14 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
if o.stopRegister == nil { if o.stopRegister == nil {
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch) logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch)
o.stopRegister, o.updateRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]any{ o.stopRegister, o.updateRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]any{
"publicKey": publicKey.String(), "publicKey": publicKey.String(),
"relay": !config.Holepunch, "relay": !config.Holepunch,
"olmVersion": o.olmConfig.Version, "olmVersion": o.olmConfig.Version,
"olmAgent": o.olmConfig.Agent, "olmAgent": o.olmConfig.Agent,
"orgId": config.OrgID, "orgId": config.OrgID,
"userToken": userToken, "userToken": userToken,
"fingerprint": o.fingerprint,
"postures": o.postures,
}, 1*time.Second) }, 1*time.Second)
// Invoke onRegistered callback if configured // Invoke onRegistered callback if configured
@@ -404,6 +426,19 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
} }
}) })
fingerprint := config.InitialFingerprint
if fingerprint == nil {
fingerprint = make(map[string]any)
}
postures := config.InitialPostures
if postures == nil {
postures = make(map[string]any)
}
o.SetFingerprint(fingerprint)
o.SetPostures(postures)
// Connect to the WebSocket server // Connect to the WebSocket server
if err := olmClient.Connect(); err != nil { if err := olmClient.Connect(); err != nil {
logger.Error("Failed to connect to server: %v", err) logger.Error("Failed to connect to server: %v", err)
@@ -577,28 +612,16 @@ func (o *Olm) SwitchOrg(orgID string) error {
return nil return nil
} }
func (o *Olm) AddDevice(fd uint32) error { func (o *Olm) SetFingerprint(data map[string]any) {
if o.middleDev == nil { o.metaMu.Lock()
return fmt.Errorf("middle device is not initialized") defer o.metaMu.Unlock()
}
if o.tunnelConfig.MTU == 0 { o.fingerprint = data
return fmt.Errorf("tunnel MTU is not set") }
}
func (o *Olm) SetPostures(data map[string]any) {
tdev, err := olmDevice.CreateTUNFromFD(fd, o.tunnelConfig.MTU) o.metaMu.Lock()
if err != nil { defer o.metaMu.Unlock()
return fmt.Errorf("failed to create TUN device from fd: %v", err)
} o.postures = data
// Update interface name if available
if realInterfaceName, err2 := tdev.Name(); err2 == nil {
o.tunnelConfig.InterfaceName = realInterfaceName
}
// Replace the existing TUN device in the middle device with the new one
o.middleDev.AddDevice(tdev)
logger.Info("Added device from file descriptor %d", fd)
return nil
} }

View File

@@ -8,11 +8,12 @@ import (
"github.com/fosrl/olm/websocket" "github.com/fosrl/olm/websocket"
) )
func sendPing(olm *websocket.Client) error { func (o *Olm) sendPing(olm *websocket.Client) error {
logger.Debug("Sending ping message")
err := olm.SendMessage("olm/ping", map[string]any{ err := olm.SendMessage("olm/ping", map[string]any{
"timestamp": time.Now().Unix(), "timestamp": time.Now().Unix(),
"userToken": olm.GetConfig().UserToken, "userToken": olm.GetConfig().UserToken,
"fingerprint": o.fingerprint,
"postures": o.postures,
}) })
if err != nil { if err != nil {
logger.Error("Failed to send ping message: %v", err) logger.Error("Failed to send ping message: %v", err)
@@ -24,7 +25,7 @@ func sendPing(olm *websocket.Client) error {
func (o *Olm) keepSendingPing(olm *websocket.Client) { func (o *Olm) keepSendingPing(olm *websocket.Client) {
// Send ping immediately on startup // Send ping immediately on startup
if err := sendPing(olm); err != nil { if err := o.sendPing(olm); err != nil {
logger.Error("Failed to send initial ping: %v", err) logger.Error("Failed to send initial ping: %v", err)
} else { } else {
logger.Info("Sent initial ping message") logger.Info("Sent initial ping message")
@@ -40,7 +41,7 @@ func (o *Olm) keepSendingPing(olm *websocket.Client) {
logger.Info("Stopping ping messages") logger.Info("Stopping ping messages")
return return
case <-ticker.C: case <-ticker.C:
if err := sendPing(olm); err != nil { if err := o.sendPing(olm); err != nil {
logger.Error("Failed to send periodic ping: %v", err) logger.Error("Failed to send periodic ping: %v", err)
} }
} }

View File

@@ -67,5 +67,8 @@ type TunnelConfig struct {
OverrideDNS bool OverrideDNS bool
TunnelDNS bool TunnelDNS bool
InitialFingerprint map[string]any
InitialPostures map[string]any
DisableRelay bool DisableRelay bool
} }