mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
Add default firewall rule to allow netbird traffic (#1056)
Add a default firewall rule to allow netbird traffic to be handled by the access control managers. Userspace manager behavior: - When running on Windows, a default rule is add on Windows firewall - For Linux, we are using one of the Kernel managers to add a single rule - This PR doesn't handle macOS Kernel manager behavior: - For NFtables, if there is a filter table, an INPUT rule is added - Iptables follows the previous flow if running on kernel mode. If running on userspace mode, it adds a single rule for INPUT and OUTPUT chains A new checkerFW package has been introduced to consolidate checks across route and access control managers. It supports a new environment variable to skip nftables and allow iptables tests
This commit is contained in:
committed by
GitHub
parent
e4bc76c4de
commit
246abda46d
@@ -40,6 +40,9 @@ const (
|
||||
// It declares methods which handle actions required by the
|
||||
// Netbird client for ACL and routing functionality
|
||||
type Manager interface {
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
AllowNetbird() error
|
||||
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
|
||||
@@ -44,6 +44,7 @@ type Manager struct {
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
}
|
||||
|
||||
type ruleset struct {
|
||||
@@ -52,7 +53,7 @@ type ruleset struct {
|
||||
}
|
||||
|
||||
// Create iptables firewall manager
|
||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) {
|
||||
m := &Manager{
|
||||
wgIface: wgIface,
|
||||
inputDefaultRuleSpecs: []string{
|
||||
@@ -62,26 +63,26 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
rulesets: make(map[string]ruleset),
|
||||
}
|
||||
|
||||
if err := ipset.Init(); err != nil {
|
||||
err := ipset.Init()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init ipset: %w", err)
|
||||
}
|
||||
|
||||
// init clients for booth ipv4 and ipv6
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
m.ipv4Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
||||
}
|
||||
if isIptablesClientAvailable(ipv4Client) {
|
||||
m.ipv4Client = ipv4Client
|
||||
|
||||
if ipv6Supported {
|
||||
m.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
if err != nil {
|
||||
log.Warnf("ip6tables is not installed in the system or not supported: %v. Access rules for this protocol won't be applied.", err)
|
||||
}
|
||||
}
|
||||
|
||||
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
if err != nil {
|
||||
log.Errorf("ip6tables is not installed in the system or not supported: %v", err)
|
||||
} else {
|
||||
if isIptablesClientAvailable(ipv6Client) {
|
||||
m.ipv6Client = ipv6Client
|
||||
}
|
||||
if m.ipv4Client == nil && m.ipv6Client == nil {
|
||||
return nil, fmt.Errorf("iptables is not installed in the system or not enough permissions to use it")
|
||||
}
|
||||
|
||||
if err := m.Reset(); err != nil {
|
||||
@@ -90,11 +91,6 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||
_, err := client.ListChains("filter")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment is empty rule ID is used as comment
|
||||
@@ -276,6 +272,38 @@ func (m *Manager) Reset() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
if m.wgIface.IsUserspaceBind() {
|
||||
_, err := m.AddFiltering(
|
||||
net.ParseIP("0.0.0.0"),
|
||||
"all",
|
||||
nil,
|
||||
nil,
|
||||
fw.RuleDirectionIN,
|
||||
fw.ActionAccept,
|
||||
"",
|
||||
"allow netbird interface traffic",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
|
||||
}
|
||||
_, err = m.AddFiltering(
|
||||
net.ParseIP("0.0.0.0"),
|
||||
"all",
|
||||
nil,
|
||||
nil,
|
||||
fw.RuleDirectionOUT,
|
||||
fw.ActionAccept,
|
||||
"",
|
||||
"allow netbird interface traffic",
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
@@ -406,7 +434,7 @@ func (m *Manager) client(ip net.IP) (*iptables.IPTables, error) {
|
||||
return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err)
|
||||
}
|
||||
|
||||
if err := client.AppendUnique("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil {
|
||||
if err := client.Insert("filter", "INPUT", 1, m.inputDefaultRuleSpecs...); err != nil {
|
||||
return nil, fmt.Errorf("failed to create input chain jump rule: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -33,6 +33,8 @@ func (i *iFaceMock) Address() iface.WGAddress {
|
||||
panic("AddressFunc is not set")
|
||||
}
|
||||
|
||||
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
||||
|
||||
func TestIptablesManager(t *testing.T) {
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err)
|
||||
@@ -53,7 +55,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
manager, err := Create(mock, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(time.Second)
|
||||
@@ -141,7 +143,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
manager, err := Create(mock, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(time.Second)
|
||||
@@ -229,7 +231,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
manager, err := Create(mock, true)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
|
||||
|
||||
@@ -29,6 +29,8 @@ const (
|
||||
|
||||
// FilterOutputChainName is the name of the chain that is used for filtering outgoing packets
|
||||
FilterOutputChainName = "netbird-acl-output-filter"
|
||||
|
||||
AllowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
||||
)
|
||||
|
||||
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
@@ -379,7 +381,7 @@ func (m *Manager) chain(
|
||||
if c != nil {
|
||||
return c, nil
|
||||
}
|
||||
return m.createChainIfNotExists(tf, name, hook, priority, cType)
|
||||
return m.createChainIfNotExists(tf, FilterTableName, name, hook, priority, cType)
|
||||
}
|
||||
|
||||
if ip.To4() != nil {
|
||||
@@ -399,13 +401,20 @@ func (m *Manager) chain(
|
||||
}
|
||||
|
||||
// table returns the table for the given family of the IP address
|
||||
func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
|
||||
func (m *Manager) table(
|
||||
family nftables.TableFamily, tableName string,
|
||||
) (*nftables.Table, error) {
|
||||
// we cache access to Netbird ACL table only
|
||||
if tableName != FilterTableName {
|
||||
return m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
|
||||
}
|
||||
|
||||
if family == nftables.TableFamilyIPv4 {
|
||||
if m.tableIPv4 != nil {
|
||||
return m.tableIPv4, nil
|
||||
}
|
||||
|
||||
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4)
|
||||
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -417,7 +426,7 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
|
||||
return m.tableIPv6, nil
|
||||
}
|
||||
|
||||
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6)
|
||||
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -425,19 +434,21 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
|
||||
return m.tableIPv6, nil
|
||||
}
|
||||
|
||||
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) {
|
||||
func (m *Manager) createTableIfNotExists(
|
||||
family nftables.TableFamily, tableName string,
|
||||
) (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(family)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == FilterTableName {
|
||||
if t.Name == tableName {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4})
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -446,12 +457,13 @@ func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables
|
||||
|
||||
func (m *Manager) createChainIfNotExists(
|
||||
family nftables.TableFamily,
|
||||
tableName string,
|
||||
name string,
|
||||
hooknum nftables.ChainHook,
|
||||
priority nftables.ChainPriority,
|
||||
chainType nftables.ChainType,
|
||||
) (*nftables.Chain, error) {
|
||||
table, err := m.table(family)
|
||||
table, err := m.table(family, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -638,6 +650,22 @@ func (m *Manager) Reset() error {
|
||||
return fmt.Errorf("list of chains: %w", err)
|
||||
}
|
||||
for _, c := range chains {
|
||||
// delete Netbird allow input traffic rule if it exists
|
||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
||||
rules, err := m.rConn.GetRules(c.Table, c)
|
||||
if err != nil {
|
||||
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
||||
continue
|
||||
}
|
||||
for _, r := range rules {
|
||||
if bytes.Equal(r.UserData, []byte(AllowNetbirdInputRuleID)) {
|
||||
if err := m.rConn.DelRule(r); err != nil {
|
||||
log.Errorf("delete rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
|
||||
m.rConn.DelChain(c)
|
||||
}
|
||||
@@ -702,6 +730,53 @@ func (m *Manager) Flush() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
tf := nftables.TableFamilyIPv4
|
||||
if m.wgIface.Address().IP.To4() == nil {
|
||||
tf = nftables.TableFamilyIPv6
|
||||
}
|
||||
|
||||
chains, err := m.rConn.ListChainsOfTableFamily(tf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of chains: %w", err)
|
||||
}
|
||||
|
||||
var chain *nftables.Chain
|
||||
for _, c := range chains {
|
||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
||||
chain = c
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if chain == nil {
|
||||
log.Debugf("chain INPUT not found. Skiping add allow netbird rule")
|
||||
return nil
|
||||
}
|
||||
|
||||
rules, err := m.rConn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rules for the INPUT chain: %v", err)
|
||||
}
|
||||
|
||||
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
|
||||
log.Debugf("allow netbird rule already exists: %v", rule)
|
||||
return nil
|
||||
}
|
||||
|
||||
m.applyAllowNetbirdRules(chain)
|
||||
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to flush allow input netbird rules: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) flushWithBackoff() (err error) {
|
||||
backoff := 4
|
||||
backoffTime := 1000 * time.Millisecond
|
||||
@@ -745,6 +820,44 @@ func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chai
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
||||
rule := &nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: []byte(AllowNetbirdInputRuleID),
|
||||
}
|
||||
_ = m.rConn.InsertRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
|
||||
ifName := ifname(m.wgIface.Name())
|
||||
for _, rule := range existedRules {
|
||||
if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" {
|
||||
if len(rule.Exprs) < 4 {
|
||||
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
|
||||
continue
|
||||
}
|
||||
if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) {
|
||||
continue
|
||||
}
|
||||
return rule
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func encodePort(port fw.Port) []byte {
|
||||
bs := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))
|
||||
|
||||
19
client/firewall/uspfilter/allow_netbird.go
Normal file
19
client/firewall/uspfilter/allow_netbird.go
Normal file
@@ -0,0 +1,19 @@
|
||||
//go:build !windows && !linux
|
||||
|
||||
package uspfilter
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
return nil
|
||||
}
|
||||
21
client/firewall/uspfilter/allow_netbird_linux.go
Normal file
21
client/firewall/uspfilter/allow_netbird_linux.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package uspfilter
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
|
||||
if m.resetHook != nil {
|
||||
return m.resetHook()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
91
client/firewall/uspfilter/allow_netbird_windows.go
Normal file
91
client/firewall/uspfilter/allow_netbird_windows.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
type action string
|
||||
|
||||
const (
|
||||
addRule action = "add"
|
||||
deleteRule action = "delete"
|
||||
|
||||
firewallRuleName = "Netbird"
|
||||
noRulesMatchCriteria = "No rules match the specified criteria"
|
||||
)
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
|
||||
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
|
||||
return fmt.Errorf("couldn't remove windows firewall: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
return manageFirewallRule(firewallRuleName,
|
||||
addRule,
|
||||
"dir=in",
|
||||
"enable=yes",
|
||||
"action=allow",
|
||||
"profile=any",
|
||||
"localip="+m.wgIface.Address().IP.String(),
|
||||
)
|
||||
}
|
||||
|
||||
func manageFirewallRule(ruleName string, action action, args ...string) error {
|
||||
active, err := isFirewallRuleActive(ruleName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if (action == addRule && !active) || (action == deleteRule && active) {
|
||||
baseArgs := []string{"advfirewall", "firewall", string(action), "rule", "name=" + ruleName}
|
||||
args := append(baseArgs, args...)
|
||||
|
||||
cmd := exec.Command("netsh", args...)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isFirewallRuleActive(ruleName string) (bool, error) {
|
||||
args := []string{"advfirewall", "firewall", "show", "rule", "name=" + ruleName}
|
||||
|
||||
cmd := exec.Command("netsh", args...)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
var exitError *exec.ExitError
|
||||
if errors.As(err, &exitError) {
|
||||
// if the firewall rule is not active, we expect last exit code to be 1
|
||||
exitStatus := exitError.Sys().(syscall.WaitStatus).ExitStatus()
|
||||
if exitStatus == 1 {
|
||||
if strings.Contains(string(output), noRulesMatchCriteria) {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
if strings.Contains(string(output), noRulesMatchCriteria) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
@@ -19,6 +19,7 @@ const layerTypeAll = 0
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
SetFilter(iface.PacketFilter) error
|
||||
Address() iface.WGAddress
|
||||
}
|
||||
|
||||
// RuleSet is a set of rules grouped by a string key
|
||||
@@ -30,6 +31,8 @@ type Manager struct {
|
||||
incomingRules map[string]RuleSet
|
||||
wgNetwork *net.IPNet
|
||||
decoders sync.Pool
|
||||
wgIface IFaceMapper
|
||||
resetHook func() error
|
||||
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
@@ -65,6 +68,7 @@ func Create(iface IFaceMapper) (*Manager, error) {
|
||||
},
|
||||
outgoingRules: make(map[string]RuleSet),
|
||||
incomingRules: make(map[string]RuleSet),
|
||||
wgIface: iface,
|
||||
}
|
||||
|
||||
if err := iface.SetFilter(m); err != nil {
|
||||
@@ -171,17 +175,6 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
@@ -375,3 +368,8 @@ func (m *Manager) RemovePacketHook(hookID string) error {
|
||||
}
|
||||
return fmt.Errorf("hook with given id not found")
|
||||
}
|
||||
|
||||
// SetResetHook which will be executed in the end of Reset method
|
||||
func (m *Manager) SetResetHook(hook func() error) {
|
||||
m.resetHook = hook
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
type IFaceMock struct {
|
||||
SetFilterFunc func(iface.PacketFilter) error
|
||||
AddressFunc func() iface.WGAddress
|
||||
}
|
||||
|
||||
func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
|
||||
@@ -25,6 +26,13 @@ func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
|
||||
return i.SetFilterFunc(iface)
|
||||
}
|
||||
|
||||
func (i *IFaceMock) Address() iface.WGAddress {
|
||||
if i.AddressFunc == nil {
|
||||
return iface.WGAddress{}
|
||||
}
|
||||
return i.AddressFunc()
|
||||
}
|
||||
|
||||
func TestManagerCreate(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||
|
||||
Reference in New Issue
Block a user