[client] Refactor/iface pkg (#2646)

Refactor the flat code structure
This commit is contained in:
Zoltan Papp
2024-10-02 18:24:22 +02:00
committed by GitHub
parent 7e5d3bdfe2
commit fd67892cb4
105 changed files with 505 additions and 438 deletions

View File

@@ -0,0 +1,5 @@
package configurer
import "errors"
var ErrPeerNotFound = errors.New("peer not found")

View File

@@ -0,0 +1,220 @@
//go:build (linux && !android) || freebsd
package configurer
import (
"fmt"
"net"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type KernelConfigurer struct {
deviceName string
}
func NewKernelConfigurer(deviceName string) *KernelConfigurer {
return &KernelConfigurer{
deviceName: deviceName,
}
}
func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error {
log.Debugf("adding Wireguard private key")
key, err := wgtypes.ParseKey(privateKey)
if err != nil {
return err
}
fwmark := getFwmark()
config := wgtypes.Config{
PrivateKey: &key,
ReplacePeers: true,
FirewallMark: &fwmark,
ListenPort: &port,
}
err = c.configure(config)
if err != nil {
return fmt.Errorf(`received error "%w" while configuring interface %s with port %d`, err, c.deviceName, port)
}
return nil
}
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
// parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive,
Endpoint: endpoint,
PresharedKey: preSharedKey,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = c.configure(config)
if err != nil {
return fmt.Errorf(`received error "%w" while updating peer on interface %s with settings: allowed ips %s, endpoint %s`, err, c.deviceName, allowedIps, endpoint.String())
}
return nil
}
func (c *KernelConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = c.configure(config)
if err != nil {
return fmt.Errorf(`received error "%w" while removing peer %s from interface %s`, err, peerKey, c.deviceName)
}
return nil
}
func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
UpdateOnly: true,
ReplaceAllowedIPs: false,
AllowedIPs: []net.IPNet{*ipNet},
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = c.configure(config)
if err != nil {
return fmt.Errorf(`received error "%w" while adding allowed Ip to peer on interface %s with settings: allowed ips %s`, err, c.deviceName, allowedIP)
}
return nil
}
func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return fmt.Errorf("parse allowed IP: %w", err)
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return fmt.Errorf("parse peer key: %w", err)
}
existingPeer, err := c.getPeer(c.deviceName, peerKey)
if err != nil {
return fmt.Errorf("get peer: %w", err)
}
newAllowedIPs := existingPeer.AllowedIPs
for i, existingAllowedIP := range existingPeer.AllowedIPs {
if existingAllowedIP.String() == ipNet.String() {
newAllowedIPs = append(existingPeer.AllowedIPs[:i], existingPeer.AllowedIPs[i+1:]...) //nolint:gocritic
break
}
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
UpdateOnly: true,
ReplaceAllowedIPs: true,
AllowedIPs: newAllowedIPs,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = c.configure(config)
if err != nil {
return fmt.Errorf("remove allowed IP %s on interface %s: %w", allowedIP, c.deviceName, err)
}
return nil
}
func (c *KernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
wg, err := wgctrl.New()
if err != nil {
return wgtypes.Peer{}, fmt.Errorf("wgctl: %w", err)
}
defer func() {
err = wg.Close()
if err != nil {
log.Errorf("Got error while closing wgctl: %v", err)
}
}()
wgDevice, err := wg.Device(ifaceName)
if err != nil {
return wgtypes.Peer{}, fmt.Errorf("get device %s: %w", ifaceName, err)
}
for _, peer := range wgDevice.Peers {
if peer.PublicKey.String() == peerPubKey {
return peer, nil
}
}
return wgtypes.Peer{}, ErrPeerNotFound
}
func (c *KernelConfigurer) configure(config wgtypes.Config) error {
wg, err := wgctrl.New()
if err != nil {
return err
}
defer wg.Close()
// validate if device with name exists
_, err = wg.Device(c.deviceName)
if err != nil {
return err
}
return wg.ConfigureDevice(c.deviceName, config)
}
func (c *KernelConfigurer) Close() {
}
func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) {
peer, err := c.getPeer(c.deviceName, peerKey)
if err != nil {
return WGStats{}, fmt.Errorf("get wireguard stats: %w", err)
}
return WGStats{
LastHandshake: peer.LastHandshakeTime,
TxBytes: peer.TransmitBytes,
RxBytes: peer.ReceiveBytes,
}, nil
}

