Compare commits

...

23 Commits

Author SHA1 Message Date
Owen
7fc09f8ed1 Fix windows build
Former-commit-id: 6af69cdcd6
2025-11-08 20:39:36 -08:00
Owen
079843602c Dont close and comment out dont create
Former-commit-id: 9b74bcfb81
2025-11-08 17:42:19 -08:00
Owen
70bf22c354 Remove binaries
Former-commit-id: 3398d2ab7e
2025-11-08 17:38:46 -08:00
Owen
3d891cfa97 Remove do not create client for now
Because its always created when the user joins the org


Former-commit-id: 8ebc678edb
2025-11-08 16:54:35 -08:00
Owen
78e3bb374a Split out hp
Former-commit-id: 29ed4fefbf
2025-11-07 21:59:07 -08:00
Owen
a61c7ca1ee Custom bind?
Former-commit-id: 6d8e298ebc
2025-11-07 21:39:28 -08:00
Owen
7696ba2e36 Add DoNotCreateNewClient
Former-commit-id: aedebb5579
2025-11-07 15:26:45 -08:00
Owen
235877c379 Add optional user token to validate
Former-commit-id: 5734684a21
2025-11-07 14:52:54 -08:00
Owen
befab0f8d1 Fix passing original arguments
Former-commit-id: 7e5b740514
2025-11-07 14:33:52 -08:00
Owen
914d080a57 Connecting disconnecting working
Former-commit-id: 553010f2ea
2025-11-07 14:31:13 -08:00
Owen
a274b4b38f Starting and stopping working
Former-commit-id: f23f2fb9aa
2025-11-07 14:20:36 -08:00
Owen
ce3c585514 Allow connecting and disconnecting
Former-commit-id: 596c4aa0da
2025-11-07 14:07:44 -08:00
Owen
963d8abad5 Add org id in the status
Former-commit-id: da1e4911bd
2025-11-03 20:54:55 -08:00
Owen
38eb56381f Update switching orgs
Former-commit-id: 690b133c7b
2025-11-03 20:33:06 -08:00
Owen
43b3822090 Allow pasing orgId to select org to connect
Former-commit-id: 46a4847cee
2025-11-03 16:54:38 -08:00
Owen
b0fb370c4d Remove status
Former-commit-id: 352ac8def6
2025-11-03 15:29:18 -08:00
Owen
99328ee76f Add registered to api
Former-commit-id: 9c496f7ca7
2025-11-03 15:16:12 -08:00
Owen
36fc3ea253 Add exit call
Former-commit-id: 4a89915826
2025-11-03 14:15:16 -08:00
Owen
a7979259f3 Make api availble over socket
Former-commit-id: e464af5302
2025-11-02 18:56:09 -08:00
Owen
ea6fa72bc0 Copy in config
Former-commit-id: 3505549331
2025-11-02 12:09:39 -08:00
Owen
f9adde6b1d Rename to run
Former-commit-id: 6f7e866e93
2025-11-01 18:39:53 -07:00
Owen
ba25586646 Import submodule
Former-commit-id: eaf94e6855
2025-11-01 18:37:53 -07:00
Owen
952ab63e8d Package?
Former-commit-id: 218e4f88bc
2025-11-01 18:34:00 -07:00
19 changed files with 3263 additions and 1300 deletions

1
.gitignore vendored
View File

@@ -1,3 +1,2 @@
olm
.DS_Store
bin/

View File

@@ -10,7 +10,7 @@ docker-build-release:
docker buildx build --platform linux/arm/v7,linux/arm64,linux/amd64 -t fosrl/olm:$(tag) -f Dockerfile --push .
local:
CGO_ENABLED=0 go build -o olm
CGO_ENABLED=0 go build -o bin/olm
build:
docker build -t fosrl/olm:latest .

411
api/api.go Normal file
View File

@@ -0,0 +1,411 @@
package api
import (
"encoding/json"
"fmt"
"net"
"net/http"
"sync"
"time"
"github.com/fosrl/newt/logger"
)
// ConnectionRequest defines the structure for an incoming connection request
type ConnectionRequest struct {
ID string `json:"id"`
Secret string `json:"secret"`
Endpoint string `json:"endpoint"`
UserToken string `json:"userToken,omitempty"`
}
// SwitchOrgRequest defines the structure for switching organizations
type SwitchOrgRequest struct {
OrgID string `json:"orgId"`
}
// PeerStatus represents the status of a peer connection
type PeerStatus struct {
SiteID int `json:"siteId"`
Connected bool `json:"connected"`
RTT time.Duration `json:"rtt"`
LastSeen time.Time `json:"lastSeen"`
Endpoint string `json:"endpoint,omitempty"`
IsRelay bool `json:"isRelay"`
PeerIP string `json:"peerAddress,omitempty"`
}
// StatusResponse is returned by the status endpoint
type StatusResponse struct {
Connected bool `json:"connected"`
Registered bool `json:"registered"`
TunnelIP string `json:"tunnelIP,omitempty"`
Version string `json:"version,omitempty"`
OrgID string `json:"orgId,omitempty"`
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
}
// API represents the HTTP server and its state
type API struct {
addr string
socketPath string
listener net.Listener
server *http.Server
connectionChan chan ConnectionRequest
switchOrgChan chan SwitchOrgRequest
shutdownChan chan struct{}
disconnectChan chan struct{}
statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus
connectedAt time.Time
isConnected bool
isRegistered bool
tunnelIP string
version string
orgID string
}
// NewAPI creates a new HTTP server that listens on a TCP address
func NewAPI(addr string) *API {
s := &API{
addr: addr,
connectionChan: make(chan ConnectionRequest, 1),
switchOrgChan: make(chan SwitchOrgRequest, 1),
shutdownChan: make(chan struct{}, 1),
disconnectChan: make(chan struct{}, 1),
peerStatuses: make(map[int]*PeerStatus),
}
return s
}
// NewAPISocket creates a new HTTP server that listens on a Unix socket or Windows named pipe
func NewAPISocket(socketPath string) *API {
s := &API{
socketPath: socketPath,
connectionChan: make(chan ConnectionRequest, 1),
switchOrgChan: make(chan SwitchOrgRequest, 1),
shutdownChan: make(chan struct{}, 1),
disconnectChan: make(chan struct{}, 1),
peerStatuses: make(map[int]*PeerStatus),
}
return s
}
// Start starts the HTTP server
func (s *API) Start() error {
mux := http.NewServeMux()
mux.HandleFunc("/connect", s.handleConnect)
mux.HandleFunc("/status", s.handleStatus)
mux.HandleFunc("/switch-org", s.handleSwitchOrg)
mux.HandleFunc("/disconnect", s.handleDisconnect)
mux.HandleFunc("/exit", s.handleExit)
s.server = &http.Server{
Handler: mux,
}
var err error
if s.socketPath != "" {
// Use platform-specific socket listener
s.listener, err = createSocketListener(s.socketPath)
if err != nil {
return fmt.Errorf("failed to create socket listener: %w", err)
}
logger.Info("Starting HTTP server on socket %s", s.socketPath)
} else {
// Use TCP listener
s.listener, err = net.Listen("tcp", s.addr)
if err != nil {
return fmt.Errorf("failed to create TCP listener: %w", err)
}
logger.Info("Starting HTTP server on %s", s.addr)
}
go func() {
if err := s.server.Serve(s.listener); err != nil && err != http.ErrServerClosed {
logger.Error("HTTP server error: %v", err)
}
}()
return nil
}
// Stop stops the HTTP server
func (s *API) Stop() error {
logger.Info("Stopping api server")
// Close the server first, which will also close the listener gracefully
if s.server != nil {
s.server.Close()
}
// Clean up socket file if using Unix socket
if s.socketPath != "" {
cleanupSocket(s.socketPath)
}
return nil
}
// GetConnectionChannel returns the channel for receiving connection requests
func (s *API) GetConnectionChannel() <-chan ConnectionRequest {
return s.connectionChan
}
// GetSwitchOrgChannel returns the channel for receiving org switch requests
func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest {
return s.switchOrgChan
}
// GetShutdownChannel returns the channel for receiving shutdown requests
func (s *API) GetShutdownChannel() <-chan struct{} {
return s.shutdownChan
}
// GetDisconnectChannel returns the channel for receiving disconnect requests
func (s *API) GetDisconnectChannel() <-chan struct{} {
return s.disconnectChan
}
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
status, exists := s.peerStatuses[siteID]
if !exists {
status = &PeerStatus{
SiteID: siteID,
}
s.peerStatuses[siteID] = status
}
status.Connected = connected
status.RTT = rtt
status.LastSeen = time.Now()
status.Endpoint = endpoint
status.IsRelay = isRelay
}
// SetConnectionStatus sets the overall connection status
func (s *API) SetConnectionStatus(isConnected bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.isConnected = isConnected
if isConnected {
s.connectedAt = time.Now()
} else {
// Clear peer statuses when disconnected
s.peerStatuses = make(map[int]*PeerStatus)
}
}
func (s *API) SetRegistered(registered bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.isRegistered = registered
}
// SetTunnelIP sets the tunnel IP address
func (s *API) SetTunnelIP(tunnelIP string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.tunnelIP = tunnelIP
}
// SetVersion sets the olm version
func (s *API) SetVersion(version string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.version = version
}
// SetOrgID sets the organization ID
func (s *API) SetOrgID(orgID string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.orgID = orgID
}
// UpdatePeerRelayStatus updates only the relay status of a peer
func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
status, exists := s.peerStatuses[siteID]
if !exists {
status = &PeerStatus{
SiteID: siteID,
}
s.peerStatuses[siteID] = status
}
status.Endpoint = endpoint
status.IsRelay = isRelay
}
// handleConnect handles the /connect endpoint
func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// if we are already connected, reject new connection requests
s.statusMu.RLock()
alreadyConnected := s.isConnected
s.statusMu.RUnlock()
if alreadyConnected {
http.Error(w, "Already connected to a server. Disconnect first before connecting again.", http.StatusConflict)
return
}
var req ConnectionRequest
decoder := json.NewDecoder(r.Body)
if err := decoder.Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
return
}
// Validate required fields
if req.ID == "" || req.Secret == "" || req.Endpoint == "" {
http.Error(w, "Missing required fields: id, secret, and endpoint must be provided", http.StatusBadRequest)
return
}
// Send the request to the main goroutine
s.connectionChan <- req
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
json.NewEncoder(w).Encode(map[string]string{
"status": "connection request accepted",
})
}
// handleStatus handles the /status endpoint
func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
s.statusMu.RLock()
defer s.statusMu.RUnlock()
resp := StatusResponse{
Connected: s.isConnected,
Registered: s.isRegistered,
TunnelIP: s.tunnelIP,
Version: s.version,
OrgID: s.orgID,
PeerStatuses: s.peerStatuses,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
// handleExit handles the /exit endpoint
func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
logger.Info("Received exit request via API")
// Send shutdown signal
select {
case s.shutdownChan <- struct{}{}:
// Signal sent successfully
default:
// Channel already has a signal, don't block
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
"status": "shutdown initiated",
})
}
// handleSwitchOrg handles the /switch-org endpoint
func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req SwitchOrgRequest
decoder := json.NewDecoder(r.Body)
if err := decoder.Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
return
}
// Validate required fields
if req.OrgID == "" {
http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest)
return
}
logger.Info("Received org switch request to orgId: %s", req.OrgID)
// Send the request to the main goroutine
select {
case s.switchOrgChan <- req:
// Signal sent successfully
default:
// Channel already has a pending request
http.Error(w, "Org switch already in progress", http.StatusConflict)
return
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
"status": "org switch request accepted",
})
}
// handleDisconnect handles the /disconnect endpoint
func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// if we are already disconnected, reject new disconnect requests
s.statusMu.RLock()
alreadyDisconnected := !s.isConnected
s.statusMu.RUnlock()
if alreadyDisconnected {
http.Error(w, "Not currently connected to a server.", http.StatusConflict)
return
}
logger.Info("Received disconnect request via API")
// Send disconnect signal
select {
case s.disconnectChan <- struct{}{}:
// Signal sent successfully
default:
// Channel already has a signal, don't block
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
"status": "disconnect initiated",
})
}

50
api/api_unix.go Normal file
View File

@@ -0,0 +1,50 @@
//go:build !windows
// +build !windows
package api
import (
"fmt"
"net"
"os"
"path/filepath"
"github.com/fosrl/newt/logger"
)
// createSocketListener creates a Unix domain socket listener
func createSocketListener(socketPath string) (net.Listener, error) {
// Ensure the directory exists
dir := filepath.Dir(socketPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create socket directory: %w", err)
}
// Remove existing socket file if it exists
if err := os.RemoveAll(socketPath); err != nil {
return nil, fmt.Errorf("failed to remove existing socket: %w", err)
}
listener, err := net.Listen("unix", socketPath)
if err != nil {
return nil, fmt.Errorf("failed to listen on Unix socket: %w", err)
}
// Set socket permissions to allow access
if err := os.Chmod(socketPath, 0666); err != nil {
listener.Close()
return nil, fmt.Errorf("failed to set socket permissions: %w", err)
}
logger.Debug("Created Unix socket at %s", socketPath)
return listener, nil
}
// cleanupSocket removes the Unix socket file
func cleanupSocket(socketPath string) {
if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) {
logger.Error("Failed to remove socket file %s: %v", socketPath, err)
} else {
logger.Debug("Removed Unix socket at %s", socketPath)
}
}

41
api/api_windows.go Normal file
View File

