diff --git a/client/internal/dns/network_manager_linux.go b/client/internal/dns/network_manager_linux.go index b9d5ff24f..78ea6fb7e 100644 --- a/client/internal/dns/network_manager_linux.go +++ b/client/internal/dns/network_manager_linux.go @@ -76,12 +76,12 @@ func newNetworkManagerDbusConfigurator(wgInterface *iface.WGIface) (hostManager, } defer closeConn() var s string - err = obj.Call(networkManagerDbusGetDeviceByIPIfaceMethod, dbusDefaultFlag, wgInterface.GetName()).Store(&s) + err = obj.Call(networkManagerDbusGetDeviceByIPIfaceMethod, dbusDefaultFlag, wgInterface.Name()).Store(&s) if err != nil { return nil, err } - log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface.GetName()) + log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface.Name()) return &networkManagerDbusConfigurator{ dbusLinkObject: dbus.ObjectPath(s), diff --git a/client/internal/dns/resolvconf_linux.go b/client/internal/dns/resolvconf_linux.go index c1802858d..44e68c725 100644 --- a/client/internal/dns/resolvconf_linux.go +++ b/client/internal/dns/resolvconf_linux.go @@ -17,7 +17,7 @@ type resolvconf struct { func newResolvConfConfigurator(wgInterface *iface.WGIface) (hostManager, error) { return &resolvconf{ - ifaceName: wgInterface.GetName(), + ifaceName: wgInterface.Name(), }, nil } diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index a267ab94e..7a9d5301e 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -136,7 +136,7 @@ func (s *DefaultServer) Start() { func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) { ips := []string{defaultIP, customIP} if runtime.GOOS != "darwin" && s.wgInterface != nil { - ips = append([]string{s.wgInterface.GetAddress().IP.String()}, ips...) + ips = append([]string{s.wgInterface.Address().IP.String()}, ips...) } ports := []int{defaultPort, customPort} for _, port := range ports { diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index fbb16fe64..235b6c504 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -51,7 +51,7 @@ type systemdDbusLinkDomainsInput struct { } func newSystemdDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) { - iface, err := net.InterfaceByName(wgInterface.GetName()) + iface, err := net.InterfaceByName(wgInterface.Name()) if err != nil { return nil, err } diff --git a/client/internal/engine.go b/client/internal/engine.go index 0594082e3..d629e6b18 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -158,12 +158,10 @@ func (e *Engine) Stop() error { time.Sleep(500 * time.Millisecond) log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) - if e.wgInterface.Interface != nil { - err = e.wgInterface.Close() - if err != nil { - log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err) - return err - } + err = e.wgInterface.Close() + if err != nil { + log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err) + return err } if e.udpMux != nil { @@ -501,7 +499,7 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { //nil sshServer means it has not yet been started var err error e.sshServer, err = e.sshServerFunc(e.config.SSHKey, - fmt.Sprintf("%s:%d", e.wgInterface.Address.IP.String(), nbssh.DefaultSSHPort)) + fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)) if err != nil { return err } @@ -534,8 +532,8 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { } func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { - if e.wgInterface.Address.String() != conf.Address { - oldAddr := e.wgInterface.Address.String() + if e.wgInterface.Address().String() != conf.Address { + oldAddr := e.wgInterface.Address().String() log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address) err := e.wgInterface.UpdateAddr(conf.Address) if err != nil { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 611c0a05c..c4606aebe 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -857,7 +857,7 @@ loop: } // cleanup test for n, peerEngine := range engines { - t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name, n) + t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n) errStop := peerEngine.mgmClient.Close() if errStop != nil { log.Infoln("got error trying to close management clients from engine: ", errStop) @@ -905,7 +905,7 @@ func Test_ParseNATExternalIPMappings(t *testing.T) { expectedOutput: []string{"1.1.1.1", "8.8.8.8/" + testingIP}, }, { - name: "Only Interface Name Should Return Nil", + name: "Only Interface name Should Return Nil", inputBlacklistInterface: defaultInterfaceBlacklist, inputMapList: []string{testingInterface}, expectedOutput: nil, diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index e35368060..75a10e1f8 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -162,7 +162,7 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error { if err != nil { return err } - err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.GetAddress().IP.String()) + err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String()) if err != nil { return fmt.Errorf("couldn't remove route %s from system, err: %v", c.network, err) @@ -201,10 +201,10 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { return err } } else { - err = addToRouteTableIfNoExists(c.network, c.wgInterface.GetAddress().IP.String()) + err = addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String()) if err != nil { return fmt.Errorf("route %s couldn't be added for peer %s, err: %v", - c.network.String(), c.wgInterface.GetAddress().IP.String(), err) + c.network.String(), c.wgInterface.Address().IP.String(), err) } } diff --git a/client/internal/routemanager/server.go b/client/internal/routemanager/server.go index 0bfd1cec5..a8445894b 100644 --- a/client/internal/routemanager/server.go +++ b/client/internal/routemanager/server.go @@ -40,7 +40,7 @@ func (m *DefaultManager) removeFromServerNetwork(route *route.Route) error { default: m.serverRouter.mux.Lock() defer m.serverRouter.mux.Unlock() - err := m.serverRouter.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route)) + err := m.serverRouter.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route)) if err != nil { return err } @@ -57,7 +57,7 @@ func (m *DefaultManager) addToServerNetwork(route *route.Route) error { default: m.serverRouter.mux.Lock() defer m.serverRouter.mux.Unlock() - err := m.serverRouter.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route)) + err := m.serverRouter.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route)) if err != nil { return err } diff --git a/client/internal/routemanager/systemops_test.go b/client/internal/routemanager/systemops_test.go index 8f6f4c6fe..234a4973d 100644 --- a/client/internal/routemanager/systemops_test.go +++ b/client/internal/routemanager/systemops_test.go @@ -39,18 +39,18 @@ func TestAddRemoveRoutes(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") - err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.GetAddress().IP.String()) + err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String()) require.NoError(t, err, "should not return err") prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix) require.NoError(t, err, "should not return err") if testCase.shouldRouteToWireguard { - require.Equal(t, wgInterface.GetAddress().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") + require.Equal(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") } else { - require.NotEqual(t, wgInterface.GetAddress().IP.String(), prefixGateway.String(), "route should point to a different interface") + require.NotEqual(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to a different interface") } - err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.GetAddress().IP.String()) + err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String()) require.NoError(t, err, "should not return err") prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix) diff --git a/iface/address.go b/iface/address.go new file mode 100644 index 000000000..5ff4fbc06 --- /dev/null +++ b/iface/address.go @@ -0,0 +1,29 @@ +package iface + +import ( + "fmt" + "net" +) + +// WGAddress Wireguard parsed address +type WGAddress struct { + IP net.IP + Network *net.IPNet +} + +// parseWGAddress parse a string ("1.2.3.4/24") address to WG Address +func parseWGAddress(address string) (WGAddress, error) { + ip, network, err := net.ParseCIDR(address) + if err != nil { + return WGAddress{}, err + } + return WGAddress{ + IP: ip, + Network: network, + }, nil +} + +func (addr WGAddress) String() string { + maskSize, _ := addr.Network.Mask.Size() + return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize) +} diff --git a/iface/configuration.go b/iface/configuration.go deleted file mode 100644 index 07adc1555..000000000 --- a/iface/configuration.go +++ /dev/null @@ -1,258 +0,0 @@ -package iface - -import ( - "fmt" - log "github.com/sirupsen/logrus" - "golang.zx2c4.com/wireguard/wgctrl" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "net" - "time" -) - -// GetName returns the interface name -func (w *WGIface) GetName() string { - return w.Name -} - -// GetAddress returns the interface address -func (w *WGIface) GetAddress() WGAddress { - return w.Address -} - -// configureDevice configures the wireguard device -func (w *WGIface) configureDevice(config wgtypes.Config) error { - wg, err := wgctrl.New() - if err != nil { - return err - } - defer wg.Close() - - // validate if device with name exists - _, err = wg.Device(w.Name) - if err != nil { - return err - } - log.Debugf("got Wireguard device %s", w.Name) - - return wg.ConfigureDevice(w.Name, config) -} - -// Configure configures a Wireguard interface -// The interface must exist before calling this method (e.g. call interface.Create() before) -func (w *WGIface) Configure(privateKey string, port int) error { - w.mu.Lock() - defer w.mu.Unlock() - - log.Debugf("configuring Wireguard interface %s", w.Name) - - log.Debugf("adding Wireguard private key") - key, err := wgtypes.ParseKey(privateKey) - if err != nil { - return err - } - fwmark := 0 - config := wgtypes.Config{ - PrivateKey: &key, - ReplacePeers: true, - FirewallMark: &fwmark, - ListenPort: &port, - } - - err = w.configureDevice(config) - if err != nil { - return fmt.Errorf("received error \"%v\" while configuring interface %s with port %d", err, w.Name, port) - } - return nil -} - -// GetListenPort returns the listening port of the Wireguard endpoint -func (w *WGIface) GetListenPort() (*int, error) { - log.Debugf("getting Wireguard listen port of interface %s", w.Name) - - //discover Wireguard current configuration - wg, err := wgctrl.New() - if err != nil { - return nil, err - } - defer wg.Close() - - d, err := wg.Device(w.Name) - if err != nil { - return nil, err - } - log.Debugf("got Wireguard device listen port %s, %d", w.Name, d.ListenPort) - - return &d.ListenPort, nil -} - -// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist -// Endpoint is optional -func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { - w.mu.Lock() - defer w.mu.Unlock() - - log.Debugf("updating interface %s peer %s: endpoint %s ", w.Name, peerKey, endpoint) - - //parse allowed ips - _, ipNet, err := net.ParseCIDR(allowedIps) - if err != nil { - return err - } - - peerKeyParsed, err := wgtypes.ParseKey(peerKey) - if err != nil { - return err - } - peer := wgtypes.PeerConfig{ - PublicKey: peerKeyParsed, - ReplaceAllowedIPs: true, - AllowedIPs: []net.IPNet{*ipNet}, - PersistentKeepaliveInterval: &keepAlive, - PresharedKey: preSharedKey, - Endpoint: endpoint, - } - - config := wgtypes.Config{ - Peers: []wgtypes.PeerConfig{peer}, - } - err = w.configureDevice(config) - if err != nil { - return fmt.Errorf("received error \"%v\" while updating peer on interface %s with settings: allowed ips %s, endpoint %s", err, w.Name, allowedIps, endpoint.String()) - } - return nil -} - -// AddAllowedIP adds a prefix to the allowed IPs list of peer -func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error { - w.mu.Lock() - defer w.mu.Unlock() - - log.Debugf("adding allowed IP to interface %s and peer %s: allowed IP %s ", w.Name, peerKey, allowedIP) - - _, ipNet, err := net.ParseCIDR(allowedIP) - if err != nil { - return err - } - - peerKeyParsed, err := wgtypes.ParseKey(peerKey) - if err != nil { - return err - } - peer := wgtypes.PeerConfig{ - PublicKey: peerKeyParsed, - UpdateOnly: true, - ReplaceAllowedIPs: false, - AllowedIPs: []net.IPNet{*ipNet}, - } - - config := wgtypes.Config{ - Peers: []wgtypes.PeerConfig{peer}, - } - err = w.configureDevice(config) - if err != nil { - return fmt.Errorf("received error \"%v\" while adding allowed Ip to peer on interface %s with settings: allowed ips %s", err, w.Name, allowedIP) - } - return nil -} - -// RemoveAllowedIP removes a prefix from the allowed IPs list of peer -func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error { - w.mu.Lock() - defer w.mu.Unlock() - - log.Debugf("removing allowed IP from interface %s and peer %s: allowed IP %s ", w.Name, peerKey, allowedIP) - - _, ipNet, err := net.ParseCIDR(allowedIP) - if err != nil { - return err - } - - peerKeyParsed, err := wgtypes.ParseKey(peerKey) - if err != nil { - return err - } - - existingPeer, err := getPeer(w.Name, peerKey) - if err != nil { - return err - } - - newAllowedIPs := existingPeer.AllowedIPs - - for i, existingAllowedIP := range existingPeer.AllowedIPs { - if existingAllowedIP.String() == ipNet.String() { - newAllowedIPs = append(existingPeer.AllowedIPs[:i], existingPeer.AllowedIPs[i+1:]...) - break - } - } - - if err != nil { - return err - } - peer := wgtypes.PeerConfig{ - PublicKey: peerKeyParsed, - UpdateOnly: true, - ReplaceAllowedIPs: true, - AllowedIPs: newAllowedIPs, - } - - config := wgtypes.Config{ - Peers: []wgtypes.PeerConfig{peer}, - } - err = w.configureDevice(config) - if err != nil { - return fmt.Errorf("received error \"%v\" while removing allowed IP from peer on interface %s with settings: allowed ips %s", err, w.Name, allowedIP) - } - return nil -} - -func getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) { - wg, err := wgctrl.New() - if err != nil { - return wgtypes.Peer{}, err - } - defer func() { - err = wg.Close() - if err != nil { - log.Errorf("got error while closing wgctl: %v", err) - } - }() - - wgDevice, err := wg.Device(ifaceName) - if err != nil { - return wgtypes.Peer{}, err - } - for _, peer := range wgDevice.Peers { - if peer.PublicKey.String() == peerPubKey { - return peer, nil - } - } - return wgtypes.Peer{}, fmt.Errorf("peer not found") -} - -// RemovePeer removes a Wireguard Peer from the interface iface -func (w *WGIface) RemovePeer(peerKey string) error { - w.mu.Lock() - defer w.mu.Unlock() - - log.Debugf("Removing peer %s from interface %s ", peerKey, w.Name) - - peerKeyParsed, err := wgtypes.ParseKey(peerKey) - if err != nil { - return err - } - - peer := wgtypes.PeerConfig{ - PublicKey: peerKeyParsed, - Remove: true, - } - - config := wgtypes.Config{ - Peers: []wgtypes.PeerConfig{peer}, - } - err = w.configureDevice(config) - if err != nil { - return fmt.Errorf("received error \"%v\" while removing peer %s from interface %s", err, peerKey, w.Name) - } - return nil -} diff --git a/iface/iface.go b/iface/iface.go index d75c4db86..4e88f57e7 100644 --- a/iface/iface.go +++ b/iface/iface.go @@ -3,9 +3,12 @@ package iface import ( "fmt" "net" - "os" - "runtime" "sync" + "time" + + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) const ( @@ -13,83 +16,276 @@ const ( DefaultWgPort = 51820 ) -// WGIface represents a interface instance -type WGIface struct { - Name string - Port int - MTU int - Address WGAddress - Interface NetInterface - mu sync.Mutex -} - -// WGAddress Wireguard parsed address -type WGAddress struct { - IP net.IP - Network *net.IPNet -} - -func (addr *WGAddress) String() string { - maskSize, _ := addr.Network.Mask.Size() - return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize) -} - // NetInterface represents a generic network tunnel interface type NetInterface interface { Close() error } +// WGIface represents a interface instance +type WGIface struct { + name string + address WGAddress + mtu int + netInterface NetInterface + mu sync.Mutex +} + // NewWGIFace Creates a new Wireguard interface instance func NewWGIFace(iface string, address string, mtu int) (*WGIface, error) { wgIface := &WGIface{ - Name: iface, - MTU: mtu, + name: iface, + mtu: mtu, mu: sync.Mutex{}, } - wgAddress, err := parseAddress(address) + wgAddress, err := parseWGAddress(address) if err != nil { return wgIface, err } - wgIface.Address = wgAddress + wgIface.address = wgAddress return wgIface, nil } -// parseAddress parse a string ("1.2.3.4/24") address to WG Address -func parseAddress(address string) (WGAddress, error) { - ip, network, err := net.ParseCIDR(address) - if err != nil { - return WGAddress{}, err - } - return WGAddress{ - IP: ip, - Network: network, - }, nil +// Name returns the interface name +func (w *WGIface) Name() string { + return w.name } -// Close closes the tunnel interface -func (w *WGIface) Close() error { +// Address returns the interface address +func (w *WGIface) Address() WGAddress { + return w.address +} + +// Configure configures a Wireguard interface +// The interface must exist before calling this method (e.g. call interface.Create() before) +func (w *WGIface) Configure(privateKey string, port int) error { w.mu.Lock() defer w.mu.Unlock() - if w.Interface == nil { - return nil + + log.Debugf("configuring Wireguard interface %s", w.name) + + log.Debugf("adding Wireguard private key") + key, err := wgtypes.ParseKey(privateKey) + if err != nil { + return err } - err := w.Interface.Close() + fwmark := 0 + config := wgtypes.Config{ + PrivateKey: &key, + ReplacePeers: true, + FirewallMark: &fwmark, + ListenPort: &port, + } + + err = w.configureDevice(config) + if err != nil { + return fmt.Errorf(`received error "%w" while configuring interface %s with port %d`, err, w.name, port) + } + return nil +} + +// UpdateAddr updates address of the interface +func (w *WGIface) UpdateAddr(newAddr string) error { + w.mu.Lock() + defer w.mu.Unlock() + + addr, err := parseWGAddress(newAddr) if err != nil { return err } - if runtime.GOOS != "windows" { - sockPath := "/var/run/wireguard/" + w.Name + ".sock" - if _, statErr := os.Stat(sockPath); statErr == nil { - statErr = os.Remove(sockPath) - if statErr != nil { - return statErr - } + w.address = addr + return w.assignAddr() +} + +// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist +// Endpoint is optional +func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { + w.mu.Lock() + defer w.mu.Unlock() + + log.Debugf("updating interface %s peer %s: endpoint %s ", w.name, peerKey, endpoint) + + //parse allowed ips + _, ipNet, err := net.ParseCIDR(allowedIps) + if err != nil { + return err + } + + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } + peer := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + ReplaceAllowedIPs: true, + AllowedIPs: []net.IPNet{*ipNet}, + PersistentKeepaliveInterval: &keepAlive, + PresharedKey: preSharedKey, + Endpoint: endpoint, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + err = w.configureDevice(config) + if err != nil { + return fmt.Errorf(`received error "%w" while updating peer on interface %s with settings: allowed ips %s, endpoint %s`, err, w.name, allowedIps, endpoint.String()) + } + return nil +} + +// AddAllowedIP adds a prefix to the allowed IPs list of peer +func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error { + w.mu.Lock() + defer w.mu.Unlock() + + log.Debugf("adding allowed IP to interface %s and peer %s: allowed IP %s ", w.name, peerKey, allowedIP) + + _, ipNet, err := net.ParseCIDR(allowedIP) + if err != nil { + return err + } + + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } + peer := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + UpdateOnly: true, + ReplaceAllowedIPs: false, + AllowedIPs: []net.IPNet{*ipNet}, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + err = w.configureDevice(config) + if err != nil { + return fmt.Errorf(`received error "%w" while adding allowed Ip to peer on interface %s with settings: allowed ips %s`, err, w.name, allowedIP) + } + return nil +} + +// RemoveAllowedIP removes a prefix from the allowed IPs list of peer +func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error { + w.mu.Lock() + defer w.mu.Unlock() + + log.Debugf("removing allowed IP from interface %s and peer %s: allowed IP %s ", w.name, peerKey, allowedIP) + + _, ipNet, err := net.ParseCIDR(allowedIP) + if err != nil { + return err + } + + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } + + existingPeer, err := getPeer(w.name, peerKey) + if err != nil { + return err + } + + newAllowedIPs := existingPeer.AllowedIPs + + for i, existingAllowedIP := range existingPeer.AllowedIPs { + if existingAllowedIP.String() == ipNet.String() { + newAllowedIPs = append(existingPeer.AllowedIPs[:i], existingPeer.AllowedIPs[i+1:]...) + break } } + if err != nil { + return err + } + peer := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + UpdateOnly: true, + ReplaceAllowedIPs: true, + AllowedIPs: newAllowedIPs, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + err = w.configureDevice(config) + if err != nil { + return fmt.Errorf(`received error "%w" while removing allowed IP from peer on interface %s with settings: allowed ips %s`, err, w.name, allowedIP) + } return nil } + +// RemovePeer removes a Wireguard Peer from the interface iface +func (w *WGIface) RemovePeer(peerKey string) error { + w.mu.Lock() + defer w.mu.Unlock() + + log.Debugf("Removing peer %s from interface %s ", peerKey, w.name) + + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } + + peer := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + Remove: true, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + err = w.configureDevice(config) + if err != nil { + return fmt.Errorf(`received error "%w" while removing peer %s from interface %s`, err, peerKey, w.name) + } + return nil +} + +func getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) { + wg, err := wgctrl.New() + if err != nil { + return wgtypes.Peer{}, err + } + defer func() { + err = wg.Close() + if err != nil { + log.Errorf("got error while closing wgctl: %v", err) + } + }() + + wgDevice, err := wg.Device(ifaceName) + if err != nil { + return wgtypes.Peer{}, err + } + for _, peer := range wgDevice.Peers { + if peer.PublicKey.String() == peerPubKey { + return peer, nil + } + } + return wgtypes.Peer{}, fmt.Errorf("peer not found") +} + +// configureDevice configures the wireguard device +func (w *WGIface) configureDevice(config wgtypes.Config) error { + wg, err := wgctrl.New() + if err != nil { + return err + } + defer wg.Close() + + // validate if device with name exists + _, err = wg.Device(w.name) + if err != nil { + return err + } + log.Debugf("got Wireguard device %s", w.name) + + return wg.ConfigureDevice(w.name, config) +} diff --git a/iface/iface_darwin.go b/iface/iface_darwin.go index 22c089d27..fd1b6334a 100644 --- a/iface/iface_darwin.go +++ b/iface/iface_darwin.go @@ -1,8 +1,9 @@ package iface import ( - log "github.com/sirupsen/logrus" "os/exec" + + log "github.com/sirupsen/logrus" ) // Create Creates a new Wireguard interface, sets a given IP and brings it up. @@ -15,26 +16,17 @@ func (w *WGIface) Create() error { // assignAddr Adds IP address to the tunnel interface and network route based on the range provided func (w *WGIface) assignAddr() error { - //mask,_ := w.Address.Network.Mask.Size() - // - //address := fmt.Sprintf("%s/%d",w.Address.IP.String() , mask) - - cmd := exec.Command("ifconfig", w.Name, "inet", w.Address.IP.String(), w.Address.IP.String()) + cmd := exec.Command("ifconfig", w.name, "inet", w.address.IP.String(), w.address.IP.String()) if out, err := cmd.CombinedOutput(); err != nil { - log.Infof("adding addreess command \"%v\" failed with output %s and error: ", cmd.String(), out) + log.Infof(`adding addreess command "%v" failed with output %s and error: `, cmd.String(), out) return err } - routeCmd := exec.Command("route", "add", "-net", w.Address.Network.String(), "-interface", w.Name) + routeCmd := exec.Command("route", "add", "-net", w.address.Network.String(), "-interface", w.name) if out, err := routeCmd.CombinedOutput(); err != nil { - log.Printf("adding route command \"%v\" failed with output %s and error: ", routeCmd.String(), out) + log.Printf(`adding route command "%v" failed with output %s and error: `, routeCmd.String(), out) return err } return nil } - -// WireguardModuleIsLoaded check if we can load wireguard mod (linux only) -func WireguardModuleIsLoaded() bool { - return false -} diff --git a/iface/iface_linux.go b/iface/iface_linux.go index cbbe0c6f8..042ed5bd6 100644 --- a/iface/iface_linux.go +++ b/iface/iface_linux.go @@ -2,15 +2,12 @@ package iface import ( "fmt" + "os" + log "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" - "os" ) -type NativeLink struct { - Link *netlink.Link -} - // Create creates a new Wireguard interface, sets a given IP and brings it up. // Will reuse an existing one. func (w *WGIface) Create() error { @@ -33,10 +30,10 @@ func (w *WGIface) Create() error { // Works for Linux and offers much better network performance func (w *WGIface) createWithKernel() error { - link := newWGLink(w.Name) + link := newWGLink(w.name) // check if interface exists - l, err := netlink.LinkByName(w.Name) + l, err := netlink.LinkByName(w.name) if err != nil { switch err.(type) { case netlink.LinkNotFoundError: @@ -54,15 +51,15 @@ func (w *WGIface) createWithKernel() error { } } - log.Debugf("adding device: %s", w.Name) + log.Debugf("adding device: %s", w.name) err = netlink.LinkAdd(link) if os.IsExist(err) { - log.Infof("interface %s already exists. Will reuse.", w.Name) + log.Infof("interface %s already exists. Will reuse.", w.name) } else if err != nil { return err } - w.Interface = link + w.netInterface = link err = w.assignAddr() if err != nil { @@ -70,17 +67,17 @@ func (w *WGIface) createWithKernel() error { } // todo do a discovery - log.Debugf("setting MTU: %d interface: %s", w.MTU, w.Name) - err = netlink.LinkSetMTU(link, w.MTU) + log.Debugf("setting MTU: %d interface: %s", w.mtu, w.name) + err = netlink.LinkSetMTU(link, w.mtu) if err != nil { - log.Errorf("error setting MTU on interface: %s", w.Name) + log.Errorf("error setting MTU on interface: %s", w.name) return err } - log.Debugf("bringing up interface: %s", w.Name) + log.Debugf("bringing up interface: %s", w.name) err = netlink.LinkSetUp(link) if err != nil { - log.Errorf("error bringing up interface: %s", w.Name) + log.Errorf("error bringing up interface: %s", w.name) return err } @@ -89,7 +86,7 @@ func (w *WGIface) createWithKernel() error { // assignAddr Adds IP address to the tunnel interface func (w *WGIface) assignAddr() error { - link := newWGLink(w.Name) + link := newWGLink(w.name) //delete existing addresses list, err := netlink.AddrList(link, 0) @@ -105,11 +102,11 @@ func (w *WGIface) assignAddr() error { } } - log.Debugf("adding address %s to interface: %s", w.Address.String(), w.Name) - addr, _ := netlink.ParseAddr(w.Address.String()) + log.Debugf("adding address %s to interface: %s", w.address.String(), w.name) + addr, _ := netlink.ParseAddr(w.address.String()) err = netlink.AddrAdd(link, addr) if os.IsExist(err) { - log.Infof("interface %s already has the address: %s", w.Name, w.Address.String()) + log.Infof("interface %s already has the address: %s", w.name, w.address.String()) } else if err != nil { return err } diff --git a/iface/iface_test.go b/iface/iface_test.go index 8521e2f17..c34e2959c 100644 --- a/iface/iface_test.go +++ b/iface/iface_test.go @@ -46,11 +46,11 @@ func TestWGIface_UpdateAddr(t *testing.T) { t.Error(err) } }() - port, err := iface.GetListenPort() + port, err := getListenPortByName(ifaceName) if err != nil { t.Fatal(err) } - err = iface.Configure(key, *port) + err = iface.Configure(key, port) if err != nil { t.Fatal(err) } @@ -164,11 +164,11 @@ func Test_ConfigureInterface(t *testing.T) { } }() - port, err := iface.GetListenPort() + port, err := getListenPortByName(ifaceName) if err != nil { t.Fatal(err) } - err = iface.Configure(key, *port) + err = iface.Configure(key, port) if err != nil { t.Fatal(err) } @@ -210,11 +210,11 @@ func Test_UpdatePeer(t *testing.T) { t.Error(err) } }() - port, err := iface.GetListenPort() + port, err := getListenPortByName(ifaceName) if err != nil { t.Fatal(err) } - err = iface.Configure(key, *port) + err = iface.Configure(key, port) if err != nil { t.Fatal(err) } @@ -269,11 +269,11 @@ func Test_RemovePeer(t *testing.T) { t.Error(err) } }() - port, err := iface.GetListenPort() + port, err := getListenPortByName(ifaceName) if err != nil { t.Fatal(err) } - err = iface.Configure(key, *port) + err = iface.Configure(key, port) if err != nil { t.Fatal(err) } @@ -298,12 +298,10 @@ func Test_ConnectPeers(t *testing.T) { peer1ifaceName := fmt.Sprintf("utun%d", WgIntNumber+400) peer1wgIP := "10.99.99.17/30" peer1Key, _ := wgtypes.GeneratePrivateKey() - //peer1Port := WgPort + 4 - peer2ifaceName := fmt.Sprintf("utun%d", 500) + peer2ifaceName := "utun500" peer2wgIP := "10.99.99.18/30" peer2Key, _ := wgtypes.GeneratePrivateKey() - //peer2Port := WgPort + 5 keepAlive := 1 * time.Second @@ -315,11 +313,11 @@ func Test_ConnectPeers(t *testing.T) { if err != nil { t.Fatal(err) } - peer1Port, err := iface1.GetListenPort() + peer1Port, err := getListenPortByName(peer1ifaceName) if err != nil { t.Fatal(err) } - peer1endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", *peer1Port)) + peer1endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", peer1Port)) if err != nil { t.Fatal(err) } @@ -332,11 +330,11 @@ func Test_ConnectPeers(t *testing.T) { if err != nil { t.Fatal(err) } - peer2Port, err := iface2.GetListenPort() + peer2Port, err := getListenPortByName(peer2ifaceName) if err != nil { t.Fatal(err) } - peer2endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", *peer2Port)) + peer2endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", peer2Port)) if err != nil { t.Fatal(err) } @@ -351,11 +349,11 @@ func Test_ConnectPeers(t *testing.T) { } }() - err = iface1.Configure(peer1Key.String(), *peer1Port) + err = iface1.Configure(peer1Key.String(), peer1Port) if err != nil { t.Fatal(err) } - err = iface2.Configure(peer2Key.String(), *peer2Port) + err = iface2.Configure(peer2Key.String(), peer2Port) if err != nil { t.Fatal(err) } @@ -388,3 +386,18 @@ func Test_ConnectPeers(t *testing.T) { } } + +func getListenPortByName(name string) (int, error) { + wg, err := wgctrl.New() + if err != nil { + return 0, err + } + defer wg.Close() + + d, err := wg.Device(name) + if err != nil { + return 0, err + } + + return d.ListenPort, nil +} diff --git a/iface/iface_unix.go b/iface/iface_unix.go index ebac5d8a1..be09afa9e 100644 --- a/iface/iface_unix.go +++ b/iface/iface_unix.go @@ -4,23 +4,53 @@ package iface import ( + "net" + "os" + log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/tun" - "net" ) -// createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation -func (w *WGIface) createWithUserspace() error { +// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only +func (w *WGIface) GetInterfaceGUIDString() (string, error) { + return "", nil +} - tunIface, err := tun.CreateTUN(w.Name, w.MTU) +// Close closes the tunnel interface +func (w *WGIface) Close() error { + w.mu.Lock() + defer w.mu.Unlock() + if w.netInterface == nil { + return nil + } + err := w.netInterface.Close() if err != nil { return err } - w.Interface = tunIface + sockPath := "/var/run/wireguard/" + w.name + ".sock" + if _, statErr := os.Stat(sockPath); statErr == nil { + statErr = os.Remove(sockPath) + if statErr != nil { + return statErr + } + } + + return nil +} + +// createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation +func (w *WGIface) createWithUserspace() error { + + tunIface, err := tun.CreateTUN(w.name, w.mtu) + if err != nil { + return err + } + + w.netInterface = tunIface // We need to create a wireguard-go device and listen to configuration requests tunDevice := device.NewDevice(tunIface, conn.NewDefaultBind(), device.NewLogger(device.LogLevelSilent, "[wiretrustee] ")) @@ -28,7 +58,7 @@ func (w *WGIface) createWithUserspace() error { if err != nil { return err } - uapi, err := getUAPI(w.Name) + uapi, err := getUAPI(w.name) if err != nil { return err } @@ -61,22 +91,3 @@ func getUAPI(iface string) (net.Listener, error) { } return ipc.UAPIListen(iface, tunSock) } - -// UpdateAddr updates address of the interface -func (w *WGIface) UpdateAddr(newAddr string) error { - w.mu.Lock() - defer w.mu.Unlock() - - addr, err := parseAddress(newAddr) - if err != nil { - return err - } - - w.Address = addr - return w.assignAddr() -} - -// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only -func (w *WGIface) GetInterfaceGUIDString() (string, error) { - return "", nil -} diff --git a/iface/iface_windows.go b/iface/iface_windows.go index 5c16916b9..2fefe3402 100644 --- a/iface/iface_windows.go +++ b/iface/iface_windows.go @@ -2,11 +2,11 @@ package iface import ( "fmt" + "net" + log "github.com/sirupsen/logrus" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/driver" - "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" - "net" ) // Create Creates a new Wireguard interface, sets a given IP and brings it up. @@ -15,55 +15,27 @@ func (w *WGIface) Create() error { defer w.mu.Unlock() WintunStaticRequestedGUID, _ := windows.GenerateGUID() - adapter, err := driver.CreateAdapter(w.Name, "WireGuard", &WintunStaticRequestedGUID) + adapter, err := driver.CreateAdapter(w.name, "WireGuard", &WintunStaticRequestedGUID) if err != nil { err = fmt.Errorf("error creating adapter: %w", err) return err } - w.Interface = adapter - luid := adapter.LUID() + w.netInterface = adapter err = adapter.SetAdapterState(driver.AdapterStateUp) if err != nil { return err } - state, _ := luid.GUID() + state, _ := adapter.LUID().GUID() log.Debugln("device guid: ", state.String()) - return w.assignAddr(luid) -} - -// assignAddr Adds IP address to the tunnel interface and network route based on the range provided -func (w *WGIface) assignAddr(luid winipcfg.LUID) error { - - log.Debugf("adding address %s to interface: %s", w.Address.IP, w.Name) - err := luid.SetIPAddresses([]net.IPNet{{w.Address.IP, w.Address.Network.Mask}}) - if err != nil { - return err - } - - return nil -} - -// UpdateAddr updates address of the interface -func (w *WGIface) UpdateAddr(newAddr string) error { - w.mu.Lock() - defer w.mu.Unlock() - - luid := w.Interface.(*driver.Adapter).LUID() - addr, err := parseAddress(newAddr) - if err != nil { - return err - } - - w.Address = addr - return w.assignAddr(luid) + return w.assignAddr() } // GetInterfaceGUIDString returns an interface GUID string func (w *WGIface) GetInterfaceGUIDString() (string, error) { - if w.Interface == nil { + if w.netInterface == nil { return "", fmt.Errorf("interface has not been initialized yet") } - windowsDevice := w.Interface.(*driver.Adapter) + windowsDevice := w.netInterface.(*driver.Adapter) luid := windowsDevice.LUID() guid, err := luid.GUID() if err != nil { @@ -72,7 +44,26 @@ func (w *WGIface) GetInterfaceGUIDString() (string, error) { return guid.String(), nil } -// WireguardModuleIsLoaded check if we can load wireguard mod (linux only) -func WireguardModuleIsLoaded() bool { - return false +// Close closes the tunnel interface +func (w *WGIface) Close() error { + w.mu.Lock() + defer w.mu.Unlock() + if w.netInterface == nil { + return nil + } + + return w.netInterface.Close() +} + +// assignAddr Adds IP address to the tunnel interface and network route based on the range provided +func (w *WGIface) assignAddr() error { + luid := w.netInterface.(*driver.Adapter).LUID() + + log.Debugf("adding address %s to interface: %s", w.address.IP, w.name) + err := luid.SetIPAddresses([]net.IPNet{{w.address.IP, w.address.Network.Mask}}) + if err != nil { + return err + } + + return nil } diff --git a/iface/module.go b/iface/module.go new file mode 100644 index 000000000..08aa4c5e3 --- /dev/null +++ b/iface/module.go @@ -0,0 +1,9 @@ +//go:build !linux +// +build !linux + +package iface + +// WireguardModuleIsLoaded check if we can load wireguard mod (linux only) +func WireguardModuleIsLoaded() bool { + return false +}