View File

@@ -0,0 +1,6 @@
//go:build linux || windows || freebsd
package configurer
// WgInterfaceDefault is a default interface name of Wiretrustee
const WgInterfaceDefault = "wt0"

View File

@@ -0,0 +1,6 @@
//go:build darwin
package configurer
// WgInterfaceDefault is a default interface name of Wiretrustee
const WgInterfaceDefault = "utun100"

View File

@@ -0,0 +1,26 @@
//go:build !windows
package configurer
import (
"net"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/ipc"
)
func openUAPI(deviceName string) (net.Listener, error) {
uapiSock, err := ipc.UAPIOpen(deviceName)
if err != nil {
log.Errorf("failed to open uapi socket: %v", err)
return nil, err
}
listener, err := ipc.UAPIListen(deviceName, uapiSock)
if err != nil {
log.Errorf("failed to listen on uapi socket: %v", err)
return nil, err
}
return listener, nil
}

View File

@@ -0,0 +1,11 @@
package configurer
import (
"net"
"golang.zx2c4.com/wireguard/ipc"
)
func openUAPI(deviceName string) (net.Listener, error) {
return ipc.UAPIListen(deviceName)
}

View File

@@ -0,0 +1,369 @@
package configurer
import (
"encoding/hex"
"fmt"
"net"
"os"
"runtime"
"strconv"
"strings"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbnet "github.com/netbirdio/netbird/util/net"
)
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
type WGUSPConfigurer struct {
device *device.Device
deviceName string
uapiListener net.Listener
}
func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer {
wgCfg := &WGUSPConfigurer{
device: device,
deviceName: deviceName,
}
wgCfg.startUAPI()
return wgCfg
}
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
log.Debugf("adding Wireguard private key")
key, err := wgtypes.ParseKey(privateKey)
if err != nil {
return err
}
fwmark := getFwmark()
config := wgtypes.Config{
PrivateKey: &key,
ReplacePeers: true,
FirewallMark: &fwmark,
ListenPort: &port,
}
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
// parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey,
Endpoint: endpoint,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
UpdateOnly: true,
ReplaceAllowedIPs: false,
AllowedIPs: []net.IPNet{*ipNet},
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
ipc, err := c.device.IpcGet()
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
hexKey := hex.EncodeToString(peerKeyParsed[:])
lines := strings.Split(ipc, "\n")
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
UpdateOnly: true,
ReplaceAllowedIPs: true,
AllowedIPs: []net.IPNet{},
}
foundPeer := false
removedAllowedIP := false
for _, line := range lines {
line = strings.TrimSpace(line)
// If we're within the details of the found peer and encounter another public key,
// this means we're starting another peer's details. So, reset the flag.
if strings.HasPrefix(line, "public_key=") && foundPeer {
foundPeer = false
}
// Identify the peer with the specific public key
if line == fmt.Sprintf("public_key=%s", hexKey) {
foundPeer = true
}
// If we're within the details of the found peer and find the specific allowed IP, skip this line
if foundPeer && line == "allowed_ip="+ip {
removedAllowedIP = true
continue
}
// Append the line to the output string
if foundPeer && strings.HasPrefix(line, "allowed_ip=") {
allowedIP := strings.TrimPrefix(line, "allowed_ip=")
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
}
peer.AllowedIPs = append(peer.AllowedIPs, *ipNet)
}
}
if !removedAllowedIP {
return ErrAllowedIPNotFound
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
return c.device.IpcSet(toWgUserspaceString(config))
}
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
func (t *WGUSPConfigurer) startUAPI() {
var err error
t.uapiListener, err = openUAPI(t.deviceName)
if err != nil {
log.Errorf("failed to open uapi listener: %v", err)
return
}
go func(uapi net.Listener) {
for {
uapiConn, uapiErr := uapi.Accept()
if uapiErr != nil {
log.Tracef("%s", uapiErr)
return
}
go func() {
t.device.IpcHandle(uapiConn)
}()
}
}(t.uapiListener)
}
func (t *WGUSPConfigurer) Close() {
if t.uapiListener != nil {
err := t.uapiListener.Close()
if err != nil {
log.Errorf("failed to close uapi listener: %v", err)
}
}
if runtime.GOOS == "linux" {
sockPath := "/var/run/wireguard/" + t.deviceName + ".sock"
if _, statErr := os.Stat(sockPath); statErr == nil {
_ = os.Remove(sockPath)
}
}
}
func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) {
ipc, err := t.device.IpcGet()
if err != nil {
return WGStats{}, fmt.Errorf("ipc get: %w", err)
}
stats, err := findPeerInfo(ipc, peerKey, []string{
"last_handshake_time_sec",
"last_handshake_time_nsec",
"tx_bytes",
"rx_bytes",
})
if err != nil {
return WGStats{}, fmt.Errorf("find peer info: %w", err)
}
sec, err := strconv.ParseInt(stats["last_handshake_time_sec"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse handshake sec: %w", err)
}
nsec, err := strconv.ParseInt(stats["last_handshake_time_nsec"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse handshake nsec: %w", err)
}
txBytes, err := strconv.ParseInt(stats["tx_bytes"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse tx_bytes: %w", err)
}
rxBytes, err := strconv.ParseInt(stats["rx_bytes"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse rx_bytes: %w", err)
}
return WGStats{
LastHandshake: time.Unix(sec, nsec),
TxBytes: txBytes,
RxBytes: rxBytes,
}, nil
}
func findPeerInfo(ipcInput string, peerKey string, searchConfigKeys []string) (map[string]string, error) {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return nil, fmt.Errorf("parse key: %w", err)
}
hexKey := hex.EncodeToString(peerKeyParsed[:])
lines := strings.Split(ipcInput, "\n")
configFound := map[string]string{}
foundPeer := false
for _, line := range lines {
line = strings.TrimSpace(line)
// If we're within the details of the found peer and encounter another public key,
// this means we're starting another peer's details. So, stop.
if strings.HasPrefix(line, "public_key=") && foundPeer {
break
}
// Identify the peer with the specific public key
if line == fmt.Sprintf("public_key=%s", hexKey) {
foundPeer = true
}
for _, key := range searchConfigKeys {
if foundPeer && strings.HasPrefix(line, key+"=") {
v := strings.SplitN(line, "=", 2)
configFound[v[0]] = v[1]
}
}
}
// todo: use multierr
for _, key := range searchConfigKeys {
if _, ok := configFound[key]; !ok {
return configFound, fmt.Errorf("config key not found: %s", key)
}
}
if !foundPeer {
return nil, fmt.Errorf("%w: %s", ErrPeerNotFound, peerKey)
}
return configFound, nil
}
func toWgUserspaceString(wgCfg wgtypes.Config) string {
var sb strings.Builder
if wgCfg.PrivateKey != nil {
hexKey := hex.EncodeToString(wgCfg.PrivateKey[:])
sb.WriteString(fmt.Sprintf("private_key=%s\n", hexKey))
}
if wgCfg.ListenPort != nil {
sb.WriteString(fmt.Sprintf("listen_port=%d\n", *wgCfg.ListenPort))
}
if wgCfg.ReplacePeers {
sb.WriteString("replace_peers=true\n")
}
if wgCfg.FirewallMark != nil {
sb.WriteString(fmt.Sprintf("fwmark=%d\n", *wgCfg.FirewallMark))
}
for _, p := range wgCfg.Peers {
hexKey := hex.EncodeToString(p.PublicKey[:])
sb.WriteString(fmt.Sprintf("public_key=%s\n", hexKey))
if p.PresharedKey != nil {
preSharedHexKey := hex.EncodeToString(p.PresharedKey[:])
sb.WriteString(fmt.Sprintf("preshared_key=%s\n", preSharedHexKey))
}
if p.Remove {
sb.WriteString("remove=true")
}
if p.ReplaceAllowedIPs {
sb.WriteString("replace_allowed_ips=true\n")
}
for _, aip := range p.AllowedIPs {
sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String()))
}
if p.Endpoint != nil {
sb.WriteString(fmt.Sprintf("endpoint=%s\n", p.Endpoint.String()))
}
if p.PersistentKeepaliveInterval != nil {
sb.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", int(p.PersistentKeepaliveInterval.Seconds())))
}
}
return sb.String()
}
func getFwmark() int {
if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() {
return nbnet.NetbirdFwmark
}
return 0
}

View File

@@ -0,0 +1,104 @@
package configurer
import (
"encoding/hex"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
var ipcFixture = `
private_key=e84b5a6d2717c1003a13b431570353dbaca9146cf150c5f8575680feba52027a
listen_port=12912
public_key=b85996fecc9c7f1fc6d2572a76eda11d59bcd20be8e543b15ce4bd85a8e75a33
preshared_key=188515093e952f5f22e865cef3012e72f8b5f0b598ac0309d5dacce3b70fcf52
allowed_ip=192.168.4.4/32
endpoint=[abcd:23::33%2]:51820
public_key=58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376
tx_bytes=38333
rx_bytes=2224
allowed_ip=192.168.4.6/32
persistent_keepalive_interval=111
endpoint=182.122.22.19:3233
public_key=662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58
endpoint=5.152.198.39:51820
allowed_ip=192.168.4.10/32
allowed_ip=192.168.4.11/32
tx_bytes=1212111
rx_bytes=1929999999
protocol_version=1
errno=0
`
func Test_findPeerInfo(t *testing.T) {
tests := []struct {
name string
peerKey string
searchKeys []string
want map[string]string
wantErr bool
}{
{
name: "single",
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
searchKeys: []string{"tx_bytes"},
want: map[string]string{
"tx_bytes": "38333",
},
wantErr: false,
},
{
name: "multiple",
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
searchKeys: []string{"tx_bytes", "rx_bytes"},
want: map[string]string{
"tx_bytes": "38333",
"rx_bytes": "2224",
},
wantErr: false,
},
{
name: "lastpeer",
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
searchKeys: []string{"tx_bytes", "rx_bytes"},
want: map[string]string{
"tx_bytes": "1212111",
"rx_bytes": "1929999999",
},
wantErr: false,
},
{
name: "peer not found",
peerKey: "1111111111111111111111111111111111111111111111111111111111111111",
searchKeys: nil,
want: nil,
wantErr: true,
},
{
name: "key not found",
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
searchKeys: []string{"tx_bytes", "unknown_key"},
want: map[string]string{
"tx_bytes": "1212111",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
res, err := hex.DecodeString(tt.peerKey)
require.NoError(t, err)
key, err := wgtypes.NewKey(res)
require.NoError(t, err)
got, err := findPeerInfo(ipcFixture, key.String(), tt.searchKeys)
assert.Equalf(t, tt.wantErr, err != nil, fmt.Sprintf("findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys))
assert.Equalf(t, tt.want, got, "findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys)
})
}
}

View File

@@ -0,0 +1,9 @@
package configurer
import "time"
type WGStats struct {
LastHandshake time.Time
TxBytes int64
RxBytes int64
}