@@ -0,0 +1,41 @@
//go:build windows
// +build windows
package api
import (
"fmt"
"net"
"github.com/Microsoft/go-winio"
"github.com/fosrl/newt/logger"
)
// createSocketListener creates a Windows named pipe listener
func createSocketListener(pipePath string) (net.Listener, error) {
// Ensure the pipe path has the correct format
if pipePath[0] != '\\' {
pipePath = `\\.\pipe\` + pipePath
}
// Create a pipe configuration that allows everyone to write
config := &winio.PipeConfig{
// Set security descriptor to allow everyone full access
// This SDDL string grants full access to Everyone (WD) and to the current owner (OW)
SecurityDescriptor: "D:(A;;GA;;;WD)(A;;GA;;;OW)",
}
// Create a named pipe listener using go-winio with the configuration
listener, err := winio.ListenPipe(pipePath, config)
if err != nil {
return nil, fmt.Errorf("failed to listen on named pipe: %w", err)
}
logger.Debug("Created named pipe at %s with write access for everyone", pipePath)
return listener, nil
}
// cleanupSocket is a no-op on Windows as named pipes are automatically cleaned up
func cleanupSocket(pipePath string) {
logger.Debug("Named pipe %s will be automatically cleaned up", pipePath)
}

378
bind/shared_bind.go Normal file
View File

@@ -0,0 +1,378 @@
//go:build !js
package bind
import (
"fmt"
"net"
"net/netip"
"runtime"
"sync"
"sync/atomic"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
)
// Endpoint represents a network endpoint for the SharedBind
type Endpoint struct {
AddrPort netip.AddrPort
}
// ClearSrc implements the wgConn.Endpoint interface
func (e *Endpoint) ClearSrc() {}
// DstIP implements the wgConn.Endpoint interface
func (e *Endpoint) DstIP() netip.Addr {
return e.AddrPort.Addr()
}
// SrcIP implements the wgConn.Endpoint interface
func (e *Endpoint) SrcIP() netip.Addr {
return netip.Addr{}
}
// DstToBytes implements the wgConn.Endpoint interface
func (e *Endpoint) DstToBytes() []byte {
b, _ := e.AddrPort.MarshalBinary()
return b
}
// DstToString implements the wgConn.Endpoint interface
func (e *Endpoint) DstToString() string {
return e.AddrPort.String()
}
// SrcToString implements the wgConn.Endpoint interface
func (e *Endpoint) SrcToString() string {
return ""
}
// SharedBind is a thread-safe UDP bind that can be shared between WireGuard
// and hole punch senders. It wraps a single UDP connection and implements
// reference counting to prevent premature closure.
type SharedBind struct {
mu sync.RWMutex
// The underlying UDP connection
udpConn *net.UDPConn
// IPv4 and IPv6 packet connections for advanced features
ipv4PC *ipv4.PacketConn
ipv6PC *ipv6.PacketConn
// Reference counting to prevent closing while in use
refCount atomic.Int32
closed atomic.Bool
// Channels for receiving data
recvFuncs []wgConn.ReceiveFunc
// Port binding information
port uint16
}
// New creates a new SharedBind from an existing UDP connection.
// The SharedBind takes ownership of the connection and will close it
// when all references are released.
func New(udpConn *net.UDPConn) (*SharedBind, error) {
if udpConn == nil {
return nil, fmt.Errorf("udpConn cannot be nil")
}
bind := &SharedBind{
udpConn: udpConn,
}
// Initialize reference count to 1 (the creator holds the first reference)
bind.refCount.Store(1)
// Get the local port
if addr, ok := udpConn.LocalAddr().(*net.UDPAddr); ok {
bind.port = uint16(addr.Port)
}
return bind, nil
}
// AddRef increments the reference count. Call this when sharing
// the bind with another component.
func (b *SharedBind) AddRef() {
newCount := b.refCount.Add(1)
// Optional: Add logging for debugging
_ = newCount // Placeholder for potential logging
}
// Release decrements the reference count. When it reaches zero,
// the underlying UDP connection is closed.
func (b *SharedBind) Release() error {
newCount := b.refCount.Add(-1)
// Optional: Add logging for debugging
_ = newCount // Placeholder for potential logging
if newCount < 0 {
// This should never happen with proper usage
b.refCount.Store(0)
return fmt.Errorf("SharedBind reference count went negative")
}
if newCount == 0 {
return b.closeConnection()
}
return nil
}
// closeConnection actually closes the UDP connection
func (b *SharedBind) closeConnection() error {
if !b.closed.CompareAndSwap(false, true) {
// Already closed
return nil
}
b.mu.Lock()
defer b.mu.Unlock()
var err error
if b.udpConn != nil {
err = b.udpConn.Close()
b.udpConn = nil
}
b.ipv4PC = nil
b.ipv6PC = nil
return err
}
// GetUDPConn returns the underlying UDP connection.
// The caller must not close this connection directly.
func (b *SharedBind) GetUDPConn() *net.UDPConn {
b.mu.RLock()
defer b.mu.RUnlock()
return b.udpConn
}
// GetRefCount returns the current reference count (for debugging)
func (b *SharedBind) GetRefCount() int32 {
return b.refCount.Load()
}
// IsClosed returns whether the bind is closed
func (b *SharedBind) IsClosed() bool {
return b.closed.Load()
}
// WriteToUDP writes data to a specific UDP address.
// This is thread-safe and can be used by hole punch senders.
func (b *SharedBind) WriteToUDP(data []byte, addr *net.UDPAddr) (int, error) {
if b.closed.Load() {
return 0, net.ErrClosed
}
b.mu.RLock()
conn := b.udpConn
b.mu.RUnlock()
if conn == nil {
return 0, net.ErrClosed
}
return conn.WriteToUDP(data, addr)
}
// Close implements the WireGuard Bind interface.
// It decrements the reference count and closes the connection if no references remain.
func (b *SharedBind) Close() error {
return b.Release()
}
// Open implements the WireGuard Bind interface.
// Since the connection is already open, this just sets up the receive functions.
func (b *SharedBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
if b.closed.Load() {
return nil, 0, net.ErrClosed
}
b.mu.Lock()
defer b.mu.Unlock()
if b.udpConn == nil {
return nil, 0, net.ErrClosed
}
// Set up IPv4 and IPv6 packet connections for advanced features
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
b.ipv4PC = ipv4.NewPacketConn(b.udpConn)
b.ipv6PC = ipv6.NewPacketConn(b.udpConn)
}
// Create receive functions
recvFuncs := make([]wgConn.ReceiveFunc, 0, 2)
// Add IPv4 receive function
if b.ipv4PC != nil || runtime.GOOS != "linux" {
recvFuncs = append(recvFuncs, b.makeReceiveIPv4())
}
// Add IPv6 receive function if needed
// For now, we focus on IPv4 for hole punching use case
b.recvFuncs = recvFuncs
return recvFuncs, b.port, nil
}
// makeReceiveIPv4 creates a receive function for IPv4 packets
func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
if b.closed.Load() {
return 0, net.ErrClosed
}
b.mu.RLock()
conn := b.udpConn
pc := b.ipv4PC
b.mu.RUnlock()
if conn == nil {
return 0, net.ErrClosed
}
// Use batch reading on Linux for performance
if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
return b.receiveIPv4Batch(pc, bufs, sizes, eps)
}
// Fallback to simple read for other platforms
return b.receiveIPv4Simple(conn, bufs, sizes, eps)
}
}
// receiveIPv4Batch uses batch reading for better performance on Linux
func (b *SharedBind) receiveIPv4Batch(pc *ipv4.PacketConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
// Create messages for batch reading
msgs := make([]ipv4.Message, len(bufs))
for i := range bufs {
msgs[i].Buffers = [][]byte{bufs[i]}
msgs[i].OOB = make([]byte, 0) // No OOB data needed for basic use
}
numMsgs, err := pc.ReadBatch(msgs, 0)
if err != nil {
return 0, err
}
for i := 0; i < numMsgs; i++ {
sizes[i] = msgs[i].N
if sizes[i] == 0 {
continue
}
if msgs[i].Addr != nil {
if udpAddr, ok := msgs[i].Addr.(*net.UDPAddr); ok {
addrPort := udpAddr.AddrPort()
eps[i] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
}
}
}
return numMsgs, nil
}
// receiveIPv4Simple uses simple ReadFromUDP for non-Linux platforms
func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
n, addr, err := conn.ReadFromUDP(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
if addr != nil {
addrPort := addr.AddrPort()
eps[0] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
}
return 1, nil
}
// Send implements the WireGuard Bind interface.
// It sends packets to the specified endpoint.
func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
if b.closed.Load() {
return net.ErrClosed
}
b.mu.RLock()
conn := b.udpConn
b.mu.RUnlock()
if conn == nil {
return net.ErrClosed
}
// Extract the destination address from the endpoint
var destAddr *net.UDPAddr
// Try to cast to StdNetEndpoint first
if stdEp, ok := ep.(*wgConn.StdNetEndpoint); ok {
destAddr = net.UDPAddrFromAddrPort(stdEp.AddrPort)
} else {
// Fallback: construct from DstIP and DstToBytes
dstBytes := ep.DstToBytes()
if len(dstBytes) >= 6 { // Minimum for IPv4 (4 bytes) + port (2 bytes)
var addr netip.Addr
var port uint16
if len(dstBytes) >= 18 { // IPv6 (16 bytes) + port (2 bytes)
addr, _ = netip.AddrFromSlice(dstBytes[:16])
port = uint16(dstBytes[16]) | uint16(dstBytes[17])<<8
} else { // IPv4
addr, _ = netip.AddrFromSlice(dstBytes[:4])
port = uint16(dstBytes[4]) | uint16(dstBytes[5])<<8
}
if addr.IsValid() {
destAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, port))
}
}
}
if destAddr == nil {
return fmt.Errorf("could not extract destination address from endpoint")
}
// Send all buffers to the destination
for _, buf := range bufs {
_, err := conn.WriteToUDP(buf, destAddr)
if err != nil {
return err
}
}
return nil
}
// SetMark implements the WireGuard Bind interface.
// It's a no-op for this implementation.
func (b *SharedBind) SetMark(mark uint32) error {
// Not implemented for this use case
return nil
}
// BatchSize returns the preferred batch size for sending packets.
func (b *SharedBind) BatchSize() int {
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
return wgConn.IdealBatchSize
}
return 1
}
// ParseEndpoint creates a new endpoint from a string address.
func (b *SharedBind) ParseEndpoint(s string) (wgConn.Endpoint, error) {
addrPort, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
return &wgConn.StdNetEndpoint{AddrPort: addrPort}, nil
}

424
bind/shared_bind_test.go Normal file
View File

@@ -0,0 +1,424 @@
//go:build !js
package bind
import (
"net"
"net/netip"
"sync"
"testing"
"time"
wgConn "golang.zx2c4.com/wireguard/conn"
)
// TestSharedBindCreation tests basic creation and initialization
func TestSharedBindCreation(t *testing.T) {
// Create a UDP connection
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
defer udpConn.Close()
// Create SharedBind
bind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
if bind == nil {
t.Fatal("SharedBind is nil")
}
// Verify initial reference count
if bind.refCount.Load() != 1 {
t.Errorf("Expected initial refCount to be 1, got %d", bind.refCount.Load())
}
// Clean up
if err := bind.Close(); err != nil {
t.Errorf("Failed to close SharedBind: %v", err)
}
}
// TestSharedBindReferenceCount tests reference counting
func TestSharedBindReferenceCount(t *testing.T) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
bind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
// Add references
bind.AddRef()
if bind.refCount.Load() != 2 {
t.Errorf("Expected refCount to be 2, got %d", bind.refCount.Load())
}
bind.AddRef()
if bind.refCount.Load() != 3 {
t.Errorf("Expected refCount to be 3, got %d", bind.refCount.Load())
}
// Release references
bind.Release()
if bind.refCount.Load() != 2 {
t.Errorf("Expected refCount to be 2 after release, got %d", bind.refCount.Load())
}
bind.Release()
bind.Release() // This should close the connection
if !bind.closed.Load() {
t.Error("Expected bind to be closed after all references released")
}
}
// TestSharedBindWriteToUDP tests the WriteToUDP functionality
func TestSharedBindWriteToUDP(t *testing.T) {
// Create sender
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create sender UDP connection: %v", err)
}
senderBind, err := New(senderConn)
if err != nil {
t.Fatalf("Failed to create sender SharedBind: %v", err)
}
defer senderBind.Close()
// Create receiver
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create receiver UDP connection: %v", err)
}
defer receiverConn.Close()
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
// Send data
testData := []byte("Hello, SharedBind!")
n, err := senderBind.WriteToUDP(testData, receiverAddr)
if err != nil {
t.Fatalf("WriteToUDP failed: %v", err)
}
if n != len(testData) {
t.Errorf("Expected to send %d bytes, sent %d", len(testData), n)
}
// Receive data
buf := make([]byte, 1024)
receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, _, err = receiverConn.ReadFromUDP(buf)
if err != nil {
t.Fatalf("Failed to receive data: %v", err)
}
if string(buf[:n]) != string(testData) {
t.Errorf("Expected to receive %q, got %q", testData, buf[:n])
}
}
// TestSharedBindConcurrentWrites tests thread-safety
func TestSharedBindConcurrentWrites(t *testing.T) {
// Create sender
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create sender UDP connection: %v", err)
}
senderBind, err := New(senderConn)
if err != nil {
t.Fatalf("Failed to create sender SharedBind: %v", err)
}
defer senderBind.Close()
// Create receiver
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create receiver UDP connection: %v", err)
}
defer receiverConn.Close()
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
// Launch concurrent writes
numGoroutines := 100
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
data := []byte{byte(id)}
_, err := senderBind.WriteToUDP(data, receiverAddr)
if err != nil {
t.Errorf("WriteToUDP failed in goroutine %d: %v", id, err)
}
}(i)
}
wg.Wait()
}
// TestSharedBindWireGuardInterface tests WireGuard Bind interface implementation
func TestSharedBindWireGuardInterface(t *testing.T) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
bind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
defer bind.Close()
// Test Open
recvFuncs, port, err := bind.Open(0)
if err != nil {
t.Fatalf("Open failed: %v", err)
}
if len(recvFuncs) == 0 {
t.Error("Expected at least one receive function")
}
if port == 0 {
t.Error("Expected non-zero port")
}
// Test SetMark (should be a no-op)
if err := bind.SetMark(0); err != nil {
t.Errorf("SetMark failed: %v", err)
}
// Test BatchSize
batchSize := bind.BatchSize()
if batchSize <= 0 {
t.Error("Expected positive batch size")
}
}
// TestSharedBindSend tests the Send method with WireGuard endpoints
func TestSharedBindSend(t *testing.T) {
// Create sender
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create sender UDP connection: %v", err)
}
senderBind, err := New(senderConn)
if err != nil {
t.Fatalf("Failed to create sender SharedBind: %v", err)
}
defer senderBind.Close()
// Create receiver
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create receiver UDP connection: %v", err)
}
defer receiverConn.Close()
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
// Create an endpoint
addrPort := receiverAddr.AddrPort()
endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort}
// Send data
testData := []byte("WireGuard packet")
bufs := [][]byte{testData}
err = senderBind.Send(bufs, endpoint)
if err != nil {
t.Fatalf("Send failed: %v", err)
}
// Receive data
buf := make([]byte, 1024)
receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, _, err := receiverConn.ReadFromUDP(buf)
if err != nil {
t.Fatalf("Failed to receive data: %v", err)
}
if string(buf[:n]) != string(testData) {
t.Errorf("Expected to receive %q, got %q", testData, buf[:n])
}
}
// TestSharedBindMultipleUsers simulates WireGuard and hole punch using the same bind
func TestSharedBindMultipleUsers(t *testing.T) {
// Create shared bind
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
sharedBind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
// Add reference for hole punch sender
sharedBind.AddRef()
// Create receiver
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create receiver UDP connection: %v", err)
}
defer receiverConn.Close()
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
var wg sync.WaitGroup
// Simulate WireGuard using the bind
wg.Add(1)
go func() {
defer wg.Done()
addrPort := receiverAddr.AddrPort()
endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort}
for i := 0; i < 10; i++ {
data := []byte("WireGuard packet")
bufs := [][]byte{data}
if err := sharedBind.Send(bufs, endpoint); err != nil {
t.Errorf("WireGuard Send failed: %v", err)
}
time.Sleep(10 * time.Millisecond)
}
}()
// Simulate hole punch sender using the bind
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 10; i++ {
data := []byte("Hole punch packet")
if _, err := sharedBind.WriteToUDP(data, receiverAddr); err != nil {
t.Errorf("Hole punch WriteToUDP failed: %v", err)
}
time.Sleep(10 * time.Millisecond)
}
}()
wg.Wait()
// Release the hole punch reference
sharedBind.Release()
// Close WireGuard's reference (should close the connection)
sharedBind.Close()
if !sharedBind.closed.Load() {
t.Error("Expected bind to be closed after all users released it")
}
}
// TestEndpoint tests the Endpoint implementation
func TestEndpoint(t *testing.T) {
addr := netip.MustParseAddr("192.168.1.1")
addrPort := netip.AddrPortFrom(addr, 51820)
ep := &Endpoint{AddrPort: addrPort}
// Test DstIP
if ep.DstIP() != addr {
t.Errorf("Expected DstIP to be %v, got %v", addr, ep.DstIP())
}
// Test DstToString
expected := "192.168.1.1:51820"
if ep.DstToString() != expected {
t.Errorf("Expected DstToString to be %q, got %q", expected, ep.DstToString())
}
// Test DstToBytes
bytes := ep.DstToBytes()
if len(bytes) == 0 {
t.Error("Expected DstToBytes to return non-empty slice")
}
// Test SrcIP (should be zero)
if ep.SrcIP().IsValid() {
t.Error("Expected SrcIP to be invalid")
}
// Test ClearSrc (should not panic)
ep.ClearSrc()
}
// TestParseEndpoint tests the ParseEndpoint method
func TestParseEndpoint(t *testing.T) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
bind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
defer bind.Close()
tests := []struct {
name string
input string
wantErr bool
checkAddr func(*testing.T, wgConn.Endpoint)
}{
{
name: "valid IPv4",
input: "192.168.1.1:51820",
wantErr: false,
checkAddr: func(t *testing.T, ep wgConn.Endpoint) {
if ep.DstToString() != "192.168.1.1:51820" {
t.Errorf("Expected 192.168.1.1:51820, got %s", ep.DstToString())
}
},
},
{
name: "valid IPv6",
input: "[::1]:51820",
wantErr: false,
checkAddr: func(t *testing.T, ep wgConn.Endpoint) {
if ep.DstToString() != "[::1]:51820" {
t.Errorf("Expected [::1]:51820, got %s", ep.DstToString())
}
},
},
{
name: "invalid - missing port",
input: "192.168.1.1",
wantErr: true,
},
{
name: "invalid - bad format",
input: "not-an-address",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ep, err := bind.ParseEndpoint(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("ParseEndpoint() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && tt.checkAddr != nil {
tt.checkAddr(t, ep)
}
})
}
}

124
config.go
View File

@@ -14,9 +14,11 @@ import (
// OlmConfig holds all configuration options for the Olm client
type OlmConfig struct {
// Connection settings
Endpoint string `json:"endpoint"`
ID string `json:"id"`
Secret string `json:"secret"`
Endpoint string `json:"endpoint"`
ID string `json:"id"`
Secret string `json:"secret"`
OrgID string `json:"org"`
UserToken string `json:"userToken"`
// Network settings
MTU int `json:"mtu"`
@@ -27,8 +29,9 @@ type OlmConfig struct {
LogLevel string `json:"logLevel"`
// HTTP server
EnableHTTP bool `json:"enableHttp"`
EnableAPI bool `json:"enableApi"`
HTTPAddr string `json:"httpAddr"`
SocketPath string `json:"socketPath"`
// Ping settings
PingInterval string `json:"pingInterval"`
@@ -37,6 +40,7 @@ type OlmConfig struct {
// Advanced
Holepunch bool `json:"holepunch"`
TlsClientCert string `json:"tlsClientCert"`
// DoNotCreateNewClient bool `json:"doNotCreateNewClient"`
// Parsed values (not in JSON)
PingIntervalDuration time.Duration `json:"-"`
@@ -44,6 +48,8 @@ type OlmConfig struct {
// Source tracking (not in JSON)
sources map[string]string `json:"-"`
Version string
}
// ConfigSource tracks where each config value came from
@@ -58,17 +64,27 @@ const (
// DefaultConfig returns a config with default values
func DefaultConfig() *OlmConfig {
// Set OS-specific socket path
var socketPath string
switch runtime.GOOS {
case "windows":
socketPath = "olm"
default: // darwin, linux, and others
socketPath = "/var/run/olm.sock"
}
config := &OlmConfig{
MTU: 1280,
DNS: "8.8.8.8",
LogLevel: "INFO",
InterfaceName: "olm",
EnableHTTP: false,
HTTPAddr: ":9452",
EnableAPI: false,
SocketPath: socketPath,
PingInterval: "3s",
PingTimeout: "5s",
Holepunch: false,
sources: make(map[string]string),
// DoNotCreateNewClient: false,
sources: make(map[string]string),
}
// Track default sources
@@ -76,11 +92,13 @@ func DefaultConfig() *OlmConfig {
config.sources["dns"] = string(SourceDefault)
config.sources["logLevel"] = string(SourceDefault)
config.sources["interface"] = string(SourceDefault)
config.sources["enableHttp"] = string(SourceDefault)
config.sources["enableApi"] = string(SourceDefault)
config.sources["httpAddr"] = string(SourceDefault)
config.sources["socketPath"] = string(SourceDefault)
config.sources["pingInterval"] = string(SourceDefault)
config.sources["pingTimeout"] = string(SourceDefault)
config.sources["holepunch"] = string(SourceDefault)
// config.sources["doNotCreateNewClient"] = string(SourceDefault)
return config
}
@@ -175,6 +193,14 @@ func loadConfigFromEnv(config *OlmConfig) {
config.Secret = val
config.sources["secret"] = string(SourceEnv)
}
if val := os.Getenv("ORG"); val != "" {
config.OrgID = val
config.sources["org"] = string(SourceEnv)
}
if val := os.Getenv("USER_TOKEN"); val != "" {
config.UserToken = val
config.sources["userToken"] = string(SourceEnv)
}
if val := os.Getenv("MTU"); val != "" {
if mtu, err := strconv.Atoi(val); err == nil {
config.MTU = mtu
@@ -207,14 +233,22 @@ func loadConfigFromEnv(config *OlmConfig) {
config.PingTimeout = val
config.sources["pingTimeout"] = string(SourceEnv)
}
if val := os.Getenv("ENABLE_HTTP"); val == "true" {
config.EnableHTTP = true
config.sources["enableHttp"] = string(SourceEnv)
if val := os.Getenv("ENABLE_API"); val == "true" {
config.EnableAPI = true
config.sources["enableApi"] = string(SourceEnv)
}
if val := os.Getenv("SOCKET_PATH"); val != "" {
config.SocketPath = val
config.sources["socketPath"] = string(SourceEnv)
}
if val := os.Getenv("HOLEPUNCH"); val == "true" {
config.Holepunch = true
config.sources["holepunch"] = string(SourceEnv)
}
// if val := os.Getenv("DO_NOT_CREATE_NEW_CLIENT"); val == "true" {
// config.DoNotCreateNewClient = true
// config.sources["doNotCreateNewClient"] = string(SourceEnv)
// }
}
// loadConfigFromCLI loads configuration from command-line arguments
@@ -226,30 +260,38 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
"endpoint": config.Endpoint,
"id": config.ID,
"secret": config.Secret,
"org": config.OrgID,
"userToken": config.UserToken,
"mtu": config.MTU,
"dns": config.DNS,
"logLevel": config.LogLevel,
"interface": config.InterfaceName,
"httpAddr": config.HTTPAddr,
"socketPath": config.SocketPath,
"pingInterval": config.PingInterval,
"pingTimeout": config.PingTimeout,
"enableHttp": config.EnableHTTP,
"enableApi": config.EnableAPI,
"holepunch": config.Holepunch,
// "doNotCreateNewClient": config.DoNotCreateNewClient,
}
// Define flags
serviceFlags.StringVar(&config.Endpoint, "endpoint", config.Endpoint, "Endpoint of your Pangolin server")
serviceFlags.StringVar(&config.ID, "id", config.ID, "Olm ID")
serviceFlags.StringVar(&config.Secret, "secret", config.Secret, "Olm secret")
serviceFlags.StringVar(&config.OrgID, "org", config.OrgID, "Organization ID")
serviceFlags.StringVar(&config.UserToken, "user-token", config.UserToken, "User token (optional)")
serviceFlags.IntVar(&config.MTU, "mtu", config.MTU, "MTU to use")
serviceFlags.StringVar(&config.DNS, "dns", config.DNS, "DNS server to use")
serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
serviceFlags.StringVar(&config.InterfaceName, "interface", config.InterfaceName, "Name of the WireGuard interface")
serviceFlags.StringVar(&config.HTTPAddr, "http-addr", config.HTTPAddr, "HTTP server address (e.g., ':9452')")
serviceFlags.StringVar(&config.SocketPath, "socket-path", config.SocketPath, "Unix socket path (or named pipe on Windows)")
serviceFlags.StringVar(&config.PingInterval, "ping-interval", config.PingInterval, "Interval for pinging the server")
serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping")
serviceFlags.BoolVar(&config.EnableHTTP, "enable-http", config.EnableHTTP, "Enable HTTP server for receiving connection requests")
serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests")
serviceFlags.BoolVar(&config.Holepunch, "holepunch", config.Holepunch, "Enable hole punching")
// serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client")
version := serviceFlags.Bool("version", false, "Print the version")
showConfig := serviceFlags.Bool("show-config", false, "Show configuration sources and exit")
@@ -269,6 +311,12 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
if config.Secret != origValues["secret"].(string) {
config.sources["secret"] = string(SourceCLI)
}
if config.OrgID != origValues["org"].(string) {
config.sources["org"] = string(SourceCLI)
}
if config.UserToken != origValues["userToken"].(string) {
config.sources["userToken"] = string(SourceCLI)
}
if config.MTU != origValues["mtu"].(int) {
config.sources["mtu"] = string(SourceCLI)
}
@@ -284,18 +332,24 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
if config.HTTPAddr != origValues["httpAddr"].(string) {
config.sources["httpAddr"] = string(SourceCLI)
}
if config.SocketPath != origValues["socketPath"].(string) {
config.sources["socketPath"] = string(SourceCLI)
}
if config.PingInterval != origValues["pingInterval"].(string) {
config.sources["pingInterval"] = string(SourceCLI)
}
if config.PingTimeout != origValues["pingTimeout"].(string) {
config.sources["pingTimeout"] = string(SourceCLI)
}
if config.EnableHTTP != origValues["enableHttp"].(bool) {
config.sources["enableHttp"] = string(SourceCLI)
if config.EnableAPI != origValues["enableApi"].(bool) {
config.sources["enableApi"] = string(SourceCLI)
}
if config.Holepunch != origValues["holepunch"].(bool) {
config.sources["holepunch"] = string(SourceCLI)
}
// if config.DoNotCreateNewClient != origValues["doNotCreateNewClient"].(bool) {
// config.sources["doNotCreateNewClient"] = string(SourceCLI)
// }
return *version, *showConfig, nil
}
@@ -348,6 +402,14 @@ func mergeConfigs(dest, src *OlmConfig) {
dest.Secret = src.Secret
dest.sources["secret"] = string(SourceFile)
}
if src.OrgID != "" {
dest.OrgID = src.OrgID
dest.sources["org"] = string(SourceFile)
}
if src.UserToken != "" {
dest.UserToken = src.UserToken
dest.sources["userToken"] = string(SourceFile)
}
if src.MTU != 0 && src.MTU != 1280 {
dest.MTU = src.MTU
dest.sources["mtu"] = string(SourceFile)
@@ -368,6 +430,14 @@ func mergeConfigs(dest, src *OlmConfig) {
dest.HTTPAddr = src.HTTPAddr
dest.sources["httpAddr"] = string(SourceFile)
}
if src.SocketPath != "" {
// Check if it's not the default for any OS
isDefault := src.SocketPath == "/var/run/olm.sock" || src.SocketPath == "olm"
if !isDefault {
dest.SocketPath = src.SocketPath
dest.sources["socketPath"] = string(SourceFile)
}
}
if src.PingInterval != "" && src.PingInterval != "3s" {
dest.PingInterval = src.PingInterval
dest.sources["pingInterval"] = string(SourceFile)
@@ -381,14 +451,18 @@ func mergeConfigs(dest, src *OlmConfig) {
dest.sources["tlsClientCert"] = string(SourceFile)
}
// For booleans, we always take the source value if explicitly set
if src.EnableHTTP {
dest.EnableHTTP = src.EnableHTTP
dest.sources["enableHttp"] = string(SourceFile)
if src.EnableAPI {
dest.EnableAPI = src.EnableAPI
dest.sources["enableApi"] = string(SourceFile)
}
if src.Holepunch {
dest.Holepunch = src.Holepunch
dest.sources["holepunch"] = string(SourceFile)
}
// if src.DoNotCreateNewClient {
// dest.DoNotCreateNewClient = src.DoNotCreateNewClient
// dest.sources["doNotCreateNewClient"] = string(SourceFile)
// }
}
// SaveConfig saves the current configuration to the config file
@@ -445,6 +519,8 @@ func (c *OlmConfig) ShowConfig() {
fmt.Printf(" endpoint = %s [%s]\n", formatValue("endpoint", c.Endpoint), getSource("endpoint"))
fmt.Printf(" id = %s [%s]\n", formatValue("id", c.ID), getSource("id"))
fmt.Printf(" secret = %s [%s]\n", formatValue("secret", c.Secret), getSource("secret"))
fmt.Printf(" org = %s [%s]\n", formatValue("org", c.OrgID), getSource("org"))
fmt.Printf(" user-token = %s [%s]\n", formatValue("userToken", c.UserToken), getSource("userToken"))
// Network settings
fmt.Println("\nNetwork:")
@@ -456,10 +532,11 @@ func (c *OlmConfig) ShowConfig() {
fmt.Println("\nLogging:")
fmt.Printf(" log-level = %s [%s]\n", c.LogLevel, getSource("logLevel"))
// HTTP server
fmt.Println("\nHTTP Server:")
fmt.Printf(" enable-http = %v [%s]\n", c.EnableHTTP, getSource("enableHttp"))
// API server
fmt.Println("\nAPI Server:")
fmt.Printf(" enable-api = %v [%s]\n", c.EnableAPI, getSource("enableApi"))
fmt.Printf(" http-addr = %s [%s]\n", c.HTTPAddr, getSource("httpAddr"))
fmt.Printf(" socket-path = %s [%s]\n", c.SocketPath, getSource("socketPath"))
// Timing
fmt.Println("\nTiming:")
@@ -468,9 +545,10 @@ func (c *OlmConfig) ShowConfig() {
// Advanced
fmt.Println("\nAdvanced:")
fmt.Printf(" holepunch = %v [%s]\n", c.Holepunch, getSource("holepunch"))
fmt.Printf(" holepunch = %v [%s]\n", c.Holepunch, getSource("holepunch"))
// fmt.Printf(" do-not-create-new-client = %v [%s]\n", c.DoNotCreateNewClient, getSource("doNotCreateNewClient"))
if c.TlsClientCert != "" {
fmt.Printf(" tls-cert = %s [%s]\n", c.TlsClientCert, getSource("tlsClientCert"))
fmt.Printf(" tls-cert = %s [%s]\n", c.TlsClientCert, getSource("tlsClientCert"))
}
// Source legend

523
diff Normal file
View File

@@ -0,0 +1,523 @@
diff --git a/api/api.go b/api/api.go
index dd07751..0d2e4ef 100644
--- a/api/api.go
+++ b/api/api.go
@@ -18,6 +18,11 @@ type ConnectionRequest struct {
Endpoint string `json:"endpoint"`
}
+// SwitchOrgRequest defines the structure for switching organizations
+type SwitchOrgRequest struct {
+ OrgID string `json:"orgId"`
+}
+
// PeerStatus represents the status of a peer connection
type PeerStatus struct {
SiteID int `json:"siteId"`
@@ -35,6 +40,7 @@ type StatusResponse struct {
Registered bool `json:"registered"`
TunnelIP string `json:"tunnelIP,omitempty"`
Version string `json:"version,omitempty"`
+ OrgID string `json:"orgId,omitempty"`
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
}
@@ -46,6 +52,7 @@ type API struct {
server *http.Server
connectionChan chan ConnectionRequest
shutdownChan chan struct{}
+ switchOrgChan chan SwitchOrgRequest
statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus
connectedAt time.Time
@@ -53,6 +60,7 @@ type API struct {
isRegistered bool
tunnelIP string
version string
+ orgID string
}
// NewAPI creates a new HTTP server that listens on a TCP address
@@ -61,6 +69,7 @@ func NewAPI(addr string) *API {
addr: addr,
connectionChan: make(chan ConnectionRequest, 1),
shutdownChan: make(chan struct{}, 1),
+ switchOrgChan: make(chan SwitchOrgRequest, 1),
peerStatuses: make(map[int]*PeerStatus),
}
@@ -73,6 +82,7 @@ func NewAPISocket(socketPath string) *API {
socketPath: socketPath,
connectionChan: make(chan ConnectionRequest, 1),
shutdownChan: make(chan struct{}, 1),
+ switchOrgChan: make(chan SwitchOrgRequest, 1),
peerStatuses: make(map[int]*PeerStatus),
}
@@ -85,6 +95,7 @@ func (s *API) Start() error {
mux.HandleFunc("/connect", s.handleConnect)
mux.HandleFunc("/status", s.handleStatus)
mux.HandleFunc("/exit", s.handleExit)
+ mux.HandleFunc("/switch-org", s.handleSwitchOrg)
s.server = &http.Server{
Handler: mux,
@@ -143,6 +154,11 @@ func (s *API) GetShutdownChannel() <-chan struct{} {
return s.shutdownChan
}
+// GetSwitchOrgChannel returns the channel for receiving org switch requests
+func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest {
+ return s.switchOrgChan
+}
+
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
s.statusMu.Lock()
@@ -198,6 +214,13 @@ func (s *API) SetVersion(version string) {
s.version = version
}
+// SetOrgID sets the org ID
+func (s *API) SetOrgID(orgID string) {
+ s.statusMu.Lock()
+ defer s.statusMu.Unlock()
+ s.orgID = orgID
+}
+
// UpdatePeerRelayStatus updates only the relay status of a peer
func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
s.statusMu.Lock()
@@ -261,6 +284,7 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
Registered: s.isRegistered,
TunnelIP: s.tunnelIP,
Version: s.version,
+ OrgID: s.orgID,
PeerStatuses: s.peerStatuses,
}
@@ -292,3 +316,44 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
"status": "shutdown initiated",
})
}
+
+// handleSwitchOrg handles the /switch-org endpoint
+func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ var req SwitchOrgRequest
+ decoder := json.NewDecoder(r.Body)
+ if err := decoder.Decode(&req); err != nil {
+ http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
+ return
+ }
+
+ // Validate required fields
+ if req.OrgID == "" {
+ http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest)
+ return
+ }
+
+ logger.Info("Received org switch request to orgId: %s", req.OrgID)
+
+ // Send the request to the main goroutine
+ select {
+ case s.switchOrgChan <- req:
+ // Signal sent successfully
+ default:
+ // Channel already has a signal, don't block
+ http.Error(w, "Org switch already in progress", http.StatusTooManyRequests)
+ return
+ }
+
+ // Return a success response
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusAccepted)
+ json.NewEncoder(w).Encode(map[string]string{
+ "status": "org switch initiated",
+ "orgId": req.OrgID,
+ })
+}
diff --git a/olm/olm.go b/olm/olm.go
index 78080c4..5e292d6 100644
--- a/olm/olm.go
+++ b/olm/olm.go
@@ -58,6 +58,58 @@ type Config struct {
OrgID string
}
+// tunnelState holds all the active tunnel resources that need cleanup
+type tunnelState struct {
+ dev *device.Device
+ tdev tun.Device
+ uapiListener net.Listener
+ peerMonitor *peermonitor.PeerMonitor
+ stopRegister func()
+ connected bool
+}
+
+// teardownTunnel cleans up all tunnel resources
+func teardownTunnel(state *tunnelState) {
+ if state == nil {
+ return
+ }
+
+ logger.Info("Tearing down tunnel...")
+
+ // Stop registration messages
+ if state.stopRegister != nil {
+ state.stopRegister()
+ state.stopRegister = nil
+ }
+
+ // Stop peer monitor
+ if state.peerMonitor != nil {
+ state.peerMonitor.Stop()
+ state.peerMonitor = nil
+ }
+
+ // Close UAPI listener
+ if state.uapiListener != nil {
+ state.uapiListener.Close()
+ state.uapiListener = nil
+ }
+
+ // Close WireGuard device
+ if state.dev != nil {
+ state.dev.Close()
+ state.dev = nil
+ }
+
+ // Close TUN device
+ if state.tdev != nil {
+ state.tdev.Close()
+ state.tdev = nil
+ }
+
+ state.connected = false
+ logger.Info("Tunnel teardown complete")
+}
+
func Run(ctx context.Context, config Config) {
// Create a cancellable context for internal shutdown control
ctx, cancel := context.WithCancel(ctx)
@@ -75,14 +127,14 @@ func Run(ctx context.Context, config Config) {
pingTimeout = config.PingTimeoutDuration
doHolepunch = config.Holepunch
privateKey wgtypes.Key
- connected bool
- dev *device.Device
wgData WgData
holePunchData HolePunchData
- uapiListener net.Listener
- tdev tun.Device
+ orgID = config.OrgID
)
+ // Tunnel state that can be torn down and recreated
+ tunnel := &tunnelState{}
+
stopHolepunch = make(chan struct{})
stopPing = make(chan struct{})
@@ -110,6 +162,7 @@ func Run(ctx context.Context, config Config) {
}
apiServer.SetVersion(config.Version)
+ apiServer.SetOrgID(orgID)
if err := apiServer.Start(); err != nil {
logger.Fatal("Failed to start HTTP server: %v", err)
}
@@ -249,14 +302,14 @@ func Run(ctx context.Context, config Config) {
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
- if connected {
+ if tunnel.connected {
logger.Info("Already connected. Ignoring new connection request.")
return
}
- if stopRegister != nil {
- stopRegister()
- stopRegister = nil
+ if tunnel.stopRegister != nil {
+ tunnel.stopRegister()
+ tunnel.stopRegister = nil
}
close(stopHolepunch)
@@ -266,9 +319,9 @@ func Run(ctx context.Context, config Config) {
time.Sleep(500 * time.Millisecond)
// if there is an existing tunnel then close it
- if dev != nil {
+ if tunnel.dev != nil {
logger.Info("Got new message. Closing existing tunnel!")
- dev.Close()
+ tunnel.dev.Close()
}
jsonData, err := json.Marshal(msg.Data)
@@ -282,7 +335,7 @@ func Run(ctx context.Context, config Config) {
return
}
- tdev, err = func() (tun.Device, error) {
+ tunnel.tdev, err = func() (tun.Device, error) {
if runtime.GOOS == "darwin" {
interfaceName, err := findUnusedUTUN()
if err != nil {
@@ -301,7 +354,7 @@ func Run(ctx context.Context, config Config) {
return
}
- if realInterfaceName, err2 := tdev.Name(); err2 == nil {
+ if realInterfaceName, err2 := tunnel.tdev.Name(); err2 == nil {
interfaceName = realInterfaceName
}
@@ -321,9 +374,9 @@ func Run(ctx context.Context, config Config) {
return
}
- dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
+ tunnel.dev = device.NewDevice(tunnel.tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
- uapiListener, err = uapiListen(interfaceName, fileUAPI)
+ tunnel.uapiListener, err = uapiListen(interfaceName, fileUAPI)
if err != nil {
logger.Error("Failed to listen on uapi socket: %v", err)
os.Exit(1)
@@ -331,16 +384,16 @@ func Run(ctx context.Context, config Config) {
go func() {
for {
- conn, err := uapiListener.Accept()
+ conn, err := tunnel.uapiListener.Accept()
if err != nil {
return
}
- go dev.IpcHandle(conn)
+ go tunnel.dev.IpcHandle(conn)
}
}()
logger.Info("UAPI listener started")
- if err = dev.Up(); err != nil {
+ if err = tunnel.dev.Up(); err != nil {
logger.Error("Failed to bring up WireGuard device: %v", err)
}
if err = ConfigureInterface(interfaceName, wgData); err != nil {
@@ -350,7 +403,7 @@ func Run(ctx context.Context, config Config) {
apiServer.SetTunnelIP(wgData.TunnelIP)
}
- peerMonitor = peermonitor.NewPeerMonitor(
+ tunnel.peerMonitor = peermonitor.NewPeerMonitor(
func(siteID int, connected bool, rtt time.Duration) {
if apiServer != nil {
// Find the site config to get endpoint information
@@ -375,7 +428,7 @@ func Run(ctx context.Context, config Config) {
},
fixKey(privateKey.String()),
olm,
- dev,
+ tunnel.dev,
doHolepunch,
)
@@ -388,7 +441,7 @@ func Run(ctx context.Context, config Config) {
// Format the endpoint before configuring the peer.
site.Endpoint = formatEndpoint(site.Endpoint)
- if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil {
+ if err := ConfigurePeer(tunnel.dev, *site, privateKey, endpoint); err != nil {
logger.Error("Failed to configure peer: %v", err)
return
}
@@ -404,13 +457,13 @@ func Run(ctx context.Context, config Config) {
logger.Info("Configured peer %s", site.PublicKey)
}
- peerMonitor.Start()
+ tunnel.peerMonitor.Start()
if apiServer != nil {
apiServer.SetRegistered(true)
}
- connected = true
+ tunnel.connected = true
logger.Info("WireGuard device created.")
})
@@ -441,7 +494,7 @@ func Run(ctx context.Context, config Config) {
}
// Update the peer in WireGuard
- if dev != nil {
+ if tunnel.dev != nil {
// Find the existing peer to get old data
var oldRemoteSubnets string
var oldPublicKey string
@@ -456,7 +509,7 @@ func Run(ctx context.Context, config Config) {
// If the public key has changed, remove the old peer first
if oldPublicKey != "" && oldPublicKey != updateData.PublicKey {
logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey)
- if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil {
+ if err := RemovePeer(tunnel.dev, updateData.SiteId, oldPublicKey); err != nil {
logger.Error("Failed to remove old peer: %v", err)
return
}
@@ -465,7 +518,7 @@ func Run(ctx context.Context, config Config) {
// Format the endpoint before updating the peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
+ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to update peer: %v", err)
return
}
@@ -524,11 +577,11 @@ func Run(ctx context.Context, config Config) {
}
// Add the peer to WireGuard
- if dev != nil {
+ if tunnel.dev != nil {
// Format the endpoint before adding the new peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
+ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to add peer: %v", err)
return
}
@@ -585,8 +638,8 @@ func Run(ctx context.Context, config Config) {
}
// Remove the peer from WireGuard
- if dev != nil {
- if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
+ if tunnel.dev != nil {
+ if err := RemovePeer(tunnel.dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
logger.Error("Failed to remove peer: %v", err)
// Send error response if needed
return
@@ -640,7 +693,7 @@ func Run(ctx context.Context, config Config) {
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
}
- peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
+ tunnel.peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
})
olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) {
@@ -673,7 +726,7 @@ func Run(ctx context.Context, config Config) {
apiServer.SetConnectionStatus(true)
}
- if connected {
+ if tunnel.connected {
logger.Debug("Already connected, skipping registration")
return nil
}
@@ -682,11 +735,11 @@ func Run(ctx context.Context, config Config) {
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch)
- stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
+ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !doHolepunch,
"olmVersion": config.Version,
- "orgId": config.OrgID,
+ "orgId": orgID,
}, 1*time.Second)
go keepSendingPing(olm)
@@ -705,6 +758,49 @@ func Run(ctx context.Context, config Config) {
}
defer olm.Close()
+ // Listen for org switch requests from the API (after olm is created)
+ if apiServer != nil {
+ go func() {
+ for req := range apiServer.GetSwitchOrgChannel() {
+ logger.Info("Org switch requested via API to orgId: %s", req.OrgID)
+
+ // Update the orgId
+ orgID = req.OrgID
+
+ // Teardown existing tunnel
+ teardownTunnel(tunnel)
+
+ // Reset tunnel state
+ tunnel = &tunnelState{}
+
+ // Stop holepunch
+ select {
+ case <-stopHolepunch:
+ // Channel already closed
+ default:
+ close(stopHolepunch)
+ }
+ stopHolepunch = make(chan struct{})
+
+ // Clear API server state
+ apiServer.SetRegistered(false)
+ apiServer.SetTunnelIP("")
+ apiServer.SetOrgID(orgID)
+
+ // Send new registration message with updated orgId
+ publicKey := privateKey.PublicKey()
+ logger.Info("Sending registration message with new orgId: %s", orgID)
+
+ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
+ "publicKey": publicKey.String(),
+ "relay": !doHolepunch,
+ "olmVersion": config.Version,
+ "orgId": orgID,
+ }, 1*time.Second)
+ }
+ }()
+ }
+
select {
case <-ctx.Done():
logger.Info("Context cancelled")
@@ -717,9 +813,9 @@ func Run(ctx context.Context, config Config) {
close(stopHolepunch)
}
- if stopRegister != nil {
- stopRegister()
- stopRegister = nil
+ if tunnel.stopRegister != nil {
+ tunnel.stopRegister()
+ tunnel.stopRegister = nil
}
select {
@@ -729,16 +825,8 @@ func Run(ctx context.Context, config Config) {
close(stopPing)
}
- if peerMonitor != nil {
- peerMonitor.Stop()
- }
-
- if uapiListener != nil {
- uapiListener.Close()
- }
- if dev != nil {
- dev.Close()
- }
+ // Use teardownTunnel to clean up all tunnel resources
+ teardownTunnel(tunnel)
if apiServer != nil {
apiServer.Stop()

7
go.mod
View File

@@ -3,20 +3,21 @@ module github.com/fosrl/olm
go 1.25
require (
github.com/Microsoft/go-winio v0.6.2
github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7
github.com/gorilla/websocket v1.5.3
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.43.0
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792
golang.org/x/net v0.45.0
golang.org/x/sys v0.37.0
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
software.sslmate.com/src/go-pkcs12 v0.6.0
)
require (
github.com/gorilla/websocket v1.5.3 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
golang.org/x/net v0.45.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 // indirect
software.sslmate.com/src/go-pkcs12 v0.6.0 // indirect
)

2
go.sum
View File

@@ -1,3 +1,5 @@
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7 h1:6bSU8Efyhx1SR53iSw1Wjk5V8vDfizGAudq/GlE9b+o=
github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7/go.mod h1:Ac0k2FmAMC+hu21rAK+p7EnnEGrqKO/QZuGTVHA/XDM=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=

351
holepunch/holepunch.go Normal file
View File

@@ -0,0 +1,351 @@
package holepunch
import (
"encoding/json"
"fmt"
"net"
"sync"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/bind"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/curve25519"
"golang.org/x/exp/rand"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// DomainResolver is a function type for resolving domains to IP addresses
type DomainResolver func(string) (string, error)
// ExitNode represents a WireGuard exit node for hole punching
type ExitNode struct {
Endpoint string `json:"endpoint"`
PublicKey string `json:"publicKey"`
}
// Manager handles UDP hole punching operations
type Manager struct {
mu sync.Mutex
running bool
stopChan chan struct{}
sharedBind *bind.SharedBind
olmID string
token string
domainResolver DomainResolver
}
// NewManager creates a new hole punch manager
func NewManager(sharedBind *bind.SharedBind, olmID string, domainResolver DomainResolver) *Manager {
return &Manager{
sharedBind: sharedBind,
olmID: olmID,
domainResolver: domainResolver,
}
}
// SetToken updates the authentication token used for hole punching
func (m *Manager) SetToken(token string) {
m.mu.Lock()
defer m.mu.Unlock()
m.token = token
}
// IsRunning returns whether hole punching is currently active
func (m *Manager) IsRunning() bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.running
}
// Stop stops any ongoing hole punch operations
func (m *Manager) Stop() {
m.mu.Lock()
defer m.mu.Unlock()
if !m.running {
return
}
if m.stopChan != nil {
close(m.stopChan)
m.stopChan = nil
}
m.running = false
logger.Info("Hole punch manager stopped")
}
// StartMultipleExitNodes starts hole punching to multiple exit nodes
func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error {
m.mu.Lock()
if m.running {
m.mu.Unlock()
logger.Debug("UDP hole punch already running, skipping new request")
return fmt.Errorf("hole punch already running")
}
if len(exitNodes) == 0 {
m.mu.Unlock()
logger.Warn("No exit nodes provided for hole punching")
return fmt.Errorf("no exit nodes provided")
}
m.running = true
m.stopChan = make(chan struct{})
m.mu.Unlock()
logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes))
go m.runMultipleExitNodes(exitNodes)
return nil
}
// StartSingleEndpoint starts hole punching to a single endpoint (legacy mode)
func (m *Manager) StartSingleEndpoint(endpoint, serverPubKey string) error {
m.mu.Lock()
if m.running {
m.mu.Unlock()
logger.Debug("UDP hole punch already running, skipping new request")
return fmt.Errorf("hole punch already running")
}
m.running = true
m.stopChan = make(chan struct{})
m.mu.Unlock()
logger.Info("Starting UDP hole punch to %s with shared bind", endpoint)
go m.runSingleEndpoint(endpoint, serverPubKey)
return nil
}
// runMultipleExitNodes performs hole punching to multiple exit nodes
func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) {
defer func() {
m.mu.Lock()
m.running = false
m.mu.Unlock()
logger.Info("UDP hole punch goroutine ended for all exit nodes")
}()
// Resolve all endpoints upfront
type resolvedExitNode struct {
remoteAddr *net.UDPAddr
publicKey string
endpointName string
}
var resolvedNodes []resolvedExitNode
for _, exitNode := range exitNodes {
host, err := m.domainResolver(exitNode.Endpoint)
if err != nil {
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
continue
}
serverAddr := net.JoinHostPort(host, "21820")
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
continue
}
resolvedNodes = append(resolvedNodes, resolvedExitNode{
remoteAddr: remoteAddr,
publicKey: exitNode.PublicKey,
endpointName: exitNode.Endpoint,
})
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
}
if len(resolvedNodes) == 0 {
logger.Error("No exit nodes could be resolved")
return
}
// Send initial hole punch to all exit nodes
for _, node := range resolvedNodes {
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
logger.Warn("Failed to send initial hole punch to %s: %v", node.endpointName, err)
}
}
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
timeout := time.NewTimer(15 * time.Second)
defer timeout.Stop()
for {
select {
case <-m.stopChan:
logger.Debug("Hole punch stopped by signal")
return
case <-timeout.C:
logger.Debug("Hole punch timeout reached")
return
case <-ticker.C:
// Send hole punch to all exit nodes
for _, node := range resolvedNodes {
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err)
}
}
}
}
}
// runSingleEndpoint performs hole punching to a single endpoint
func (m *Manager) runSingleEndpoint(endpoint, serverPubKey string) {
defer func() {
m.mu.Lock()
m.running = false
m.mu.Unlock()
logger.Info("UDP hole punch goroutine ended for %s", endpoint)
}()
host, err := m.domainResolver(endpoint)
if err != nil {
logger.Error("Failed to resolve domain %s: %v", endpoint, err)
return
}
serverAddr := net.JoinHostPort(host, "21820")
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
return
}
// Execute once immediately before starting the loop
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
logger.Warn("Failed to send initial hole punch: %v", err)
}
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
timeout := time.NewTimer(15 * time.Second)
defer timeout.Stop()
for {
select {
case <-m.stopChan:
logger.Debug("Hole punch stopped by signal")
return
case <-timeout.C:
logger.Debug("Hole punch timeout reached")
return
case <-ticker.C:
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
logger.Debug("Failed to send hole punch: %v", err)
}
}
}
}
// sendHolePunch sends an encrypted hole punch packet using the shared bind
func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) error {
m.mu.Lock()
token := m.token
olmID := m.olmID
m.mu.Unlock()
if serverPubKey == "" || token == "" {
return fmt.Errorf("server public key or OLM token is empty")
}
payload := struct {
OlmID string `json:"olmId"`
Token string `json:"token"`
}{
OlmID: olmID,
Token: token,
}
// Convert payload to JSON
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal payload: %w", err)
}
// Encrypt the payload using the server's WireGuard public key
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
if err != nil {
return fmt.Errorf("failed to encrypt payload: %w", err)
}
jsonData, err := json.Marshal(encryptedPayload)
if err != nil {
return fmt.Errorf("failed to marshal encrypted payload: %w", err)
}
_, err = m.sharedBind.WriteToUDP(jsonData, remoteAddr)
if err != nil {
return fmt.Errorf("failed to write to UDP: %w", err)
}
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
return nil
}
// encryptPayload encrypts the payload using ChaCha20-Poly1305 AEAD with X25519 key exchange
func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) {
// Generate an ephemeral keypair for this message
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
}
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
// Parse the server's public key
serverPubKey, err := wgtypes.ParseKey(serverPublicKey)
if err != nil {
return nil, fmt.Errorf("failed to parse server public key: %v", err)
}
// Use X25519 for key exchange
var ephPrivKeyFixed [32]byte
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
// Perform X25519 key exchange
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
if err != nil {
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
}
// Create an AEAD cipher using the shared secret
aead, err := chacha20poly1305.New(sharedSecret)
if err != nil {
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
}
// Generate a random nonce
nonce := make([]byte, aead.NonceSize())
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %v", err)
}
// Encrypt the payload
ciphertext := aead.Seal(nil, nonce, payload, nil)
// Prepare the final encrypted message
encryptedMsg := struct {
EphemeralPublicKey string `json:"ephemeralPublicKey"`
Nonce []byte `json:"nonce"`
Ciphertext []byte `json:"ciphertext"`
}{
EphemeralPublicKey: ephemeralPublicKey.String(),
Nonce: nonce,
Ciphertext: ciphertext,
}
return encryptedMsg, nil
}

View File

@@ -1,217 +0,0 @@
package httpserver
import (
"encoding/json"
"fmt"
"net/http"
"sync"
"time"
"github.com/fosrl/newt/logger"
)
// ConnectionRequest defines the structure for an incoming connection request
type ConnectionRequest struct {
ID string `json:"id"`
Secret string `json:"secret"`
Endpoint string `json:"endpoint"`
}
// PeerStatus represents the status of a peer connection
type PeerStatus struct {
SiteID int `json:"siteId"`
Connected bool `json:"connected"`
RTT time.Duration `json:"rtt"`
LastSeen time.Time `json:"lastSeen"`
Endpoint string `json:"endpoint,omitempty"`
IsRelay bool `json:"isRelay"`
}
// StatusResponse is returned by the status endpoint
type StatusResponse struct {
Status string `json:"status"`
Connected bool `json:"connected"`
TunnelIP string `json:"tunnelIP,omitempty"`
Version string `json:"version,omitempty"`
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
}
// HTTPServer represents the HTTP server and its state
type HTTPServer struct {
addr string
server *http.Server
connectionChan chan ConnectionRequest
statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus
connectedAt time.Time
isConnected bool
tunnelIP string
version string
}
// NewHTTPServer creates a new HTTP server
func NewHTTPServer(addr string) *HTTPServer {
s := &HTTPServer{
addr: addr,
connectionChan: make(chan ConnectionRequest, 1),
peerStatuses: make(map[int]*PeerStatus),
}
return s
}
// Start starts the HTTP server
func (s *HTTPServer) Start() error {
mux := http.NewServeMux()
mux.HandleFunc("/connect", s.handleConnect)
mux.HandleFunc("/status", s.handleStatus)
s.server = &http.Server{
Addr: s.addr,
Handler: mux,
}
logger.Info("Starting HTTP server on %s", s.addr)
go func() {
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Error("HTTP server error: %v", err)
}
}()
return nil
}
// Stop stops the HTTP server
func (s *HTTPServer) Stop() error {
logger.Info("Stopping HTTP server")
return s.server.Close()
}
// GetConnectionChannel returns the channel for receiving connection requests
func (s *HTTPServer) GetConnectionChannel() <-chan ConnectionRequest {
return s.connectionChan
}
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
func (s *HTTPServer) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
status, exists := s.peerStatuses[siteID]
if !exists {
status = &PeerStatus{
SiteID: siteID,
}
s.peerStatuses[siteID] = status
}
status.Connected = connected
status.RTT = rtt
status.LastSeen = time.Now()
status.Endpoint = endpoint
status.IsRelay = isRelay
}
// SetConnectionStatus sets the overall connection status
func (s *HTTPServer) SetConnectionStatus(isConnected bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.isConnected = isConnected
if isConnected {
s.connectedAt = time.Now()
} else {
// Clear peer statuses when disconnected
s.peerStatuses = make(map[int]*PeerStatus)
}
}
// SetTunnelIP sets the tunnel IP address
func (s *HTTPServer) SetTunnelIP(tunnelIP string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.tunnelIP = tunnelIP
}
// SetVersion sets the olm version
func (s *HTTPServer) SetVersion(version string) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
s.version = version
}
// UpdatePeerRelayStatus updates only the relay status of a peer
func (s *HTTPServer) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
s.statusMu.Lock()
defer s.statusMu.Unlock()
status, exists := s.peerStatuses[siteID]
if !exists {
status = &PeerStatus{
SiteID: siteID,
}
s.peerStatuses[siteID] = status
}
status.Endpoint = endpoint
status.IsRelay = isRelay
}
// handleConnect handles the /connect endpoint
func (s *HTTPServer) handleConnect(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req ConnectionRequest
decoder := json.NewDecoder(r.Body)
if err := decoder.Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
return
}
// Validate required fields
if req.ID == "" || req.Secret == "" || req.Endpoint == "" {
http.Error(w, "Missing required fields: id, secret, and endpoint must be provided", http.StatusBadRequest)
return
}
// Send the request to the main goroutine
s.connectionChan <- req
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
json.NewEncoder(w).Encode(map[string]string{
"status": "connection request accepted",
})
}
// handleStatus handles the /status endpoint
func (s *HTTPServer) handleStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
s.statusMu.RLock()
defer s.statusMu.RUnlock()
resp := StatusResponse{
Connected: s.isConnected,
TunnelIP: s.tunnelIP,
Version: s.version,
PeerStatuses: s.peerStatuses,
}
if s.isConnected {
resp.Status = "connected"
} else {
resp.Status = "disconnected"
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}

789
main.go
View File

@@ -2,56 +2,16 @@ package main
import (
"context"
"encoding/json"
"fmt"
"net"
"os"
"os/signal"
"runtime"
"strconv"
"strings"
"syscall"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/updates"
"github.com/fosrl/olm/httpserver"
"github.com/fosrl/olm/peermonitor"
"github.com/fosrl/olm/websocket"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/fosrl/olm/olm"
)
// Helper function to format endpoints correctly
func formatEndpoint(endpoint string) string {
if endpoint == "" {
return ""
}
// Check if it's already a valid host:port that SplitHostPort can parse (e.g., [::1]:8080 or 1.2.3.4:8080)
_, _, err := net.SplitHostPort(endpoint)
if err == nil {
return endpoint // Already valid, no change needed
}
// If it failed, it might be our malformed "ipv6:port" string. Let's check and fix it.
lastColon := strings.LastIndex(endpoint, ":")
if lastColon > 0 { // Ensure there is a colon and it's not the first character
hostPart := endpoint[:lastColon]
// Check if the host part is a literal IPv6 address
if ip := net.ParseIP(hostPart); ip != nil && ip.To4() == nil {
// It is! Reformat it with brackets.
portPart := endpoint[lastColon+1:]
return fmt.Sprintf("[%s]:%s", hostPart, portPart)
}
}
// If it's not the specific malformed case, return it as is.
return endpoint
}
func main() {
// Check if we're running as a Windows service
if isWindowsService() {
@@ -193,18 +153,26 @@ func main() {
}
}
// Run in console mode
runOlmMain(context.Background())
}
// Create a context that will be cancelled on interrupt signals
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
func runOlmMain(ctx context.Context) {
// Run in console mode
runOlmMainWithArgs(ctx, os.Args[1:])
}
func runOlmMainWithArgs(ctx context.Context, args []string) {
// Setup Windows event logging if on Windows
if runtime.GOOS != "windows" {
setupWindowsEventLog()
} else {
// Initialize logger for non-Windows platforms
logger.Init()
}
// Load configuration from file, env vars, and CLI args
// Priority: CLI args > Env vars > Config file > Defaults
config, showVersion, showConfig, err := LoadConfig(args)
config, showVersion, showConfig, err := LoadConfig(os.Args[1:])
if err != nil {
fmt.Printf("Failed to load configuration: %v\n", err)
return
@@ -216,36 +184,6 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
os.Exit(0)
}
// Extract commonly used values from config for convenience
var (
endpoint = config.Endpoint
id = config.ID
secret = config.Secret
mtu = config.MTU
logLevel = config.LogLevel
interfaceName = config.InterfaceName
enableHTTP = config.EnableHTTP
httpAddr = config.HTTPAddr
pingInterval = config.PingIntervalDuration
pingTimeout = config.PingTimeoutDuration
doHolepunch = config.Holepunch
privateKey wgtypes.Key
connected bool
)
stopHolepunch = make(chan struct{})
stopPing = make(chan struct{})
// Setup Windows event logging if on Windows
if runtime.GOOS == "windows" {
setupWindowsEventLog()
} else {
// Initialize logger for non-Windows platforms
logger.Init()
}
loggerLevel := parseLogLevel(logLevel)
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
olmVersion := "version_replaceme"
if showVersion {
fmt.Println("Olm version " + olmVersion)
@@ -253,680 +191,35 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
}
logger.Info("Olm version " + olmVersion)
if err := updates.CheckForUpdate("fosrl", "olm", olmVersion); err != nil {
logger.Debug("Failed to check for updates: %v", err)
config.Version = olmVersion
if err := SaveConfig(config); err != nil {
logger.Error("Failed to save full olm config: %v", err)
} else {
logger.Debug("Saved full olm config with all options")
}
// Log startup information
logger.Debug("Olm service starting...")
logger.Debug("Parameters: endpoint='%s', id='%s', secret='%s'", endpoint, id, secret)
logger.Debug("HTTP enabled: %v, HTTP addr: %s", enableHTTP, httpAddr)
if doHolepunch {
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
// Create a new olm.Config struct and copy values from the main config
olmConfig := olm.Config{
Endpoint: config.Endpoint,
ID: config.ID,
Secret: config.Secret,
UserToken: config.UserToken,
MTU: config.MTU,
DNS: config.DNS,
InterfaceName: config.InterfaceName,
LogLevel: config.LogLevel,
EnableAPI: config.EnableAPI,
HTTPAddr: config.HTTPAddr,
SocketPath: config.SocketPath,
Holepunch: config.Holepunch,
TlsClientCert: config.TlsClientCert,
PingIntervalDuration: config.PingIntervalDuration,
PingTimeoutDuration: config.PingTimeoutDuration,
Version: config.Version,
OrgID: config.OrgID,
// DoNotCreateNewClient: config.DoNotCreateNewClient,
}
var httpServer *httpserver.HTTPServer
if enableHTTP {
httpServer = httpserver.NewHTTPServer(httpAddr)
httpServer.SetVersion(olmVersion)
if err := httpServer.Start(); err != nil {
logger.Fatal("Failed to start HTTP server: %v", err)
}
// Use a goroutine to handle connection requests
go func() {
for req := range httpServer.GetConnectionChannel() {
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
// Set the connection parameters
id = req.ID
secret = req.Secret
endpoint = req.Endpoint
}
}()
}
// // Check if required parameters are missing and provide helpful guidance
// missingParams := []string{}
// if id == "" {
// missingParams = append(missingParams, "id (use -id flag or OLM_ID env var)")
// }
// if secret == "" {
// missingParams = append(missingParams, "secret (use -secret flag or OLM_SECRET env var)")
// }
// if endpoint == "" {
// missingParams = append(missingParams, "endpoint (use -endpoint flag or PANGOLIN_ENDPOINT env var)")
// }
// if len(missingParams) > 0 {
// logger.Error("Missing required parameters: %v", missingParams)
// logger.Error("Either provide them as command line flags or set as environment variables")
// fmt.Printf("ERROR: Missing required parameters: %v\n", missingParams)
// fmt.Printf("Please provide them as command line flags or set as environment variables\n")
// if !enableHTTP {
// logger.Error("HTTP server is disabled, cannot receive parameters via API")
// fmt.Printf("HTTP server is disabled, cannot receive parameters via API\n")
// return
// }
// }
// Create a new olm
olm, err := websocket.NewClient(
"olm",
id, // CLI arg takes precedence
secret, // CLI arg takes precedence
endpoint,
pingInterval,
pingTimeout,
)
if err != nil {
logger.Fatal("Failed to create olm: %v", err)
}
// wait until we have a client id and secret and endpoint
waitCount := 0
for id == "" || secret == "" || endpoint == "" {
select {
case <-ctx.Done():
logger.Info("Context cancelled while waiting for credentials")
return
default:
missing := []string{}
if id == "" {
missing = append(missing, "id")
}
if secret == "" {
missing = append(missing, "secret")
}
if endpoint == "" {
missing = append(missing, "endpoint")
}
waitCount++
if waitCount%10 == 1 { // Log every 10 seconds instead of every second
logger.Debug("Waiting for missing parameters: %v (waiting %d seconds)", missing, waitCount)
}
time.Sleep(1 * time.Second)
}
}
privateKey, err = wgtypes.GeneratePrivateKey()
if err != nil {
logger.Fatal("Failed to generate private key: %v", err)
}
// Create TUN device and network stack
var dev *device.Device
var wgData WgData
var holePunchData HolePunchData
var uapiListener net.Listener
var tdev tun.Device
sourcePort, err := FindAvailableUDPPort(49152, 65535)
if err != nil {
fmt.Printf("Error finding available port: %v\n", err)
os.Exit(1)
}
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &holePunchData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
// Create a new stopHolepunch channel for the new set of goroutines
stopHolepunch = make(chan struct{})
// Start a single hole punch goroutine for all exit nodes
logger.Info("Starting hole punch for %d exit nodes", len(holePunchData.ExitNodes))
go keepSendingUDPHolePunchToMultipleExitNodes(holePunchData.ExitNodes, id, sourcePort)
})
olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) {
// THIS ENDPOINT IS FOR BACKWARD COMPATIBILITY
logger.Debug("Received message: %v", msg.Data)
type LegacyHolePunchData struct {
ServerPubKey string `json:"serverPubKey"`
Endpoint string `json:"endpoint"`
}
var legacyHolePunchData LegacyHolePunchData
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &legacyHolePunchData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
// Stop any existing hole punch goroutines by closing the current channel
select {
case <-stopHolepunch:
// Channel already closed
default:
close(stopHolepunch)
}
// Create a new stopHolepunch channel for the new set of goroutines
stopHolepunch = make(chan struct{})
// Start hole punching for each exit node
logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey)
go keepSendingUDPHolePunch(legacyHolePunchData.Endpoint, id, sourcePort, legacyHolePunchData.ServerPubKey)
})
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
if connected {
logger.Info("Already connected. Ignoring new connection request.")
return
}
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
close(stopHolepunch)
// wait 10 milliseconds to ensure the previous connection is closed
logger.Debug("Waiting 500 milliseconds to ensure previous connection is closed")
time.Sleep(500 * time.Millisecond)
// if there is an existing tunnel then close it
if dev != nil {
logger.Info("Got new message. Closing existing tunnel!")
dev.Close()
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &wgData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
tdev, err = func() (tun.Device, error) {
if runtime.GOOS == "darwin" {
interfaceName, err := findUnusedUTUN()
if err != nil {
return nil, err
}
return tun.CreateTUN(interfaceName, mtu)
}
if tunFdStr := os.Getenv(ENV_WG_TUN_FD); tunFdStr != "" {
return createTUNFromFD(tunFdStr, mtu)
}
return tun.CreateTUN(interfaceName, mtu)
}()
if err != nil {
logger.Error("Failed to create TUN device: %v", err)
return
}
if realInterfaceName, err2 := tdev.Name(); err2 == nil {
interfaceName = realInterfaceName
}
fileUAPI, err := func() (*os.File, error) {
if uapiFdStr := os.Getenv(ENV_WG_UAPI_FD); uapiFdStr != "" {
fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
if err != nil {
return nil, err
}
return os.NewFile(uintptr(fd), ""), nil
}
return uapiOpen(interfaceName)
}()
if err != nil {
logger.Error("UAPI listen error: %v", err)
os.Exit(1)
return
}
dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
uapiListener, err = uapiListen(interfaceName, fileUAPI)
if err != nil {
logger.Error("Failed to listen on uapi socket: %v", err)
os.Exit(1)
}
go func() {
for {
conn, err := uapiListener.Accept()
if err != nil {
return
}
go dev.IpcHandle(conn)
}
}()
logger.Info("UAPI listener started")
if err = dev.Up(); err != nil {
logger.Error("Failed to bring up WireGuard device: %v", err)
}
if err = ConfigureInterface(interfaceName, wgData); err != nil {
logger.Error("Failed to configure interface: %v", err)
}
if httpServer != nil {
httpServer.SetTunnelIP(wgData.TunnelIP)
}
peerMonitor = peermonitor.NewPeerMonitor(
func(siteID int, connected bool, rtt time.Duration) {
if httpServer != nil {
// Find the site config to get endpoint information
var endpoint string
var isRelay bool
for _, site := range wgData.Sites {
if site.SiteId == siteID {
endpoint = site.Endpoint
// TODO: We'll need to track relay status separately
// For now, assume not using relay unless we get relay data
isRelay = !doHolepunch
break
}
}
httpServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay)
}
if connected {
logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt)
} else {
logger.Warn("Peer %d is disconnected", siteID)
}
},
fixKey(privateKey.String()),
olm,
dev,
doHolepunch,
)
for i := range wgData.Sites {
site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice
if httpServer != nil {
httpServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
}
// Format the endpoint before configuring the peer.
site.Endpoint = formatEndpoint(site.Endpoint)
if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil {
logger.Error("Failed to configure peer: %v", err)
return
}
if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil {
logger.Error("Failed to add route for peer: %v", err)
return
}
if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err)
return
}
logger.Info("Configured peer %s", site.PublicKey)
}
peerMonitor.Start()
connected = true
logger.Info("WireGuard device created.")
})
olm.RegisterHandler("olm/wg/peer/update", func(msg websocket.WSMessage) {
logger.Debug("Received update-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var updateData UpdatePeerData
if err := json.Unmarshal(jsonData, &updateData); err != nil {
logger.Error("Error unmarshaling update data: %v", err)
return
}
// Convert to SiteConfig
siteConfig := SiteConfig{
SiteId: updateData.SiteId,
Endpoint: updateData.Endpoint,
PublicKey: updateData.PublicKey,
ServerIP: updateData.ServerIP,
ServerPort: updateData.ServerPort,
RemoteSubnets: updateData.RemoteSubnets,
}
// Update the peer in WireGuard
if dev != nil {
// Find the existing peer to get old data
var oldRemoteSubnets string
var oldPublicKey string
for _, site := range wgData.Sites {
if site.SiteId == updateData.SiteId {
oldRemoteSubnets = site.RemoteSubnets
oldPublicKey = site.PublicKey
break
}
}
// If the public key has changed, remove the old peer first
if oldPublicKey != "" && oldPublicKey != updateData.PublicKey {
logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey)
if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil {
logger.Error("Failed to remove old peer: %v", err)
return
}
}
// Format the endpoint before updating the peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to update peer: %v", err)
return
}
// Remove old remote subnet routes if they changed
if oldRemoteSubnets != siteConfig.RemoteSubnets {
if err := removeRoutesForRemoteSubnets(oldRemoteSubnets); err != nil {
logger.Error("Failed to remove old remote subnet routes: %v", err)
// Continue anyway to add new routes
}
// Add new remote subnet routes
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add new remote subnet routes: %v", err)
return
}
}
// Update successful
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
for i := range wgData.Sites {
if wgData.Sites[i].SiteId == updateData.SiteId {
wgData.Sites[i] = siteConfig
break
}
}
} else {
logger.Error("WireGuard device not initialized")
}
})
// Handler for adding a new peer
olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) {
logger.Debug("Received add-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var addData AddPeerData
if err := json.Unmarshal(jsonData, &addData); err != nil {
logger.Error("Error unmarshaling add data: %v", err)
return
}
// Convert to SiteConfig
siteConfig := SiteConfig{
SiteId: addData.SiteId,
Endpoint: addData.Endpoint,
PublicKey: addData.PublicKey,
ServerIP: addData.ServerIP,
ServerPort: addData.ServerPort,
RemoteSubnets: addData.RemoteSubnets,
}
// Add the peer to WireGuard
if dev != nil {
// Format the endpoint before adding the new peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to add peer: %v", err)
return
}
if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil {
logger.Error("Failed to add route for new peer: %v", err)
return
}
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err)
return
}
// Add successful
logger.Info("Successfully added peer for site %d", addData.SiteId)
// Update WgData with the new peer
wgData.Sites = append(wgData.Sites, siteConfig)
} else {
logger.Error("WireGuard device not initialized")
}
})
// Handler for removing a peer
olm.RegisterHandler("olm/wg/peer/remove", func(msg websocket.WSMessage) {
logger.Debug("Received remove-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var removeData RemovePeerData
if err := json.Unmarshal(jsonData, &removeData); err != nil {
logger.Error("Error unmarshaling remove data: %v", err)
return
}
// Find the peer to remove
var peerToRemove *SiteConfig
var newSites []SiteConfig
for _, site := range wgData.Sites {
if site.SiteId == removeData.SiteId {
peerToRemove = &site
} else {
newSites = append(newSites, site)
}
}
if peerToRemove == nil {
logger.Error("Peer with site ID %d not found", removeData.SiteId)
return
}
// Remove the peer from WireGuard
if dev != nil {
if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
logger.Error("Failed to remove peer: %v", err)
// Send error response if needed
return
}
// Remove route for the peer
err = removeRouteForServerIP(peerToRemove.ServerIP)
if err != nil {
logger.Error("Failed to remove route for peer: %v", err)
return
}
// Remove routes for remote subnets
if err := removeRoutesForRemoteSubnets(peerToRemove.RemoteSubnets); err != nil {
logger.Error("Failed to remove routes for remote subnets: %v", err)
return
}
// Remove successful
logger.Info("Successfully removed peer for site %d", removeData.SiteId)
// Update WgData to remove the peer
wgData.Sites = newSites
} else {
logger.Error("WireGuard device not initialized")
}
})
olm.RegisterHandler("olm/wg/peer/relay", func(msg websocket.WSMessage) {
logger.Debug("Received relay-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var relayData RelayPeerData
if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err)
return
}
primaryRelay, err := resolveDomain(relayData.Endpoint)
if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
}
// Update HTTP server to mark this peer as using relay
if httpServer != nil {
httpServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
}
peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
})
olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) {
logger.Info("Received no-sites message - no sites available for connection")
// if stopRegister != nil {
// stopRegister()
// stopRegister = nil
// }
// select {
// case <-stopHolepunch:
// // Channel already closed, do nothing
// default:
// close(stopHolepunch)
// }
logger.Info("No sites available - stopped registration and holepunch processes")
})
olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) {
logger.Info("Received terminate message")
olm.Close()
})
olm.OnConnect(func() error {
logger.Info("Websocket Connected")
if httpServer != nil {
httpServer.SetConnectionStatus(true)
}
// CRITICAL: Save our full config AFTER websocket saves its limited config
// This ensures all 13 fields are preserved, not just the 4 that websocket saves
if err := SaveConfig(config); err != nil {
logger.Error("Failed to save full olm config: %v", err)
} else {
logger.Debug("Saved full olm config with all options")
}
if connected {
logger.Debug("Already connected, skipping registration")
return nil
}
publicKey := privateKey.PublicKey()
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch)
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !doHolepunch,
"olmVersion": olmVersion,
}, 1*time.Second)
go keepSendingPing(olm)
logger.Info("Sent registration message")
return nil
})
olm.OnTokenUpdate(func(token string) {
olmToken = token
})
// Connect to the WebSocket server
if err := olm.Connect(); err != nil {
logger.Fatal("Failed to connect to server: %v", err)
}
defer olm.Close()
// Wait for interrupt signal or context cancellation
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
select {
case <-sigCh:
logger.Info("Received interrupt signal")
case <-ctx.Done():
logger.Info("Context cancelled")
}
select {
case <-stopHolepunch:
// Channel already closed, do nothing
default:
close(stopHolepunch)
}
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
select {
case <-stopPing:
// Channel already closed
default:
close(stopPing)
}
if uapiListener != nil {
uapiListener.Close()
}
if dev != nil {
dev.Close()
}
logger.Info("runOlmMain() exiting")
fmt.Printf("runOlmMain() exiting\n")
olm.Run(ctx, olmConfig)
}

View File

@@ -1,9 +1,8 @@
package main
package olm
import (
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"net"
"os/exec"
@@ -17,10 +16,7 @@ import (
"github.com/fosrl/olm/peermonitor"
"github.com/fosrl/olm/websocket"
"github.com/vishvananda/netlink"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/curve25519"
"golang.org/x/exp/rand"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@@ -82,11 +78,6 @@ const (
ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND"
)
type fixedPortBind struct {
port uint16
conn.Bind
}
// PeerAction represents a request to add, update, or remove a peer
type PeerAction struct {
Action string `json:"action"` // "add", "update", or "remove"
@@ -124,16 +115,31 @@ type RelayPeerData struct {
PublicKey string `json:"publicKey"`
}
func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) {
// Ignore the requested port and use our fixed port
return b.Bind.Open(b.port)
}
func NewFixedPortBind(port uint16) conn.Bind {
return &fixedPortBind{
port: port,
Bind: conn.NewDefaultBind(),
// Helper function to format endpoints correctly
func formatEndpoint(endpoint string) string {
if endpoint == "" {
return ""
}
// Check if it's already a valid host:port that SplitHostPort can parse (e.g., [::1]:8080 or 1.2.3.4:8080)
_, _, err := net.SplitHostPort(endpoint)
if err == nil {
return endpoint // Already valid, no change needed
}
// If it failed, it might be our malformed "ipv6:port" string. Let's check and fix it.
lastColon := strings.LastIndex(endpoint, ":")
if lastColon > 0 { // Ensure there is a colon and it's not the first character
hostPart := endpoint[:lastColon]
// Check if the host part is a literal IPv6 address
if ip := net.ParseIP(hostPart); ip != nil && ip.To4() == nil {
// It is! Reformat it with brackets.
portPart := endpoint[lastColon+1:]
return fmt.Sprintf("[%s]:%s", hostPart, portPart)
}
}
// If it's not the specific malformed case, return it as is.
return endpoint
}
func fixKey(key string) string {
@@ -182,7 +188,7 @@ func mapToWireGuardLogLevel(level logger.LogLevel) int {
}
}
func resolveDomain(domain string) (string, error) {
func ResolveDomain(domain string) (string, error) {
// First handle any protocol prefix
domain = strings.TrimPrefix(strings.TrimPrefix(domain, "https://"), "http://")
@@ -229,273 +235,6 @@ func resolveDomain(domain string) (string, error) {
return ipAddr, nil
}
func sendUDPHolePunchWithConn(conn *net.UDPConn, remoteAddr *net.UDPAddr, olmID string, serverPubKey string) error {
if serverPubKey == "" || olmToken == "" {
return nil
}
payload := struct {
OlmID string `json:"olmId"`
Token string `json:"token"`
}{
OlmID: olmID,
Token: olmToken,
}
// Convert payload to JSON
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal payload: %v", err)
}
// Encrypt the payload using the server's WireGuard public key
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
if err != nil {
return fmt.Errorf("failed to encrypt payload: %v", err)
}
jsonData, err := json.Marshal(encryptedPayload)
if err != nil {
return fmt.Errorf("failed to marshal encrypted payload: %v", err)
}
_, err = conn.WriteToUDP(jsonData, remoteAddr)
if err != nil {
return fmt.Errorf("failed to send UDP packet: %v", err)
}
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
return nil
}
func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) {
// Generate an ephemeral keypair for this message
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
}
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
// Parse the server's public key
serverPubKey, err := wgtypes.ParseKey(serverPublicKey)
if err != nil {
return nil, fmt.Errorf("failed to parse server public key: %v", err)
}
// Use X25519 for key exchange (replacing deprecated ScalarMult)
var ephPrivKeyFixed [32]byte
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
// Perform X25519 key exchange
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
if err != nil {
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
}
// Create an AEAD cipher using the shared secret
aead, err := chacha20poly1305.New(sharedSecret)
if err != nil {
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
}
// Generate a random nonce
nonce := make([]byte, aead.NonceSize())
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %v", err)
}
// Encrypt the payload
ciphertext := aead.Seal(nil, nonce, payload, nil)
// Prepare the final encrypted message
encryptedMsg := struct {
EphemeralPublicKey string `json:"ephemeralPublicKey"`
Nonce []byte `json:"nonce"`
Ciphertext []byte `json:"ciphertext"`
}{
EphemeralPublicKey: ephemeralPublicKey.String(),
Nonce: nonce,
Ciphertext: ciphertext,
}
return encryptedMsg, nil
}
func keepSendingUDPHolePunchToMultipleExitNodes(exitNodes []ExitNode, olmID string, sourcePort uint16) {
if len(exitNodes) == 0 {
logger.Warn("No exit nodes provided for hole punching")
return
}
// Check if hole punching is already running
if holePunchRunning {
logger.Debug("UDP hole punch already running, skipping new request")
return
}
// Set the flag to indicate hole punching is running
holePunchRunning = true
defer func() {
holePunchRunning = false
logger.Info("UDP hole punch goroutine ended")
}()
logger.Info("Starting UDP hole punch to %d exit nodes", len(exitNodes))
defer logger.Info("UDP hole punch goroutine ended for all exit nodes")
// Create the UDP connection once and reuse it for all exit nodes
localAddr := &net.UDPAddr{
Port: int(sourcePort),
IP: net.IPv4zero,
}
conn, err := net.ListenUDP("udp", localAddr)
if err != nil {
logger.Error("Failed to bind UDP socket: %v", err)
return
}
defer conn.Close()
// Resolve all endpoints upfront
type resolvedExitNode struct {
remoteAddr *net.UDPAddr
publicKey string
endpointName string
}
var resolvedNodes []resolvedExitNode
for _, exitNode := range exitNodes {
host, err := resolveDomain(exitNode.Endpoint)
if err != nil {
logger.Error("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
continue
}
serverAddr := net.JoinHostPort(host, "21820")
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address for %s: %v", exitNode.Endpoint, err)
continue
}
resolvedNodes = append(resolvedNodes, resolvedExitNode{
remoteAddr: remoteAddr,
publicKey: exitNode.PublicKey,
endpointName: exitNode.Endpoint,
})
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
}
if len(resolvedNodes) == 0 {
logger.Error("No exit nodes could be resolved")
return
}
// Send initial hole punch to all exit nodes
for _, node := range resolvedNodes {
if err := sendUDPHolePunchWithConn(conn, node.remoteAddr, olmID, node.publicKey); err != nil {
logger.Error("Failed to send initial UDP hole punch to %s: %v", node.endpointName, err)
}
}
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
timeout := time.NewTimer(15 * time.Second)
defer timeout.Stop()
for {
select {
case <-stopHolepunch:
logger.Info("Stopping UDP holepunch for all exit nodes")
return
case <-timeout.C:
logger.Info("UDP holepunch routine timed out after 15 seconds for all exit nodes")
return
case <-ticker.C:
// Send hole punch to all exit nodes
for _, node := range resolvedNodes {
if err := sendUDPHolePunchWithConn(conn, node.remoteAddr, olmID, node.publicKey); err != nil {
logger.Error("Failed to send UDP hole punch to %s: %v", node.endpointName, err)
}
}
}
}
}
func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16, serverPubKey string) {
// Check if hole punching is already running
if holePunchRunning {
logger.Debug("UDP hole punch already running, skipping new request")
return
}
// Set the flag to indicate hole punching is running
holePunchRunning = true
defer func() {
holePunchRunning = false
logger.Info("UDP hole punch goroutine ended")
}()
logger.Info("Starting UDP hole punch to %s", endpoint)
defer logger.Info("UDP hole punch goroutine ended for %s", endpoint)
host, err := resolveDomain(endpoint)
if err != nil {
logger.Error("Failed to resolve endpoint: %v", err)
return
}
serverAddr := net.JoinHostPort(host, "21820")
// Create the UDP connection once and reuse it
localAddr := &net.UDPAddr{
Port: int(sourcePort),
IP: net.IPv4zero,
}
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address: %v", err)
return
}
conn, err := net.ListenUDP("udp", localAddr)
if err != nil {
logger.Error("Failed to bind UDP socket: %v", err)
return
}
defer conn.Close()
// Execute once immediately before starting the loop
if err := sendUDPHolePunchWithConn(conn, remoteAddr, olmID, serverPubKey); err != nil {
logger.Error("Failed to send UDP hole punch: %v", err)
}
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
timeout := time.NewTimer(15 * time.Second)
defer timeout.Stop()
for {
select {
case <-stopHolepunch:
logger.Info("Stopping UDP holepunch")
return
case <-timeout.C:
logger.Info("UDP holepunch routine timed out after 15 seconds")
return
case <-ticker.C:
if err := sendUDPHolePunchWithConn(conn, remoteAddr, olmID, serverPubKey); err != nil {
logger.Error("Failed to send UDP hole punch: %v", err)
}
}
}
}
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
if maxPort < minPort {
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
@@ -535,6 +274,7 @@ func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
func sendPing(olm *websocket.Client) error {
err := olm.SendMessage("olm/ping", map[string]interface{}{
"timestamp": time.Now().Unix(),
"userToken": olm.GetConfig().UserToken,
})
if err != nil {
logger.Error("Failed to send ping message: %v", err)
@@ -571,7 +311,7 @@ func keepSendingPing(olm *websocket.Client) {
// ConfigurePeer sets up or updates a peer within the WireGuard device
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string) error {
siteHost, err := resolveDomain(siteConfig.Endpoint)
siteHost, err := ResolveDomain(siteConfig.Endpoint)
if err != nil {
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
}
@@ -628,7 +368,7 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
logger.Debug("Setting up peer monitor for site %d at %s", siteConfig.SiteId, monitorPeer)
primaryRelay, err := resolveDomain(endpoint) // Using global endpoint variable
primaryRelay, err := ResolveDomain(endpoint) // Using global endpoint variable
if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
}

892
olm/olm.go Normal file
View File

@@ -0,0 +1,892 @@
package olm
import (
"context"
"encoding/json"
"net"
"os"
"runtime"
"strconv"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/updates"
"github.com/fosrl/olm/api"
"github.com/fosrl/olm/bind"
"github.com/fosrl/olm/holepunch"
"github.com/fosrl/olm/peermonitor"
"github.com/fosrl/olm/websocket"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type Config struct {
// Connection settings
Endpoint string
ID string
Secret string
UserToken string
// Network settings
MTU int
DNS string
InterfaceName string
// Logging
LogLevel string
// HTTP server
EnableAPI bool
HTTPAddr string
SocketPath string
// Advanced
Holepunch bool
TlsClientCert string
// Parsed values (not in JSON)
PingIntervalDuration time.Duration
PingTimeoutDuration time.Duration
// Source tracking (not in JSON)
sources map[string]string
Version string
OrgID string
// DoNotCreateNewClient bool
}
var (
privateKey wgtypes.Key
connected bool
dev *device.Device
wgData WgData
holePunchData HolePunchData
uapiListener net.Listener
tdev tun.Device
apiServer *api.API
olmClient *websocket.Client
tunnelCancel context.CancelFunc
tunnelRunning bool
sharedBind *bind.SharedBind
holePunchManager *holepunch.Manager
)
func Run(ctx context.Context, config Config) {
// Create a cancellable context for internal shutdown control
ctx, cancel := context.WithCancel(ctx)
defer cancel()
logger.GetLogger().SetLevel(parseLogLevel(config.LogLevel))
if err := updates.CheckForUpdate("fosrl", "olm", config.Version); err != nil {
logger.Debug("Failed to check for updates: %v", err)
}
if config.Holepunch {
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
}
if config.HTTPAddr != "" {
apiServer = api.NewAPI(config.HTTPAddr)
} else if config.SocketPath != "" {
apiServer = api.NewAPISocket(config.SocketPath)
}
apiServer.SetVersion(config.Version)
apiServer.SetOrgID(config.OrgID)
if err := apiServer.Start(); err != nil {
logger.Fatal("Failed to start HTTP server: %v", err)
}
// Listen for shutdown requests from the API
go func() {
<-apiServer.GetShutdownChannel()
logger.Info("Shutdown requested via API")
// Cancel the context to trigger graceful shutdown
cancel()
}()
var (
id = config.ID
secret = config.Secret
endpoint = config.Endpoint
userToken = config.UserToken
)
// Main event loop that handles connect, disconnect, and reconnect
for {
select {
case <-ctx.Done():
logger.Info("Context cancelled while waiting for credentials")
goto shutdown
case req := <-apiServer.GetConnectionChannel():
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
// Stop any existing tunnel before starting a new one
if olmClient != nil {
logger.Info("Stopping existing tunnel before starting new connection")
StopTunnel()
}
// Set the connection parameters
id = req.ID
secret = req.Secret
endpoint = req.Endpoint
userToken := req.UserToken
// Start the tunnel process with the new credentials
if id != "" && secret != "" && endpoint != "" {
logger.Info("Starting tunnel with new credentials")
tunnelRunning = true
go TunnelProcess(ctx, config, id, secret, userToken, endpoint)
}
case <-apiServer.GetDisconnectChannel():
logger.Info("Received disconnect request via API")
StopTunnel()
// Clear credentials so we wait for new connect call
id = ""
secret = ""
endpoint = ""
userToken = ""
default:
// If we have credentials and no tunnel is running, start it
if id != "" && secret != "" && endpoint != "" && !tunnelRunning {
logger.Info("Starting tunnel process with initial credentials")
tunnelRunning = true
go TunnelProcess(ctx, config, id, secret, userToken, endpoint)
} else if id == "" || secret == "" || endpoint == "" {
// If we don't have credentials, check if API is enabled
if !config.EnableAPI {
missing := []string{}
if id == "" {
missing = append(missing, "id")
}
if secret == "" {
missing = append(missing, "secret")
}
if endpoint == "" {
missing = append(missing, "endpoint")
}
// exit the application because there is no way to provide the missing parameters
logger.Fatal("Missing required parameters: %v and API is not enabled to provide them", missing)
goto shutdown
}
}
// Sleep briefly to prevent tight loop
time.Sleep(100 * time.Millisecond)
}
}
shutdown:
Stop()
apiServer.Stop()
logger.Info("Olm service shutting down")
}
func TunnelProcess(ctx context.Context, config Config, id string, secret string, userToken string, endpoint string) {
// Create a cancellable context for this tunnel process
tunnelCtx, cancel := context.WithCancel(ctx)
tunnelCancel = cancel
defer func() {
tunnelCancel = nil
}()
// Recreate channels for this tunnel session
stopPing = make(chan struct{})
var (
interfaceName = config.InterfaceName
loggerLevel = parseLogLevel(config.LogLevel)
)
// Create a new olm client using the provided credentials
olm, err := websocket.NewClient(
id, // Use provided ID
secret, // Use provided secret
userToken, // Use provided user token OPTIONAL
endpoint, // Use provided endpoint
config.PingIntervalDuration,
config.PingTimeoutDuration,
)
if err != nil {
logger.Error("Failed to create olm: %v", err)
return
}
// Store the client reference globally
olmClient = olm
privateKey, err = wgtypes.GeneratePrivateKey()
if err != nil {
logger.Error("Failed to generate private key: %v", err)
return
}
// Create shared UDP socket for both holepunch and WireGuard
if sharedBind == nil {
sourcePort, err := FindAvailableUDPPort(49152, 65535)
if err != nil {
logger.Error("Error finding available port: %v", err)
return
}
localAddr := &net.UDPAddr{
Port: int(sourcePort),
IP: net.IPv4zero,
}
udpConn, err := net.ListenUDP("udp", localAddr)
if err != nil {
logger.Error("Failed to create shared UDP socket: %v", err)
return
}
sharedBind, err = bind.New(udpConn)
if err != nil {
logger.Error("Failed to create shared bind: %v", err)
udpConn.Close()
return
}
// Add a reference for the hole punch senders (creator already has one reference for WireGuard)
sharedBind.AddRef()
logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount())
}
// Create the holepunch manager
if holePunchManager == nil {
holePunchManager = holepunch.NewManager(sharedBind, id, ResolveDomain)
}
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &holePunchData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
// Convert HolePunchData.ExitNodes to holepunch.ExitNode slice
exitNodes := make([]holepunch.ExitNode, len(holePunchData.ExitNodes))
for i, node := range holePunchData.ExitNodes {
exitNodes[i] = holepunch.ExitNode{
Endpoint: node.Endpoint,
PublicKey: node.PublicKey,
}
}
// Start hole punching using the manager
logger.Info("Starting hole punch for %d exit nodes", len(exitNodes))
if err := holePunchManager.StartMultipleExitNodes(exitNodes); err != nil {
logger.Warn("Failed to start hole punch: %v", err)
}
})
olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) {
// THIS ENDPOINT IS FOR BACKWARD COMPATIBILITY
logger.Debug("Received message: %v", msg.Data)
type LegacyHolePunchData struct {
ServerPubKey string `json:"serverPubKey"`
Endpoint string `json:"endpoint"`
}
var legacyHolePunchData LegacyHolePunchData
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &legacyHolePunchData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
// Stop any existing hole punch operations
if holePunchManager != nil {
holePunchManager.Stop()
}
// Start hole punching for the exit node
logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey)
if err := holePunchManager.StartSingleEndpoint(legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey); err != nil {
logger.Warn("Failed to start hole punch: %v", err)
}
})
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
if connected {
logger.Info("Already connected. Ignoring new connection request.")
return
}
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
// wait 10 milliseconds to ensure the previous connection is closed
logger.Debug("Waiting 500 milliseconds to ensure previous connection is closed")
time.Sleep(500 * time.Millisecond)
// if there is an existing tunnel then close it
if dev != nil {
logger.Info("Got new message. Closing existing tunnel!")
dev.Close()
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &wgData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return
}
tdev, err = func() (tun.Device, error) {
if runtime.GOOS == "darwin" {
interfaceName, err := findUnusedUTUN()
if err != nil {
return nil, err
}
return tun.CreateTUN(interfaceName, config.MTU)
}
if tunFdStr := os.Getenv(ENV_WG_TUN_FD); tunFdStr != "" {
return createTUNFromFD(tunFdStr, config.MTU)
}
return tun.CreateTUN(interfaceName, config.MTU)
}()
if err != nil {
logger.Error("Failed to create TUN device: %v", err)
return
}
if realInterfaceName, err2 := tdev.Name(); err2 == nil {
interfaceName = realInterfaceName
}
fileUAPI, err := func() (*os.File, error) {
if uapiFdStr := os.Getenv(ENV_WG_UAPI_FD); uapiFdStr != "" {
fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
if err != nil {
return nil, err
}
return os.NewFile(uintptr(fd), ""), nil
}
return uapiOpen(interfaceName)
}()
if err != nil {
logger.Error("UAPI listen error: %v", err)
os.Exit(1)
return
}
dev = device.NewDevice(tdev, sharedBind, device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
uapiListener, err = uapiListen(interfaceName, fileUAPI)
if err != nil {
logger.Error("Failed to listen on uapi socket: %v", err)
os.Exit(1)
}
go func() {
for {
conn, err := uapiListener.Accept()
if err != nil {
return
}
go dev.IpcHandle(conn)
}
}()
logger.Info("UAPI listener started")
if err = dev.Up(); err != nil {
logger.Error("Failed to bring up WireGuard device: %v", err)
}
if err = ConfigureInterface(interfaceName, wgData); err != nil {
logger.Error("Failed to configure interface: %v", err)
}
apiServer.SetTunnelIP(wgData.TunnelIP)
peerMonitor = peermonitor.NewPeerMonitor(
func(siteID int, connected bool, rtt time.Duration) {
// Find the site config to get endpoint information
var endpoint string
var isRelay bool
for _, site := range wgData.Sites {
if site.SiteId == siteID {
endpoint = site.Endpoint
// TODO: We'll need to track relay status separately
// For now, assume not using relay unless we get relay data
isRelay = !config.Holepunch
break
}
}
apiServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay)
if connected {
logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt)
} else {
logger.Warn("Peer %d is disconnected", siteID)
}
},
fixKey(privateKey.String()),
olm,
dev,
config.Holepunch,
)
for i := range wgData.Sites {
site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice
apiServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
// Format the endpoint before configuring the peer.
site.Endpoint = formatEndpoint(site.Endpoint)
if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil {
logger.Error("Failed to configure peer: %v", err)
return
}
if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil {
logger.Error("Failed to add route for peer: %v", err)
return
}
if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err)
return
}
logger.Info("Configured peer %s", site.PublicKey)
}
peerMonitor.Start()
apiServer.SetRegistered(true)
connected = true
logger.Info("WireGuard device created.")
})
olm.RegisterHandler("olm/wg/peer/update", func(msg websocket.WSMessage) {
logger.Debug("Received update-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var updateData UpdatePeerData
if err := json.Unmarshal(jsonData, &updateData); err != nil {
logger.Error("Error unmarshaling update data: %v", err)
return
}
// Convert to SiteConfig
siteConfig := SiteConfig{
SiteId: updateData.SiteId,
Endpoint: updateData.Endpoint,
PublicKey: updateData.PublicKey,
ServerIP: updateData.ServerIP,
ServerPort: updateData.ServerPort,
RemoteSubnets: updateData.RemoteSubnets,
}
// Update the peer in WireGuard
if dev != nil {
// Find the existing peer to get old data
var oldRemoteSubnets string
var oldPublicKey string
for _, site := range wgData.Sites {
if site.SiteId == updateData.SiteId {
oldRemoteSubnets = site.RemoteSubnets
oldPublicKey = site.PublicKey
break
}
}
// If the public key has changed, remove the old peer first
if oldPublicKey != "" && oldPublicKey != updateData.PublicKey {
logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey)
if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil {
logger.Error("Failed to remove old peer: %v", err)
return
}
}
// Format the endpoint before updating the peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to update peer: %v", err)
return
}
// Remove old remote subnet routes if they changed
if oldRemoteSubnets != siteConfig.RemoteSubnets {
if err := removeRoutesForRemoteSubnets(oldRemoteSubnets); err != nil {
logger.Error("Failed to remove old remote subnet routes: %v", err)
// Continue anyway to add new routes
}
// Add new remote subnet routes
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add new remote subnet routes: %v", err)
return
}
}
// Update successful
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
for i := range wgData.Sites {
if wgData.Sites[i].SiteId == updateData.SiteId {
wgData.Sites[i] = siteConfig
break
}
}
} else {
logger.Error("WireGuard device not initialized")
}
})
// Handler for adding a new peer
olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) {
logger.Debug("Received add-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var addData AddPeerData
if err := json.Unmarshal(jsonData, &addData); err != nil {
logger.Error("Error unmarshaling add data: %v", err)
return
}
// Convert to SiteConfig
siteConfig := SiteConfig{
SiteId: addData.SiteId,
Endpoint: addData.Endpoint,
PublicKey: addData.PublicKey,
ServerIP: addData.ServerIP,
ServerPort: addData.ServerPort,
RemoteSubnets: addData.RemoteSubnets,
}
// Add the peer to WireGuard
if dev != nil {
// Format the endpoint before adding the new peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to add peer: %v", err)
return
}
if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil {
logger.Error("Failed to add route for new peer: %v", err)
return
}
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err)
return
}
// Add successful
logger.Info("Successfully added peer for site %d", addData.SiteId)
// Update WgData with the new peer
wgData.Sites = append(wgData.Sites, siteConfig)
} else {
logger.Error("WireGuard device not initialized")
}
})
// Handler for removing a peer
olm.RegisterHandler("olm/wg/peer/remove", func(msg websocket.WSMessage) {
logger.Debug("Received remove-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var removeData RemovePeerData
if err := json.Unmarshal(jsonData, &removeData); err != nil {
logger.Error("Error unmarshaling remove data: %v", err)
return
}
// Find the peer to remove
var peerToRemove *SiteConfig
var newSites []SiteConfig
for _, site := range wgData.Sites {
if site.SiteId == removeData.SiteId {
peerToRemove = &site
} else {
newSites = append(newSites, site)
}
}
if peerToRemove == nil {
logger.Error("Peer with site ID %d not found", removeData.SiteId)
return
}
// Remove the peer from WireGuard
if dev != nil {
if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
logger.Error("Failed to remove peer: %v", err)
// Send error response if needed
return
}
// Remove route for the peer
err = removeRouteForServerIP(peerToRemove.ServerIP)
if err != nil {
logger.Error("Failed to remove route for peer: %v", err)
return
}
// Remove routes for remote subnets
if err := removeRoutesForRemoteSubnets(peerToRemove.RemoteSubnets); err != nil {
logger.Error("Failed to remove routes for remote subnets: %v", err)
return
}
// Remove successful
logger.Info("Successfully removed peer for site %d", removeData.SiteId)
// Update WgData to remove the peer
wgData.Sites = newSites
} else {
logger.Error("WireGuard device not initialized")
}
})
olm.RegisterHandler("olm/wg/peer/relay", func(msg websocket.WSMessage) {
logger.Debug("Received relay-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var relayData RelayPeerData
if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err)
return
}
primaryRelay, err := ResolveDomain(relayData.Endpoint)
if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
}
// Update HTTP server to mark this peer as using relay
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
})
olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) {
logger.Info("Received no-sites message - no sites available for connection")
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
logger.Info("No sites available - stopped registration and holepunch processes")
})
olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) {
logger.Info("Received terminate message")
olm.Close()
})
olm.OnConnect(func() error {
logger.Info("Websocket Connected")
apiServer.SetConnectionStatus(true)
if connected {
logger.Debug("Already connected, skipping registration")
return nil
}
publicKey := privateKey.PublicKey()
if stopRegister == nil {
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch)
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !config.Holepunch,
"olmVersion": config.Version,
"orgId": config.OrgID,
// "doNotCreateNewClient": config.DoNotCreateNewClient,
}, 1*time.Second)
}
go keepSendingPing(olm)
return nil
})
olm.OnTokenUpdate(func(token string) {
if holePunchManager != nil {
holePunchManager.SetToken(token)
}
})
// Connect to the WebSocket server
if err := olm.Connect(); err != nil {
logger.Error("Failed to connect to server: %v", err)
return
}
defer olm.Close()
// Listen for org switch requests from the API
go func() {
for req := range apiServer.GetSwitchOrgChannel() {
logger.Info("Processing org switch request to orgId: %s", req.OrgID)
// Update the config with the new orgId
config.OrgID = req.OrgID
// Mark as not connected to trigger re-registration
connected = false
Stop()
// Clear peer statuses in API
apiServer.SetRegistered(false)
apiServer.SetTunnelIP("")
apiServer.SetOrgID(config.OrgID)
// Trigger re-registration with new orgId
logger.Info("Re-registering with new orgId: %s", config.OrgID)
publicKey := privateKey.PublicKey()
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !config.Holepunch,
"olmVersion": config.Version,
"orgId": config.OrgID,
}, 1*time.Second)
}
}()
// Wait for context cancellation
<-tunnelCtx.Done()
logger.Info("Tunnel process context cancelled, cleaning up")
}
func Stop() {
// Stop hole punch manager
if holePunchManager != nil {
holePunchManager.Stop()
}
if stopPing != nil {
select {
case <-stopPing:
// Channel already closed
default:
close(stopPing)
}
}
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
if peerMonitor != nil {
peerMonitor.Stop()
peerMonitor = nil
}
if uapiListener != nil {
uapiListener.Close()
uapiListener = nil
}
if dev != nil {
dev.Close() // This will call sharedBind.Close() which releases WireGuard's reference
dev = nil
}
// Close TUN device
if tdev != nil {
tdev.Close()
tdev = nil
}
// Release the hole punch reference to the shared bind
if sharedBind != nil {
// Release hole punch reference (WireGuard already released its reference via dev.Close())
logger.Debug("Releasing shared bind (refcount before release: %d)", sharedBind.GetRefCount())
sharedBind.Release()
sharedBind = nil
logger.Info("Released shared UDP bind")
}
logger.Info("Olm service stopped")
}
// StopTunnel stops just the tunnel process and websocket connection
// without shutting down the entire application
func StopTunnel() {
logger.Info("Stopping tunnel process")
// Cancel the tunnel context if it exists
if tunnelCancel != nil {
tunnelCancel()
// Give it a moment to clean up
time.Sleep(200 * time.Millisecond)
}
// Close the websocket connection
if olmClient != nil {
olmClient.Close()
olmClient = nil
}
Stop()
// Reset the connected state
connected = false
tunnelRunning = false
// Update API server status
apiServer.SetConnectionStatus(false)
apiServer.SetRegistered(false)
apiServer.SetTunnelIP("")
logger.Info("Tunnel process stopped")
}

View File

@@ -1,6 +1,6 @@
//go:build !windows
package main
package olm
import (
"net"

View File

@@ -1,6 +1,6 @@
//go:build windows
package main
package olm
import (
"errors"

View File

@@ -39,6 +39,7 @@ type Config struct {
Secret string
Endpoint string
TlsClientCert string // legacy PKCS12 file path
UserToken string // optional user token for websocket authentication
}
type Client struct {
@@ -103,11 +104,12 @@ func (c *Client) OnTokenUpdate(callback func(token string)) {
}
// NewClient creates a new websocket client
func NewClient(clientType string, ID, secret string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
func NewClient(ID, secret string, userToken string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
config := &Config{
ID: ID,
Secret: secret,
Endpoint: endpoint,
ID: ID,
Secret: secret,
Endpoint: endpoint,
UserToken: userToken,
}
client := &Client{
@@ -119,7 +121,7 @@ func NewClient(clientType string, ID, secret string, endpoint string, pingInterv
isConnected: false,
pingInterval: pingInterval,
pingTimeout: pingTimeout,
clientType: clientType,
clientType: "olm",
}
// Apply options before loading config
@@ -263,17 +265,9 @@ func (c *Client) getToken() (string, error) {
var tokenData map[string]interface{}
// Get a new token
if c.clientType == "newt" {
tokenData = map[string]interface{}{
"newtId": c.config.ID,
"secret": c.config.Secret,
}
} else if c.clientType == "olm" {
tokenData = map[string]interface{}{
"olmId": c.config.ID,
"secret": c.config.Secret,
}
tokenData = map[string]interface{}{
"olmId": c.config.ID,
"secret": c.config.Secret,
}
jsonData, err := json.Marshal(tokenData)
@@ -384,6 +378,9 @@ func (c *Client) establishConnection() error {
q := u.Query()
q.Set("token", token)
q.Set("clientType", c.clientType)
if c.config.UserToken != "" {
q.Set("userToken", c.config.UserToken)
}
u.RawQuery = q.Encode()
// Connect to WebSocket