mirror of
https://github.com/fosrl/olm.git
synced 2026-02-07 21:46:40 +00:00
refactor(olm): convert global state into an olm instance
Former-commit-id: b755f77d95
This commit is contained in:
committed by
Owen Schwartz
parent
8a788ef238
commit
15e96a779c
29
api/api.go
29
api/api.go
@@ -63,23 +63,26 @@ type StatusResponse struct {
|
|||||||
|
|
||||||
// API represents the HTTP server and its state
|
// API represents the HTTP server and its state
|
||||||
type API struct {
|
type API struct {
|
||||||
addr string
|
addr string
|
||||||
socketPath string
|
socketPath string
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
server *http.Server
|
server *http.Server
|
||||||
|
|
||||||
onConnect func(ConnectionRequest) error
|
onConnect func(ConnectionRequest) error
|
||||||
onSwitchOrg func(SwitchOrgRequest) error
|
onSwitchOrg func(SwitchOrgRequest) error
|
||||||
onDisconnect func() error
|
onDisconnect func() error
|
||||||
onExit func() error
|
onExit func() error
|
||||||
|
|
||||||
statusMu sync.RWMutex
|
statusMu sync.RWMutex
|
||||||
peerStatuses map[int]*PeerStatus
|
peerStatuses map[int]*PeerStatus
|
||||||
connectedAt time.Time
|
connectedAt time.Time
|
||||||
isConnected bool
|
isConnected bool
|
||||||
isRegistered bool
|
isRegistered bool
|
||||||
isTerminated bool
|
isTerminated bool
|
||||||
version string
|
|
||||||
agent string
|
version string
|
||||||
orgID string
|
agent string
|
||||||
|
orgID string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAPI creates a new HTTP server that listens on a TCP address
|
// NewAPI creates a new HTTP server that listens on a TCP address
|
||||||
@@ -173,7 +176,7 @@ func (s *API) Stop() error {
|
|||||||
|
|
||||||
// Close the server first, which will also close the listener gracefully
|
// Close the server first, which will also close the listener gracefully
|
||||||
if s.server != nil {
|
if s.server != nil {
|
||||||
s.server.Close()
|
_ = s.server.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean up socket file if using Unix socket
|
// Clean up socket file if using Unix socket
|
||||||
@@ -358,7 +361,7 @@ func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Return a success response
|
// Return a success response
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusAccepted)
|
w.WriteHeader(http.StatusAccepted)
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||||
"status": "connection request accepted",
|
"status": "connection request accepted",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -406,7 +409,7 @@ func (s *API) handleHealth(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||||
"status": "ok",
|
"status": "ok",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -423,7 +426,7 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Return a success response first
|
// Return a success response first
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||||
"status": "shutdown initiated",
|
"status": "shutdown initiated",
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -472,7 +475,7 @@ func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Return a success response
|
// Return a success response
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||||
"status": "org switch request accepted",
|
"status": "org switch request accepted",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -506,7 +509,7 @@ func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Return a success response
|
// Return a success response
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||||
"status": "disconnect initiated",
|
"status": "disconnect initiated",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
12
main.go
12
main.go
@@ -10,7 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/newt/updates"
|
"github.com/fosrl/newt/updates"
|
||||||
"github.com/fosrl/olm/olm"
|
olmpkg "github.com/fosrl/olm/olm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -210,7 +210,7 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create a new olm.Config struct and copy values from the main config
|
// Create a new olm.Config struct and copy values from the main config
|
||||||
olmConfig := olm.GlobalConfig{
|
olmConfig := olmpkg.OlmConfig{
|
||||||
LogLevel: config.LogLevel,
|
LogLevel: config.LogLevel,
|
||||||
EnableAPI: config.EnableAPI,
|
EnableAPI: config.EnableAPI,
|
||||||
HTTPAddr: config.HTTPAddr,
|
HTTPAddr: config.HTTPAddr,
|
||||||
@@ -222,13 +222,17 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt
|
|||||||
PprofAddr: ":4444", // TODO: REMOVE OR MAKE CONFIGURABLE
|
PprofAddr: ":4444", // TODO: REMOVE OR MAKE CONFIGURABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
olm.Init(ctx, olmConfig)
|
olm, err := olmpkg.Init(ctx, olmConfig)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatal("Failed to initialize olm: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := olm.StartApi(); err != nil {
|
if err := olm.StartApi(); err != nil {
|
||||||
logger.Fatal("Failed to start API server: %v", err)
|
logger.Fatal("Failed to start API server: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.ID != "" && config.Secret != "" && config.Endpoint != "" {
|
if config.ID != "" && config.Secret != "" && config.Endpoint != "" {
|
||||||
tunnelConfig := olm.TunnelConfig{
|
tunnelConfig := olmpkg.TunnelConfig{
|
||||||
Endpoint: config.Endpoint,
|
Endpoint: config.Endpoint,
|
||||||
ID: config.ID,
|
ID: config.ID,
|
||||||
Secret: config.Secret,
|
Secret: config.Secret,
|
||||||
|
|||||||
223
olm/connect.go
Normal file
223
olm/connect.go
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
package olm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/fosrl/newt/network"
|
||||||
|
olmDevice "github.com/fosrl/olm/device"
|
||||||
|
"github.com/fosrl/olm/dns"
|
||||||
|
dnsOverride "github.com/fosrl/olm/dns/override"
|
||||||
|
"github.com/fosrl/olm/peers"
|
||||||
|
"github.com/fosrl/olm/websocket"
|
||||||
|
"golang.zx2c4.com/wireguard/device"
|
||||||
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (o *Olm) handleConnect(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received message: %v", msg.Data)
|
||||||
|
|
||||||
|
var wgData WgData
|
||||||
|
|
||||||
|
if o.connected {
|
||||||
|
logger.Info("Already connected. Ignoring new connection request.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.stopRegister != nil {
|
||||||
|
o.stopRegister()
|
||||||
|
o.stopRegister = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.updateRegister != nil {
|
||||||
|
o.updateRegister = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// if there is an existing tunnel then close it
|
||||||
|
if o.dev != nil {
|
||||||
|
logger.Info("Got new message. Closing existing tunnel!")
|
||||||
|
o.dev.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Info("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(jsonData, &wgData); err != nil {
|
||||||
|
logger.Info("Error unmarshaling target data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
o.tdev, err = func() (tun.Device, error) {
|
||||||
|
if o.tunnelConfig.FileDescriptorTun != 0 {
|
||||||
|
return olmDevice.CreateTUNFromFD(o.tunnelConfig.FileDescriptorTun, o.tunnelConfig.MTU)
|
||||||
|
}
|
||||||
|
ifName := o.tunnelConfig.InterfaceName
|
||||||
|
if runtime.GOOS == "darwin" { // this is if we dont pass a fd
|
||||||
|
ifName, err = network.FindUnusedUTUN()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tun.CreateTUN(ifName, o.tunnelConfig.MTU)
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to create TUN device: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// if config.FileDescriptorTun == 0 {
|
||||||
|
if realInterfaceName, err2 := o.tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything?
|
||||||
|
o.tunnelConfig.InterfaceName = realInterfaceName
|
||||||
|
}
|
||||||
|
// }
|
||||||
|
|
||||||
|
// Wrap TUN device with packet filter for DNS proxy
|
||||||
|
o.middleDev = olmDevice.NewMiddleDevice(o.tdev)
|
||||||
|
|
||||||
|
wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ")
|
||||||
|
// Use filtered device instead of raw TUN device
|
||||||
|
o.dev = device.NewDevice(o.middleDev, o.sharedBind, (*device.Logger)(wgLogger))
|
||||||
|
|
||||||
|
if o.tunnelConfig.EnableUAPI {
|
||||||
|
fileUAPI, err := func() (*os.File, error) {
|
||||||
|
if o.tunnelConfig.FileDescriptorUAPI != 0 {
|
||||||
|
fd, err := strconv.ParseUint(fmt.Sprintf("%d", o.tunnelConfig.FileDescriptorUAPI), 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid UAPI file descriptor: %v", err)
|
||||||
|
}
|
||||||
|
return os.NewFile(uintptr(fd), ""), nil
|
||||||
|
}
|
||||||
|
return olmDevice.UapiOpen(o.tunnelConfig.InterfaceName)
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("UAPI listen error: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
o.uapiListener, err = olmDevice.UapiListen(o.tunnelConfig.InterfaceName, fileUAPI)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to listen on uapi socket: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
conn, err := o.uapiListener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go o.dev.IpcHandle(conn)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
logger.Info("UAPI listener started")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = o.dev.Up(); err != nil {
|
||||||
|
logger.Error("Failed to bring up WireGuard device: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract interface IP (strip CIDR notation if present)
|
||||||
|
interfaceIP := wgData.TunnelIP
|
||||||
|
if strings.Contains(interfaceIP, "/") {
|
||||||
|
interfaceIP = strings.Split(interfaceIP, "/")[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and start DNS proxy
|
||||||
|
o.dnsProxy, err = dns.NewDNSProxy(o.middleDev, o.tunnelConfig.MTU, wgData.UtilitySubnet, o.tunnelConfig.UpstreamDNS, o.tunnelConfig.TunnelDNS, interfaceIP)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to create DNS proxy: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = network.ConfigureInterface(o.tunnelConfig.InterfaceName, wgData.TunnelIP, o.tunnelConfig.MTU); err != nil {
|
||||||
|
logger.Error("Failed to o.tunnelConfigure interface: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if network.AddRoutes([]string{wgData.UtilitySubnet}, o.tunnelConfig.InterfaceName); err != nil { // also route the utility subnet
|
||||||
|
logger.Error("Failed to add route for utility subnet: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create peer manager with integrated peer monitoring
|
||||||
|
o.peerManager = peers.NewPeerManager(peers.PeerManagerConfig{
|
||||||
|
Device: o.dev,
|
||||||
|
DNSProxy: o.dnsProxy,
|
||||||
|
InterfaceName: o.tunnelConfig.InterfaceName,
|
||||||
|
PrivateKey: o.privateKey,
|
||||||
|
MiddleDev: o.middleDev,
|
||||||
|
LocalIP: interfaceIP,
|
||||||
|
SharedBind: o.sharedBind,
|
||||||
|
WSClient: o.olmClient,
|
||||||
|
APIServer: o.apiServer,
|
||||||
|
})
|
||||||
|
|
||||||
|
for i := range wgData.Sites {
|
||||||
|
site := wgData.Sites[i]
|
||||||
|
var siteEndpoint string
|
||||||
|
// here we are going to take the relay endpoint if it exists which means we requested a relay for this peer
|
||||||
|
if site.RelayEndpoint != "" {
|
||||||
|
siteEndpoint = site.RelayEndpoint
|
||||||
|
} else {
|
||||||
|
siteEndpoint = site.Endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
o.apiServer.AddPeerStatus(site.SiteId, site.Name, false, 0, siteEndpoint, false)
|
||||||
|
|
||||||
|
if err := o.peerManager.AddPeer(site); err != nil {
|
||||||
|
logger.Error("Failed to add peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Configured peer %s", site.PublicKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
o.peerManager.Start()
|
||||||
|
|
||||||
|
if err := o.dnsProxy.Start(); err != nil { // start DNS proxy first so there is no downtime
|
||||||
|
logger.Error("Failed to start DNS proxy: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.tunnelConfig.OverrideDNS {
|
||||||
|
// Set up DNS override to use our DNS proxy
|
||||||
|
if err := dnsOverride.SetupDNSOverride(o.tunnelConfig.InterfaceName, o.dnsProxy.GetProxyIP()); err != nil {
|
||||||
|
logger.Error("Failed to setup DNS override: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
network.SetDNSServers([]string{o.dnsProxy.GetProxyIP().String()})
|
||||||
|
}
|
||||||
|
|
||||||
|
o.apiServer.SetRegistered(true)
|
||||||
|
|
||||||
|
o.connected = true
|
||||||
|
|
||||||
|
// Invoke onConnected callback if configured
|
||||||
|
if o.olmConfig.OnConnected != nil {
|
||||||
|
go o.olmConfig.OnConnected()
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("WireGuard device created.")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Olm) handleTerminate(msg websocket.WSMessage) {
|
||||||
|
logger.Info("Received terminate message")
|
||||||
|
o.apiServer.SetTerminated(true)
|
||||||
|
o.apiServer.SetConnectionStatus(false)
|
||||||
|
o.apiServer.SetRegistered(false)
|
||||||
|
o.apiServer.ClearPeerStatuses()
|
||||||
|
|
||||||
|
network.ClearNetworkSettings()
|
||||||
|
|
||||||
|
o.Close()
|
||||||
|
|
||||||
|
if o.olmConfig.OnTerminated != nil {
|
||||||
|
go o.olmConfig.OnTerminated()
|
||||||
|
}
|
||||||
|
}
|
||||||
197
olm/data.go
Normal file
197
olm/data.go
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
package olm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/holepunch"
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/fosrl/olm/peers"
|
||||||
|
"github.com/fosrl/olm/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (o *Olm) handleWgPeerAddData(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var addSubnetsData peers.PeerAdd
|
||||||
|
if err := json.Unmarshal(jsonData, &addSubnetsData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling add-remote-subnets data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, exists := o.peerManager.GetPeer(addSubnetsData.SiteId); !exists {
|
||||||
|
logger.Debug("Peer %d not found for removing remote subnets and aliases", addSubnetsData.SiteId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new subnets
|
||||||
|
for _, subnet := range addSubnetsData.RemoteSubnets {
|
||||||
|
if err := o.peerManager.AddRemoteSubnet(addSubnetsData.SiteId, subnet); err != nil {
|
||||||
|
logger.Error("Failed to add allowed IP %s: %v", subnet, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new aliases
|
||||||
|
for _, alias := range addSubnetsData.Aliases {
|
||||||
|
if err := o.peerManager.AddAlias(addSubnetsData.SiteId, alias); err != nil {
|
||||||
|
logger.Error("Failed to add alias %s: %v", alias.Alias, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Olm) handleWgPeerRemoveData(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var removeSubnetsData peers.RemovePeerData
|
||||||
|
if err := json.Unmarshal(jsonData, &removeSubnetsData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling remove-remote-subnets data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, exists := o.peerManager.GetPeer(removeSubnetsData.SiteId); !exists {
|
||||||
|
logger.Debug("Peer %d not found for removing remote subnets and aliases", removeSubnetsData.SiteId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove subnets
|
||||||
|
for _, subnet := range removeSubnetsData.RemoteSubnets {
|
||||||
|
if err := o.peerManager.RemoveRemoteSubnet(removeSubnetsData.SiteId, subnet); err != nil {
|
||||||
|
logger.Error("Failed to remove allowed IP %s: %v", subnet, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove aliases
|
||||||
|
for _, alias := range removeSubnetsData.Aliases {
|
||||||
|
if err := o.peerManager.RemoveAlias(removeSubnetsData.SiteId, alias.Alias); err != nil {
|
||||||
|
logger.Error("Failed to remove alias %s: %v", alias.Alias, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var updateSubnetsData peers.UpdatePeerData
|
||||||
|
if err := json.Unmarshal(jsonData, &updateSubnetsData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling update-remote-subnets data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, exists := o.peerManager.GetPeer(updateSubnetsData.SiteId); !exists {
|
||||||
|
logger.Debug("Peer %d not found for updating remote subnets and aliases", updateSubnetsData.SiteId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new subnets BEFORE removing old ones to preserve shared subnets
|
||||||
|
// This ensures that if an old and new subnet are the same on different peers,
|
||||||
|
// the route won't be temporarily removed
|
||||||
|
for _, subnet := range updateSubnetsData.NewRemoteSubnets {
|
||||||
|
if err := o.peerManager.AddRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil {
|
||||||
|
logger.Error("Failed to add allowed IP %s: %v", subnet, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove old subnets after new ones are added
|
||||||
|
for _, subnet := range updateSubnetsData.OldRemoteSubnets {
|
||||||
|
if err := o.peerManager.RemoveRemoteSubnet(updateSubnetsData.SiteId, subnet); err != nil {
|
||||||
|
logger.Error("Failed to remove allowed IP %s: %v", subnet, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new aliases BEFORE removing old ones to preserve shared IP addresses
|
||||||
|
// This ensures that if an old and new alias share the same IP, the IP won't be
|
||||||
|
// temporarily removed from the allowed IPs list
|
||||||
|
for _, alias := range updateSubnetsData.NewAliases {
|
||||||
|
if err := o.peerManager.AddAlias(updateSubnetsData.SiteId, alias); err != nil {
|
||||||
|
logger.Error("Failed to add alias %s: %v", alias.Alias, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove old aliases after new ones are added
|
||||||
|
for _, alias := range updateSubnetsData.OldAliases {
|
||||||
|
if err := o.peerManager.RemoveAlias(updateSubnetsData.SiteId, alias.Alias); err != nil {
|
||||||
|
logger.Error("Failed to remove alias %s: %v", alias.Alias, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Successfully updated remote subnets and aliases for peer %d", updateSubnetsData.SiteId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received peer-handshake message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling handshake data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var handshakeData struct {
|
||||||
|
SiteId int `json:"siteId"`
|
||||||
|
ExitNode struct {
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
RelayPort uint16 `json:"relayPort"`
|
||||||
|
} `json:"exitNode"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(jsonData, &handshakeData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling handshake data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get existing peer from PeerManager
|
||||||
|
_, exists := o.peerManager.GetPeer(handshakeData.SiteId)
|
||||||
|
if exists {
|
||||||
|
logger.Warn("Peer with site ID %d already added", handshakeData.SiteId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
relayPort := handshakeData.ExitNode.RelayPort
|
||||||
|
if relayPort == 0 {
|
||||||
|
relayPort = 21820 // default relay port
|
||||||
|
}
|
||||||
|
|
||||||
|
siteId := handshakeData.SiteId
|
||||||
|
exitNode := holepunch.ExitNode{
|
||||||
|
Endpoint: handshakeData.ExitNode.Endpoint,
|
||||||
|
RelayPort: relayPort,
|
||||||
|
PublicKey: handshakeData.ExitNode.PublicKey,
|
||||||
|
SiteIds: []int{siteId},
|
||||||
|
}
|
||||||
|
|
||||||
|
added := o.holePunchManager.AddExitNode(exitNode)
|
||||||
|
if added {
|
||||||
|
logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint)
|
||||||
|
} else {
|
||||||
|
logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
|
||||||
|
o.holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud
|
||||||
|
|
||||||
|
// Send handshake acknowledgment back to server with retry
|
||||||
|
o.stopPeerSend, _ = o.olmClient.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||||
|
"siteId": handshakeData.SiteId,
|
||||||
|
}, 1*time.Second)
|
||||||
|
|
||||||
|
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
|
||||||
|
}
|
||||||
976
olm/olm.go
976
olm/olm.go
File diff suppressed because it is too large
Load Diff
195
olm/peer.go
Normal file
195
olm/peer.go
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
package olm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/fosrl/newt/util"
|
||||||
|
"github.com/fosrl/olm/peers"
|
||||||
|
"github.com/fosrl/olm/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received add-peer message: %v", msg.Data)
|
||||||
|
|
||||||
|
if o.stopPeerSend != nil {
|
||||||
|
o.stopPeerSend()
|
||||||
|
o.stopPeerSend = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var siteConfig peers.SiteConfig
|
||||||
|
if err := json.Unmarshal(jsonData, &siteConfig); err != nil {
|
||||||
|
logger.Error("Error unmarshaling add data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
|
||||||
|
|
||||||
|
if err := o.peerManager.AddPeer(siteConfig); err != nil {
|
||||||
|
logger.Error("Failed to add peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Successfully added peer for site %d", siteConfig.SiteId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received remove-peer message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var removeData peers.PeerRemove
|
||||||
|
if err := json.Unmarshal(jsonData, &removeData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling remove data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := o.peerManager.RemovePeer(removeData.SiteId); err != nil {
|
||||||
|
logger.Error("Failed to remove peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove any exit nodes associated with this peer from hole punching
|
||||||
|
if o.holePunchManager != nil {
|
||||||
|
removed := o.holePunchManager.RemoveExitNodesByPeer(removeData.SiteId)
|
||||||
|
if removed > 0 {
|
||||||
|
logger.Info("Removed %d exit nodes associated with peer %d from hole punch rotation", removed, removeData.SiteId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Successfully removed peer for site %d", removeData.SiteId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received update-peer message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var updateData peers.SiteConfig
|
||||||
|
if err := json.Unmarshal(jsonData, &updateData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling update data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get existing peer from PeerManager
|
||||||
|
existingPeer, exists := o.peerManager.GetPeer(updateData.SiteId)
|
||||||
|
if !exists {
|
||||||
|
logger.Warn("Peer with site ID %d not found", updateData.SiteId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create updated site config by merging with existing data
|
||||||
|
siteConfig := existingPeer
|
||||||
|
|
||||||
|
if updateData.Endpoint != "" {
|
||||||
|
siteConfig.Endpoint = updateData.Endpoint
|
||||||
|
}
|
||||||
|
if updateData.RelayEndpoint != "" {
|
||||||
|
siteConfig.RelayEndpoint = updateData.RelayEndpoint
|
||||||
|
}
|
||||||
|
if updateData.PublicKey != "" {
|
||||||
|
siteConfig.PublicKey = updateData.PublicKey
|
||||||
|
}
|
||||||
|
if updateData.ServerIP != "" {
|
||||||
|
siteConfig.ServerIP = updateData.ServerIP
|
||||||
|
}
|
||||||
|
if updateData.ServerPort != 0 {
|
||||||
|
siteConfig.ServerPort = updateData.ServerPort
|
||||||
|
}
|
||||||
|
if updateData.RemoteSubnets != nil {
|
||||||
|
siteConfig.RemoteSubnets = updateData.RemoteSubnets
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := o.peerManager.UpdatePeer(siteConfig); err != nil {
|
||||||
|
logger.Error("Failed to update peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the endpoint changed, trigger holepunch to refresh NAT mappings
|
||||||
|
if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint {
|
||||||
|
logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId)
|
||||||
|
_ = o.holePunchManager.TriggerHolePunch()
|
||||||
|
o.holePunchManager.ResetInterval()
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received relay-peer message: %v", msg.Data)
|
||||||
|
|
||||||
|
// Check if peerManager is still valid (may be nil during shutdown)
|
||||||
|
if o.peerManager == nil {
|
||||||
|
logger.Debug("Ignoring relay message: peerManager is nil (shutdown in progress)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var relayData peers.RelayPeerData
|
||||||
|
if err := json.Unmarshal(jsonData, &relayData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling relay data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to resolve primary relay endpoint: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update HTTP server to mark this peer as using relay
|
||||||
|
o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.RelayEndpoint, true)
|
||||||
|
|
||||||
|
o.peerManager.RelayPeer(relayData.SiteId, primaryRelay, relayData.RelayPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received unrelay-peer message: %v", msg.Data)
|
||||||
|
|
||||||
|
// Check if peerManager is still valid (may be nil during shutdown)
|
||||||
|
if o.peerManager == nil {
|
||||||
|
logger.Debug("Ignoring unrelay message: peerManager is nil (shutdown in progress)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var relayData peers.UnRelayPeerData
|
||||||
|
if err := json.Unmarshal(jsonData, &relayData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling relay data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
primaryRelay, err := util.ResolveDomain(relayData.Endpoint)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update HTTP server to mark this peer as using relay
|
||||||
|
o.apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, false)
|
||||||
|
|
||||||
|
o.peerManager.UnRelayPeer(relayData.SiteId, primaryRelay)
|
||||||
|
}
|
||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func sendPing(olm *websocket.Client) error {
|
func sendPing(olm *websocket.Client) error {
|
||||||
err := olm.SendMessage("olm/ping", map[string]interface{}{
|
err := olm.SendMessage("olm/ping", map[string]any{
|
||||||
"timestamp": time.Now().Unix(),
|
"timestamp": time.Now().Unix(),
|
||||||
"userToken": olm.GetConfig().UserToken,
|
"userToken": olm.GetConfig().UserToken,
|
||||||
})
|
})
|
||||||
@@ -21,7 +21,7 @@ func sendPing(olm *websocket.Client) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func keepSendingPing(olm *websocket.Client) {
|
func (o *Olm) keepSendingPing(olm *websocket.Client) {
|
||||||
// Send ping immediately on startup
|
// Send ping immediately on startup
|
||||||
if err := sendPing(olm); err != nil {
|
if err := sendPing(olm); err != nil {
|
||||||
logger.Error("Failed to send initial ping: %v", err)
|
logger.Error("Failed to send initial ping: %v", err)
|
||||||
@@ -35,7 +35,7 @@ func keepSendingPing(olm *websocket.Client) {
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-stopPing:
|
case <-o.stopPing:
|
||||||
logger.Info("Stopping ping messages")
|
logger.Info("Stopping ping messages")
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
@@ -12,7 +12,7 @@ type WgData struct {
|
|||||||
UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses
|
UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses
|
||||||
}
|
}
|
||||||
|
|
||||||
type GlobalConfig struct {
|
type OlmConfig struct {
|
||||||
// Logging
|
// Logging
|
||||||
LogLevel string
|
LogLevel string
|
||||||
LogFilePath string
|
LogFilePath string
|
||||||
|
|||||||
Reference in New Issue
Block a user