mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 01:06:45 +00:00
Add logic layer for the ACL firewall rules management.
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -16,6 +17,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||||
@@ -112,7 +114,9 @@ type Engine struct {
|
|||||||
|
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
|
|
||||||
routeManager routemanager.Manager
|
routeManager routemanager.Manager
|
||||||
|
firewallManager firewall.Manager
|
||||||
|
firewallRules map[string]firewall.Rule
|
||||||
|
|
||||||
dnsServer dns.Server
|
dnsServer dns.Server
|
||||||
}
|
}
|
||||||
@@ -219,6 +223,12 @@ func (e *Engine) Start() error {
|
|||||||
|
|
||||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder)
|
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder)
|
||||||
|
|
||||||
|
e.firewallManager, err = buildFirewallManager()
|
||||||
|
if err != nil {
|
||||||
|
log.Error("failed to create firewall manager, ACL policy will not work: %s", err.Error())
|
||||||
|
}
|
||||||
|
e.firewallRules = make(map[string]firewall.Rule)
|
||||||
|
|
||||||
if e.dnsServer == nil {
|
if e.dnsServer == nil {
|
||||||
// todo fix custom address
|
// todo fix custom address
|
||||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
|
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
|
||||||
@@ -622,6 +632,9 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
log.Errorf("failed to update dns server, err: %v", err)
|
log.Errorf("failed to update dns server, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := e.applyFirewallRules(networkMap.FirewallRules); err != nil {
|
||||||
|
log.Errorf("failed apply firewall rules, err: %v", err)
|
||||||
|
}
|
||||||
e.networkSerial = serial
|
e.networkSerial = serial
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1030,6 +1043,96 @@ func (e *Engine) close() {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// applyFirewallRules to the local firewall manager processed by ACL policy.
|
||||||
|
func (e *Engine) applyFirewallRules(rules []*mgmProto.FirewallRule) error {
|
||||||
|
if e.firewallManager == nil {
|
||||||
|
log.Debug("firewall manager is not supported, skipping firewall rules")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
newRules := make([]string, 0)
|
||||||
|
for _, r := range rules {
|
||||||
|
rule := e.protoRuleToFirewallRule(r)
|
||||||
|
if rule == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newRules = append(newRules, rule.GetRuleID())
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ruleID := range newRules {
|
||||||
|
if rule, ok := e.firewallRules[ruleID]; ok {
|
||||||
|
if err := e.firewallManager.DeleteRule(rule); err != nil {
|
||||||
|
log.Debug("failed to delete firewall rule: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
delete(e.firewallRules, ruleID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule {
|
||||||
|
ip := net.ParseIP(r.PeerIP)
|
||||||
|
if ip == nil {
|
||||||
|
log.Debug("invalid IP address, skipping firewall rule")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var port firewall.Port
|
||||||
|
if r.Port != "" {
|
||||||
|
split := strings.Split(r.Port, "/")
|
||||||
|
value, err := strconv.Atoi(split[0])
|
||||||
|
if err != nil {
|
||||||
|
log.Debug("invalid port, skipping firewall rule")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
port.Values = []int{value}
|
||||||
|
// get protocol from the port suffix if it exists
|
||||||
|
if len(split) > 1 {
|
||||||
|
switch split[1] {
|
||||||
|
case "tcp":
|
||||||
|
port.Proto = firewall.PortProtocolTCP
|
||||||
|
case "udp":
|
||||||
|
port.Proto = firewall.PortProtocolUDP
|
||||||
|
default:
|
||||||
|
log.Debug("invalid protocol, skipping firewall rule")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var direction firewall.Direction
|
||||||
|
switch r.Direction {
|
||||||
|
case "src":
|
||||||
|
direction = firewall.DirectionSrc
|
||||||
|
case "dst":
|
||||||
|
direction = firewall.DirectionDst
|
||||||
|
default:
|
||||||
|
log.Debug("invalid direction, skipping firewall rule")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var action firewall.Action
|
||||||
|
switch r.Action {
|
||||||
|
case "accept":
|
||||||
|
action = firewall.ActionAccept
|
||||||
|
case "drop":
|
||||||
|
action = firewall.ActionDrop
|
||||||
|
default:
|
||||||
|
log.Debug("invalid action, skipping firewall rule")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rule, err := e.firewallManager.AddFiltering(ip, &port, direction, action, "")
|
||||||
|
if err != nil {
|
||||||
|
log.Debug("failed to add firewall rule: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
e.firewallRules[rule.GetRuleID()] = rule
|
||||||
|
return rule
|
||||||
|
}
|
||||||
|
|
||||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||||
iface, err := net.InterfaceByName(ifaceName)
|
iface, err := net.InterfaceByName(ifaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
14
client/internal/firewall.go
Normal file
14
client/internal/firewall.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
|
)
|
||||||
|
|
||||||
|
func buildFirewallManager() (fw firewall.Manager, err error) {
|
||||||
|
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||||
|
}
|
||||||
17
client/internal/firewall_linux.go
Normal file
17
client/internal/firewall_linux.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/iptables"
|
||||||
|
)
|
||||||
|
|
||||||
|
func buildFirewallManager() (fw firewall.Manager, err error) {
|
||||||
|
fw, err = iptables.Create()
|
||||||
|
if err != nil {
|
||||||
|
// TODO: handle init nftables manager when it will be implemented
|
||||||
|
return nil, fmt.Errorf("create iptables manager: %w", err)
|
||||||
|
}
|
||||||
|
return fw, nil
|
||||||
|
}
|
||||||
@@ -272,7 +272,7 @@ func (a *Account) GetGroup(groupID string) *Group {
|
|||||||
|
|
||||||
// GetPeerNetworkMap returns a group by ID if exists, nil otherwise
|
// GetPeerNetworkMap returns a group by ID if exists, nil otherwise
|
||||||
func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap {
|
func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap {
|
||||||
aclPeers, _ := a.getPeersByPolicy(peerID)
|
aclPeers, firewallRules := a.getPeersByPolicy(peerID)
|
||||||
// exclude expired peers
|
// exclude expired peers
|
||||||
var peersToConnect []*Peer
|
var peersToConnect []*Peer
|
||||||
var expiredPeers []*Peer
|
var expiredPeers []*Peer
|
||||||
@@ -303,11 +303,12 @@ func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &NetworkMap{
|
return &NetworkMap{
|
||||||
Peers: peersToConnect,
|
Peers: peersToConnect,
|
||||||
Network: a.Network.Copy(),
|
Network: a.Network.Copy(),
|
||||||
Routes: routesUpdate,
|
Routes: routesUpdate,
|
||||||
DNSConfig: dnsUpdate,
|
DNSConfig: dnsUpdate,
|
||||||
OfflinePeers: expiredPeers,
|
OfflinePeers: expiredPeers,
|
||||||
|
FirewallRules: firewallRules,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -420,6 +420,8 @@ func toSyncResponse(config *Config, peer *Peer, turnCredentials *TURNCredentials
|
|||||||
|
|
||||||
offlinePeers := toRemotePeerConfig(networkMap.OfflinePeers, dnsName)
|
offlinePeers := toRemotePeerConfig(networkMap.OfflinePeers, dnsName)
|
||||||
|
|
||||||
|
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
|
||||||
|
|
||||||
return &proto.SyncResponse{
|
return &proto.SyncResponse{
|
||||||
WiretrusteeConfig: wtConfig,
|
WiretrusteeConfig: wtConfig,
|
||||||
PeerConfig: pConfig,
|
PeerConfig: pConfig,
|
||||||
@@ -433,6 +435,7 @@ func toSyncResponse(config *Config, peer *Peer, turnCredentials *TURNCredentials
|
|||||||
RemotePeersIsEmpty: len(remotePeers) == 0,
|
RemotePeersIsEmpty: len(remotePeers) == 0,
|
||||||
Routes: routesUpdate,
|
Routes: routesUpdate,
|
||||||
DNSConfig: dnsUpdate,
|
DNSConfig: dnsUpdate,
|
||||||
|
FirewallRules: firewallRules,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,11 +23,12 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type NetworkMap struct {
|
type NetworkMap struct {
|
||||||
Peers []*Peer
|
Peers []*Peer
|
||||||
Network *Network
|
Network *Network
|
||||||
Routes []*route.Route
|
Routes []*route.Route
|
||||||
DNSConfig nbdns.Config
|
DNSConfig nbdns.Config
|
||||||
OfflinePeers []*Peer
|
OfflinePeers []*Peer
|
||||||
|
FirewallRules []*FirewallRule
|
||||||
}
|
}
|
||||||
|
|
||||||
type Network struct {
|
type Network struct {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"html/template"
|
"html/template"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
|
|
||||||
@@ -449,3 +450,17 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policy *Policy) (e
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
|
||||||
|
result := make([]*proto.FirewallRule, len(update))
|
||||||
|
for i := range update {
|
||||||
|
result[i] = &proto.FirewallRule{
|
||||||
|
PeerID: update[i].PeerID,
|
||||||
|
PeerIP: update[i].PeerIP,
|
||||||
|
Direction: update[i].Direction,
|
||||||
|
Port: update[i].Port,
|
||||||
|
Action: update[i].Action,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user