mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-17 15:56:39 +00:00
Compare commits
91 Commits
debug-api
...
restore-pr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
70398ea125 | ||
|
|
9fa75e0ac5 | ||
|
|
64f27aee55 | ||
|
|
78b86e0beb | ||
|
|
18871b554f | ||
|
|
76d73548d6 | ||
|
|
11828a064a | ||
|
|
0c2a3dd937 | ||
|
|
47dcf8d68c | ||
|
|
cc8f6bcaf3 | ||
|
|
d8bcf745b0 | ||
|
|
8430139d80 | ||
|
|
a2962b4ce0 | ||
|
|
16fffdb75b | ||
|
|
036cecbf46 | ||
|
|
3482852bb6 | ||
|
|
fd62665b1f | ||
|
|
36da464413 | ||
|
|
86370a0e7b | ||
|
|
cb16d0f45f | ||
|
|
e8d8bd8f18 | ||
|
|
8b07f21c28 | ||
|
|
54be772ffd | ||
|
|
3c3a454e61 | ||
|
|
5ff77b3595 | ||
|
|
b180edbe5c | ||
|
|
0a042ac36d | ||
|
|
e9f11fb11b | ||
|
|
419ed275fa | ||
|
|
89a55bcf4e | ||
|
|
2d4fcaf186 | ||
|
|
acf172b52c | ||
|
|
8c81a823fa | ||
|
|
619c549547 | ||
|
|
9a713a0987 | ||
|
|
c4945cd565 | ||
|
|
1e10c17ecb | ||
|
|
96d5190436 | ||
|
|
d19c26df06 | ||
|
|
36e36414d9 | ||
|
|
7e69589e05 | ||
|
|
aa613ab79a | ||
|
|
6ead0ff95e | ||
|
|
0db65a8984 | ||
|
|
c138807e95 | ||
|
|
637c0c8949 | ||
|
|
c72e13d8e6 | ||
|
|
f6d7bccfa0 | ||
|
|
e3ed01cafb | ||
|
|
fa748a7ec2 | ||
|
|
cccc615783 | ||
|
|
2021463ca0 | ||
|
|
f48cfd52e9 | ||
|
|
6838f53f40 | ||
|
|
8276236dfa | ||
|
|
994b923d56 | ||
|
|
59e2432231 | ||
|
|
eee0d123e4 | ||
|
|
e943203ae2 | ||
|
|
6a775217cf | ||
|
|
175674749f | ||
|
|
1e534cecf6 | ||
|
|
aa3aa8c6a8 | ||
|
|
fbdfe45c25 | ||
|
|
81ee172db8 | ||
|
|
f8fd65a65f | ||
|
|
62b978c050 | ||
|
|
4ebf1410c6 | ||
|
|
630edf2480 | ||
|
|
ea469d28d7 | ||
|
|
597f1d47b8 | ||
|
|
fcc96417f9 | ||
|
|
8755211a60 | ||
|
|
e6d4653b08 | ||
|
|
eb69f2de78 | ||
|
|
206420c085 | ||
|
|
88a864c195 | ||
|
|
a789e9e6d8 | ||
|
|
9930913e4e | ||
|
|
48675f579f | ||
|
|
afec455f86 | ||
|
|
035c5d9f23 | ||
|
|
b2a5b29fb2 | ||
|
|
9ec61206c2 | ||
|
|
1b011a2d85 | ||
|
|
a85ea1ddb0 | ||
|
|
829e40d2aa | ||
|
|
6344e34880 | ||
|
|
a76ca8c565 | ||
|
|
26693e4ea8 | ||
|
|
f6a71f4193 |
@@ -90,13 +90,13 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
|
|||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock())
|
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settings.NewManagerMock())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
|
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManagerMock(), peersUpdateManager, secretsManager, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,17 +10,18 @@ import (
|
|||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewFirewall creates a firewall manager instance
|
// NewFirewall creates a firewall manager instance
|
||||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
||||||
if !iface.IsUserspaceBind() {
|
if !iface.IsUserspaceBind() {
|
||||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
// use userspace packet filtering firewall
|
// use userspace packet filtering firewall
|
||||||
fm, err := uspfilter.Create(iface, disableServerRoutes)
|
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -33,7 +34,7 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
|||||||
// FWType is the type for the firewall type
|
// FWType is the type for the firewall type
|
||||||
type FWType int
|
type FWType int
|
||||||
|
|
||||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
||||||
// on the linux system we try to user nftables or iptables
|
// on the linux system we try to user nftables or iptables
|
||||||
// in any case, because we need to allow netbird interface traffic
|
// in any case, because we need to allow netbird interface traffic
|
||||||
// so we use AllowNetbird traffic from these firewall managers
|
// so we use AllowNetbird traffic from these firewall managers
|
||||||
@@ -47,7 +48,7 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableS
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||||
}
|
}
|
||||||
return createUserspaceFirewall(iface, fm, disableServerRoutes)
|
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
|
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
|
||||||
@@ -77,12 +78,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) {
|
||||||
var errUsp error
|
var errUsp error
|
||||||
if fm != nil {
|
if fm != nil {
|
||||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes)
|
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger)
|
||||||
} else {
|
} else {
|
||||||
fm, errUsp = uspfilter.Create(iface, disableServerRoutes)
|
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if errUsp != nil {
|
if errUsp != nil {
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) AddPeerFiltering(
|
func (m *aclManager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
protocol firewall.Protocol,
|
protocol firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
|
|||||||
@@ -96,21 +96,22 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
//
|
//
|
||||||
// Comment will be ignored because some system this feature is not supported
|
// Comment will be ignored because some system this feature is not supported
|
||||||
func (m *Manager) AddPeerFiltering(
|
func (m *Manager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
protocol firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
_ string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, action, ipsetName)
|
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination netip.Prefix,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
@@ -125,7 +126,7 @@ func (m *Manager) AddRouteFiltering(
|
|||||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
// DeletePeerRule from the firewall by rule definition
|
||||||
@@ -196,13 +197,13 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, err := m.AddPeerFiltering(
|
_, err := m.AddPeerFiltering(
|
||||||
|
nil,
|
||||||
net.IP{0, 0, 0, 0},
|
net.IP{0, 0, 0, 0},
|
||||||
"all",
|
"all",
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
firewall.ActionAccept,
|
firewall.ActionAccept,
|
||||||
"",
|
"",
|
||||||
"",
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
IsRange: true,
|
IsRange: true,
|
||||||
Values: []uint16{8043, 8046},
|
Values: []uint16{8043, 8046},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
@@ -97,7 +97,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
// add second rule
|
// add second rule
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{Values: []uint16{5353}}
|
port := &fw.Port{Values: []uint16{5353}}
|
||||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
|
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Close(nil)
|
err = manager.Close(nil)
|
||||||
@@ -148,7 +148,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
Values: []uint16{443},
|
Values: []uint16{443},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range")
|
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||||
@@ -216,7 +216,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@@ -121,6 +121,7 @@ func (r *router) init(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) AddRouteFiltering(
|
func (r *router) AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination netip.Prefix,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
@@ -128,7 +129,7 @@ func (r *router) AddRouteFiltering(
|
|||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||||
return ruleKey, nil
|
return ruleKey, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -330,7 +330,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||||
require.NoError(t, err, "AddRouteFiltering failed")
|
require.NoError(t, err, "AddRouteFiltering failed")
|
||||||
|
|
||||||
// Check if the rule is in the internal map
|
// Check if the rule is in the internal map
|
||||||
|
|||||||
@@ -65,13 +65,13 @@ type Manager interface {
|
|||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
// rule ID as comment for the rule
|
// rule ID as comment for the rule
|
||||||
AddPeerFiltering(
|
AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto Protocol,
|
proto Protocol,
|
||||||
sPort *Port,
|
sPort *Port,
|
||||||
dPort *Port,
|
dPort *Port,
|
||||||
action Action,
|
action Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
|
||||||
) ([]Rule, error)
|
) ([]Rule, error)
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
// DeletePeerRule from the firewall by rule definition
|
||||||
@@ -80,7 +80,15 @@ type Manager interface {
|
|||||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||||
IsServerRouteSupported() bool
|
IsServerRouteSupported() bool
|
||||||
|
|
||||||
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error)
|
AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
|
sources []netip.Prefix,
|
||||||
|
destination netip.Prefix,
|
||||||
|
proto Protocol,
|
||||||
|
sPort *Port,
|
||||||
|
dPort *Port,
|
||||||
|
action Action,
|
||||||
|
) (Rule, error)
|
||||||
|
|
||||||
// DeleteRouteRule deletes a routing rule
|
// DeleteRouteRule deletes a routing rule
|
||||||
DeleteRouteRule(rule Rule) error
|
DeleteRouteRule(rule Rule) error
|
||||||
|
|||||||
@@ -84,13 +84,13 @@ func (m *AclManager) init(workTable *nftables.Table) error {
|
|||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
// rule ID as comment for the rule
|
// rule ID as comment for the rule
|
||||||
func (m *AclManager) AddPeerFiltering(
|
func (m *AclManager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
var ipset *nftables.Set
|
var ipset *nftables.Set
|
||||||
if ipsetName != "" {
|
if ipsetName != "" {
|
||||||
@@ -102,7 +102,7 @@ func (m *AclManager) AddPeerFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
newRules := make([]firewall.Rule, 0, 2)
|
newRules := make([]firewall.Rule, 0, 2)
|
||||||
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment)
|
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -256,7 +256,6 @@ func (m *AclManager) addIOFiltering(
|
|||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipset *nftables.Set,
|
ipset *nftables.Set,
|
||||||
comment string,
|
|
||||||
) (*Rule, error) {
|
) (*Rule, error) {
|
||||||
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
|
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
|
||||||
if r, ok := m.rules[ruleId]; ok {
|
if r, ok := m.rules[ruleId]; ok {
|
||||||
@@ -338,7 +337,7 @@ func (m *AclManager) addIOFiltering(
|
|||||||
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||||
}
|
}
|
||||||
|
|
||||||
userData := []byte(strings.Join([]string{ruleId, comment}, " "))
|
userData := []byte(ruleId)
|
||||||
|
|
||||||
chain := m.chainInputRules
|
chain := m.chainInputRules
|
||||||
nftRule := m.rConn.AddRule(&nftables.Rule{
|
nftRule := m.rConn.AddRule(&nftables.Rule{
|
||||||
|
|||||||
@@ -113,13 +113,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
// rule ID as comment for the rule
|
// rule ID as comment for the rule
|
||||||
func (m *Manager) AddPeerFiltering(
|
func (m *Manager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
@@ -129,10 +129,11 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, action, ipsetName, comment)
|
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination netip.Prefix,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
@@ -147,7 +148,7 @@ func (m *Manager) AddRouteFiltering(
|
|||||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
// DeletePeerRule from the firewall by rule definition
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
|
|
||||||
testClient := &nftables.Conn{}
|
testClient := &nftables.Conn{}
|
||||||
|
|
||||||
rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "", "")
|
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Flush()
|
err = manager.Flush()
|
||||||
@@ -201,7 +201,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
if i%100 == 0 {
|
if i%100 == 0 {
|
||||||
@@ -283,10 +283,11 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
ip := net.ParseIP("100.96.0.1")
|
ip := net.ParseIP("100.96.0.1")
|
||||||
_, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "", "test rule")
|
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add peer filtering rule")
|
require.NoError(t, err, "failed to add peer filtering rule")
|
||||||
|
|
||||||
_, err = manager.AddRouteFiltering(
|
_, err = manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||||
netip.MustParsePrefix("10.1.0.0/24"),
|
netip.MustParsePrefix("10.1.0.0/24"),
|
||||||
fw.ProtocolTCP,
|
fw.ProtocolTCP,
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
@@ -228,6 +228,7 @@ func (r *router) createContainers() error {
|
|||||||
|
|
||||||
// AddRouteFiltering appends a nftables rule to the routing chain
|
// AddRouteFiltering appends a nftables rule to the routing chain
|
||||||
func (r *router) AddRouteFiltering(
|
func (r *router) AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination netip.Prefix,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
@@ -236,7 +237,7 @@ func (r *router) AddRouteFiltering(
|
|||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
|
|
||||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||||
return ruleKey, nil
|
return ruleKey, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||||
require.NoError(t, err, "AddRouteFiltering failed")
|
require.NoError(t, err, "AddRouteFiltering failed")
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ package uspfilter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -16,8 +17,8 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = make(map[string]RuleSet)
|
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||||
m.incomingRules = make(map[string]RuleSet)
|
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||||
|
|
||||||
if m.udpTracker != nil {
|
if m.udpTracker != nil {
|
||||||
m.udpTracker.Close()
|
m.udpTracker.Close()
|
||||||
@@ -31,8 +32,8 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
m.tcpTracker.Close()
|
m.tcpTracker.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.forwarder != nil {
|
if fwder := m.forwarder.Load(); fwder != nil {
|
||||||
m.forwarder.Stop()
|
fwder.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.logger != nil {
|
if m.logger != nil {
|
||||||
|
|||||||
@@ -3,12 +3,14 @@ package uspfilter
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -20,28 +22,31 @@ const (
|
|||||||
firewallRuleName = "Netbird"
|
firewallRuleName = "Netbird"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Close closes the firewall manager
|
// Reset firewall to the default state
|
||||||
func (m *Manager) Close(*statemanager.Manager) error {
|
func (m *Manager) Close(*statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = make(map[string]RuleSet)
|
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||||
m.incomingRules = make(map[string]RuleSet)
|
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||||
|
|
||||||
if m.udpTracker != nil {
|
if m.udpTracker != nil {
|
||||||
m.udpTracker.Close()
|
m.udpTracker.Close()
|
||||||
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
if m.icmpTracker != nil {
|
||||||
m.icmpTracker.Close()
|
m.icmpTracker.Close()
|
||||||
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
if m.tcpTracker != nil {
|
||||||
m.tcpTracker.Close()
|
m.tcpTracker.Close()
|
||||||
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.forwarder != nil {
|
if fwder := m.forwarder.Load(); fwder != nil {
|
||||||
m.forwarder.Stop()
|
fwder.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.logger != nil {
|
if m.logger != nil {
|
||||||
|
|||||||
@@ -1,20 +1,27 @@
|
|||||||
// common.go
|
|
||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"fmt"
|
||||||
"sync"
|
"net/netip"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// BaseConnTrack provides common fields and locking for all connection types
|
// BaseConnTrack provides common fields and locking for all connection types
|
||||||
type BaseConnTrack struct {
|
type BaseConnTrack struct {
|
||||||
SourceIP net.IP
|
FlowId uuid.UUID
|
||||||
DestIP net.IP
|
Direction nftypes.Direction
|
||||||
SourcePort uint16
|
SourceIP netip.Addr
|
||||||
DestPort uint16
|
DestIP netip.Addr
|
||||||
lastSeen atomic.Int64 // Unix nano for atomic access
|
lastSeen atomic.Int64
|
||||||
|
PacketsTx atomic.Uint64
|
||||||
|
PacketsRx atomic.Uint64
|
||||||
|
BytesTx atomic.Uint64
|
||||||
|
BytesRx atomic.Uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
// these small methods will be inlined by the compiler
|
// these small methods will be inlined by the compiler
|
||||||
@@ -24,6 +31,17 @@ func (b *BaseConnTrack) UpdateLastSeen() {
|
|||||||
b.lastSeen.Store(time.Now().UnixNano())
|
b.lastSeen.Store(time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateCounters safely updates the packet and byte counters
|
||||||
|
func (b *BaseConnTrack) UpdateCounters(direction nftypes.Direction, bytes int) {
|
||||||
|
if direction == nftypes.Egress {
|
||||||
|
b.PacketsTx.Add(1)
|
||||||
|
b.BytesTx.Add(uint64(bytes))
|
||||||
|
} else {
|
||||||
|
b.PacketsRx.Add(1)
|
||||||
|
b.BytesRx.Add(uint64(bytes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// GetLastSeen safely gets the last seen timestamp
|
// GetLastSeen safely gets the last seen timestamp
|
||||||
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
||||||
return time.Unix(0, b.lastSeen.Load())
|
return time.Unix(0, b.lastSeen.Load())
|
||||||
@@ -35,92 +53,14 @@ func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
|
|||||||
return time.Since(lastSeen) > timeout
|
return time.Since(lastSeen) > timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
// IPAddr is a fixed-size IP address to avoid allocations
|
|
||||||
type IPAddr [16]byte
|
|
||||||
|
|
||||||
// MakeIPAddr creates an IPAddr from net.IP
|
|
||||||
func MakeIPAddr(ip net.IP) (addr IPAddr) {
|
|
||||||
// Optimization: check for v4 first as it's more common
|
|
||||||
if ip4 := ip.To4(); ip4 != nil {
|
|
||||||
copy(addr[12:], ip4)
|
|
||||||
} else {
|
|
||||||
copy(addr[:], ip.To16())
|
|
||||||
}
|
|
||||||
return addr
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConnKey uniquely identifies a connection
|
// ConnKey uniquely identifies a connection
|
||||||
type ConnKey struct {
|
type ConnKey struct {
|
||||||
SrcIP IPAddr
|
SrcIP netip.Addr
|
||||||
DstIP IPAddr
|
DstIP netip.Addr
|
||||||
SrcPort uint16
|
SrcPort uint16
|
||||||
DstPort uint16
|
DstPort uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeConnKey creates a connection key
|
func (c ConnKey) String() string {
|
||||||
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
|
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
||||||
return ConnKey{
|
|
||||||
SrcIP: MakeIPAddr(srcIP),
|
|
||||||
DstIP: MakeIPAddr(dstIP),
|
|
||||||
SrcPort: srcPort,
|
|
||||||
DstPort: dstPort,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateIPs checks if IPs match without allocation
|
|
||||||
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
|
|
||||||
if ip4 := pktIP.To4(); ip4 != nil {
|
|
||||||
// Compare IPv4 addresses (last 4 bytes)
|
|
||||||
for i := 0; i < 4; i++ {
|
|
||||||
if connIP[12+i] != ip4[i] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// Compare full IPv6 addresses
|
|
||||||
ip6 := pktIP.To16()
|
|
||||||
for i := 0; i < 16; i++ {
|
|
||||||
if connIP[i] != ip6[i] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// PreallocatedIPs is a pool of IP byte slices to reduce allocations
|
|
||||||
type PreallocatedIPs struct {
|
|
||||||
sync.Pool
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewPreallocatedIPs creates a new IP pool
|
|
||||||
func NewPreallocatedIPs() *PreallocatedIPs {
|
|
||||||
return &PreallocatedIPs{
|
|
||||||
Pool: sync.Pool{
|
|
||||||
New: func() interface{} {
|
|
||||||
ip := make(net.IP, 16)
|
|
||||||
return &ip
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get retrieves an IP from the pool
|
|
||||||
func (p *PreallocatedIPs) Get() net.IP {
|
|
||||||
return *p.Pool.Get().(*net.IP)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Put returns an IP to the pool
|
|
||||||
func (p *PreallocatedIPs) Put(ip net.IP) {
|
|
||||||
p.Pool.Put(&ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
// copyIP copies an IP address efficiently
|
|
||||||
func copyIP(dst, src net.IP) {
|
|
||||||
if len(src) == 16 {
|
|
||||||
copy(dst, src)
|
|
||||||
} else {
|
|
||||||
// Handle IPv4
|
|
||||||
copy(dst[12:], src.To4())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,94 +1,67 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"context"
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
|
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||||
func BenchmarkIPOperations(b *testing.B) {
|
|
||||||
b.Run("MakeIPAddr", func(b *testing.B) {
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = MakeIPAddr(ip)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("ValidateIPs", func(b *testing.B) {
|
|
||||||
ip1 := net.ParseIP("192.168.1.1")
|
|
||||||
ip2 := net.ParseIP("192.168.1.1")
|
|
||||||
addr := MakeIPAddr(ip1)
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = ValidateIPs(addr, ip2)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("IPPool", func(b *testing.B) {
|
|
||||||
pool := NewPreallocatedIPs()
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
ip := pool.Get()
|
|
||||||
pool.Put(ip)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Memory pressure tests
|
// Memory pressure tests
|
||||||
func BenchmarkMemoryPressure(b *testing.B) {
|
func BenchmarkMemoryPressure(b *testing.B) {
|
||||||
b.Run("TCPHighLoad", func(b *testing.B) {
|
b.Run("TCPHighLoad", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Generate different IPs
|
// Generate different IPs
|
||||||
srcIPs := make([]net.IP, 100)
|
srcIPs := make([]netip.Addr, 100)
|
||||||
dstIPs := make([]net.IP, 100)
|
dstIPs := make([]netip.Addr, 100)
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
|
||||||
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
srcIdx := i % len(srcIPs)
|
srcIdx := i % len(srcIPs)
|
||||||
dstIdx := (i + 1) % len(dstIPs)
|
dstIdx := (i + 1) % len(dstIPs)
|
||||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn)
|
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, 0)
|
||||||
|
|
||||||
// Simulate some valid inbound packets
|
// Simulate some valid inbound packets
|
||||||
if i%3 == 0 {
|
if i%3 == 0 {
|
||||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck)
|
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
b.Run("UDPHighLoad", func(b *testing.B) {
|
b.Run("UDPHighLoad", func(b *testing.B) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Generate different IPs
|
// Generate different IPs
|
||||||
srcIPs := make([]net.IP, 100)
|
srcIPs := make([]netip.Addr, 100)
|
||||||
dstIPs := make([]net.IP, 100)
|
dstIPs := make([]netip.Addr, 100)
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
|
||||||
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
srcIdx := i % len(srcIPs)
|
srcIdx := i % len(srcIPs)
|
||||||
dstIdx := (i + 1) % len(dstIPs)
|
dstIdx := (i + 1) % len(dstIPs)
|
||||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80)
|
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, 0)
|
||||||
|
|
||||||
// Simulate some valid inbound packets
|
// Simulate some valid inbound packets
|
||||||
if i%3 == 0 {
|
if i%3 == 0 {
|
||||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535))
|
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -2,13 +2,16 @@ package conntrack
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -20,18 +23,20 @@ const (
|
|||||||
|
|
||||||
// ICMPConnKey uniquely identifies an ICMP connection
|
// ICMPConnKey uniquely identifies an ICMP connection
|
||||||
type ICMPConnKey struct {
|
type ICMPConnKey struct {
|
||||||
// Supports both IPv4 and IPv6
|
SrcIP netip.Addr
|
||||||
SrcIP [16]byte
|
DstIP netip.Addr
|
||||||
DstIP [16]byte
|
ID uint16
|
||||||
Sequence uint16 // ICMP sequence number
|
}
|
||||||
ID uint16 // ICMP identifier
|
|
||||||
|
func (i ICMPConnKey) String() string {
|
||||||
|
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ICMPConnTrack represents an ICMP connection state
|
// ICMPConnTrack represents an ICMP connection state
|
||||||
type ICMPConnTrack struct {
|
type ICMPConnTrack struct {
|
||||||
BaseConnTrack
|
BaseConnTrack
|
||||||
Sequence uint16
|
ICMPType uint8
|
||||||
ID uint16
|
ICMPCode uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
// ICMPTracker manages ICMP connection states
|
// ICMPTracker manages ICMP connection states
|
||||||
@@ -42,11 +47,11 @@ type ICMPTracker struct {
|
|||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
tickerCancel context.CancelFunc
|
tickerCancel context.CancelFunc
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
ipPool *PreallocatedIPs
|
flowLogger nftypes.FlowLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewICMPTracker creates a new ICMP connection tracker
|
// NewICMPTracker creates a new ICMP connection tracker
|
||||||
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
|
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
|
||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
timeout = DefaultICMPTimeout
|
timeout = DefaultICMPTimeout
|
||||||
}
|
}
|
||||||
@@ -59,67 +64,107 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
|
|||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||||
tickerCancel: cancel,
|
tickerCancel: cancel,
|
||||||
ipPool: NewPreallocatedIPs(),
|
flowLogger: flowLogger,
|
||||||
}
|
}
|
||||||
|
|
||||||
go tracker.cleanupRoutine(ctx)
|
go tracker.cleanupRoutine(ctx)
|
||||||
return tracker
|
return tracker
|
||||||
}
|
}
|
||||||
|
|
||||||
// TrackOutbound records an outbound ICMP Echo Request
|
func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, direction nftypes.Direction, size int) (ICMPConnKey, bool) {
|
||||||
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
|
key := ICMPConnKey{
|
||||||
key := makeICMPKey(srcIP, dstIP, id, seq)
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
t.mutex.Lock()
|
ID: id,
|
||||||
conn, exists := t.connections[key]
|
|
||||||
if !exists {
|
|
||||||
srcIPCopy := t.ipPool.Get()
|
|
||||||
dstIPCopy := t.ipPool.Get()
|
|
||||||
copyIP(srcIPCopy, srcIP)
|
|
||||||
copyIP(dstIPCopy, dstIP)
|
|
||||||
|
|
||||||
conn = &ICMPConnTrack{
|
|
||||||
BaseConnTrack: BaseConnTrack{
|
|
||||||
SourceIP: srcIPCopy,
|
|
||||||
DestIP: dstIPCopy,
|
|
||||||
},
|
|
||||||
ID: id,
|
|
||||||
Sequence: seq,
|
|
||||||
}
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
t.connections[key] = conn
|
|
||||||
|
|
||||||
t.logger.Trace("New ICMP connection %v", key)
|
|
||||||
}
|
}
|
||||||
t.mutex.Unlock()
|
|
||||||
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
|
||||||
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
|
|
||||||
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
key := makeICMPKey(dstIP, srcIP, id, seq)
|
|
||||||
|
|
||||||
t.mutex.RLock()
|
t.mutex.RLock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if !exists {
|
if exists {
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
conn.UpdateCounters(direction, size)
|
||||||
|
|
||||||
|
return key, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackOutbound records an outbound ICMP connection
|
||||||
|
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
|
||||||
|
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
|
||||||
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
|
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackInbound records an inbound ICMP Echo Request
|
||||||
|
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
|
||||||
|
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
// track is the common implementation for tracking both inbound and outbound ICMP connections
|
||||||
|
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
|
||||||
|
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
|
||||||
|
if exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
typ, code := typecode.Type(), typecode.Code()
|
||||||
|
|
||||||
|
// non echo requests don't need tracking
|
||||||
|
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
||||||
|
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||||
|
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &ICMPConnTrack{
|
||||||
|
BaseConnTrack: BaseConnTrack{
|
||||||
|
FlowId: uuid.New(),
|
||||||
|
Direction: direction,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
},
|
||||||
|
ICMPType: typ,
|
||||||
|
ICMPCode: code,
|
||||||
|
}
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.connections[key] = conn
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
|
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||||
|
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||||
|
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
|
||||||
|
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.timeoutExceeded(t.timeout) {
|
key := ICMPConnKey{
|
||||||
|
SrcIP: dstIP,
|
||||||
|
DstIP: srcIP,
|
||||||
|
ID: id,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mutex.RLock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
|
if !exists || conn.timeoutExceeded(t.timeout) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
conn.UpdateLastSeen()
|
||||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
conn.UpdateCounters(nftypes.Ingress, size)
|
||||||
conn.ID == id &&
|
|
||||||
conn.Sequence == seq
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
|
func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
|
||||||
@@ -134,17 +179,18 @@ func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ICMPTracker) cleanup() {
|
func (t *ICMPTracker) cleanup() {
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
defer t.mutex.Unlock()
|
defer t.mutex.Unlock()
|
||||||
|
|
||||||
for key, conn := range t.connections {
|
for key, conn := range t.connections {
|
||||||
if conn.timeoutExceeded(t.timeout) {
|
if conn.timeoutExceeded(t.timeout) {
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
t.logger.Debug("Removed ICMP connection %v (timeout)", key)
|
t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||||
|
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -154,20 +200,46 @@ func (t *ICMPTracker) Close() {
|
|||||||
t.tickerCancel()
|
t.tickerCancel()
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
for _, conn := range t.connections {
|
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
}
|
|
||||||
t.connections = nil
|
t.connections = nil
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeICMPKey creates an ICMP connection key
|
func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []byte) {
|
||||||
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
|
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
return ICMPConnKey{
|
FlowID: conn.FlowId,
|
||||||
SrcIP: MakeIPAddr(srcIP),
|
Type: typ,
|
||||||
DstIP: MakeIPAddr(dstIP),
|
RuleID: ruleID,
|
||||||
ID: id,
|
Direction: conn.Direction,
|
||||||
Sequence: seq,
|
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
|
||||||
}
|
SourceIP: conn.SourceIP,
|
||||||
|
DestIP: conn.DestIP,
|
||||||
|
ICMPType: conn.ICMPType,
|
||||||
|
ICMPCode: conn.ICMPCode,
|
||||||
|
RxPackets: conn.PacketsRx.Load(),
|
||||||
|
TxPackets: conn.PacketsTx.Load(),
|
||||||
|
RxBytes: conn.BytesRx.Load(),
|
||||||
|
TxBytes: conn.BytesTx.Load(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, ruleID []byte, size int) {
|
||||||
|
fields := nftypes.EventFields{
|
||||||
|
FlowID: uuid.New(),
|
||||||
|
Type: nftypes.TypeStart,
|
||||||
|
RuleID: ruleID,
|
||||||
|
Direction: direction,
|
||||||
|
Protocol: nftypes.ICMP,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
ICMPType: typ,
|
||||||
|
ICMPCode: code,
|
||||||
|
}
|
||||||
|
if direction == nftypes.Ingress {
|
||||||
|
fields.RxPackets = 1
|
||||||
|
fields.RxBytes = uint64(size)
|
||||||
|
} else {
|
||||||
|
fields.TxPackets = 1
|
||||||
|
fields.TxBytes = uint64(size)
|
||||||
|
}
|
||||||
|
t.flowLogger.StoreEvent(fields)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,39 +1,39 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BenchmarkICMPTracker(b *testing.B) {
|
func BenchmarkICMPTracker(b *testing.B) {
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535))
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
// Pre-populate some connections
|
// Pre-populate some connections
|
||||||
for i := 0; i < 1000; i++ {
|
for i := 0; i < 1000; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i))
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0)
|
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,12 +4,15 @@ package conntrack
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -40,6 +43,35 @@ const (
|
|||||||
// TCPState represents the state of a TCP connection
|
// TCPState represents the state of a TCP connection
|
||||||
type TCPState int
|
type TCPState int
|
||||||
|
|
||||||
|
func (s TCPState) String() string {
|
||||||
|
switch s {
|
||||||
|
case TCPStateNew:
|
||||||
|
return "New"
|
||||||
|
case TCPStateSynSent:
|
||||||
|
return "SYN Sent"
|
||||||
|
case TCPStateSynReceived:
|
||||||
|
return "SYN Received"
|
||||||
|
case TCPStateEstablished:
|
||||||
|
return "Established"
|
||||||
|
case TCPStateFinWait1:
|
||||||
|
return "FIN Wait 1"
|
||||||
|
case TCPStateFinWait2:
|
||||||
|
return "FIN Wait 2"
|
||||||
|
case TCPStateClosing:
|
||||||
|
return "Closing"
|
||||||
|
case TCPStateTimeWait:
|
||||||
|
return "Time Wait"
|
||||||
|
case TCPStateCloseWait:
|
||||||
|
return "Close Wait"
|
||||||
|
case TCPStateLastAck:
|
||||||
|
return "Last ACK"
|
||||||
|
case TCPStateClosed:
|
||||||
|
return "Closed"
|
||||||
|
default:
|
||||||
|
return "Unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
TCPStateNew TCPState = iota
|
TCPStateNew TCPState = iota
|
||||||
TCPStateSynSent
|
TCPStateSynSent
|
||||||
@@ -54,19 +86,14 @@ const (
|
|||||||
TCPStateClosed
|
TCPStateClosed
|
||||||
)
|
)
|
||||||
|
|
||||||
// TCPConnKey uniquely identifies a TCP connection
|
|
||||||
type TCPConnKey struct {
|
|
||||||
SrcIP [16]byte
|
|
||||||
DstIP [16]byte
|
|
||||||
SrcPort uint16
|
|
||||||
DstPort uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
// TCPConnTrack represents a TCP connection state
|
// TCPConnTrack represents a TCP connection state
|
||||||
type TCPConnTrack struct {
|
type TCPConnTrack struct {
|
||||||
BaseConnTrack
|
BaseConnTrack
|
||||||
|
SourcePort uint16
|
||||||
|
DestPort uint16
|
||||||
State TCPState
|
State TCPState
|
||||||
established atomic.Bool
|
established atomic.Bool
|
||||||
|
tombstone atomic.Bool
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,6 +107,16 @@ func (t *TCPConnTrack) SetEstablished(state bool) {
|
|||||||
t.established.Store(state)
|
t.established.Store(state)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsTombstone safely checks if the connection is marked for deletion
|
||||||
|
func (t *TCPConnTrack) IsTombstone() bool {
|
||||||
|
return t.tombstone.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTombstone safely marks the connection for deletion
|
||||||
|
func (t *TCPConnTrack) SetTombstone() {
|
||||||
|
t.tombstone.Store(true)
|
||||||
|
}
|
||||||
|
|
||||||
// TCPTracker manages TCP connection states
|
// TCPTracker manages TCP connection states
|
||||||
type TCPTracker struct {
|
type TCPTracker struct {
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
@@ -88,11 +125,14 @@ type TCPTracker struct {
|
|||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
tickerCancel context.CancelFunc
|
tickerCancel context.CancelFunc
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
ipPool *PreallocatedIPs
|
flowLogger nftypes.FlowLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTCPTracker creates a new TCP connection tracker
|
// NewTCPTracker creates a new TCP connection tracker
|
||||||
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
|
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
|
||||||
|
if timeout == 0 {
|
||||||
|
timeout = DefaultTCPTimeout
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
@@ -102,59 +142,91 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
|
|||||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||||
tickerCancel: cancel,
|
tickerCancel: cancel,
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
ipPool: NewPreallocatedIPs(),
|
flowLogger: flowLogger,
|
||||||
}
|
}
|
||||||
|
|
||||||
go tracker.cleanupRoutine(ctx)
|
go tracker.cleanupRoutine(ctx)
|
||||||
return tracker
|
return tracker
|
||||||
}
|
}
|
||||||
|
|
||||||
// TrackOutbound processes an outbound TCP packet and updates connection state
|
func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||||
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
key := ConnKey{
|
||||||
// Create key before lock
|
SrcIP: srcIP,
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mutex.RLock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
|
if exists {
|
||||||
|
conn.Lock()
|
||||||
|
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
|
||||||
|
conn.Unlock()
|
||||||
|
|
||||||
|
conn.UpdateCounters(direction, size)
|
||||||
|
|
||||||
|
return key, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackOutbound records an outbound TCP connection
|
||||||
|
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) {
|
||||||
|
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, 0, 0); !exists {
|
||||||
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
|
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackInbound processes an inbound TCP packet and updates connection state
|
||||||
|
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, ruleID []byte, size int) {
|
||||||
|
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
// track is the common implementation for tracking both inbound and outbound connections
|
||||||
|
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
|
||||||
|
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
|
||||||
|
if exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &TCPConnTrack{
|
||||||
|
BaseConnTrack: BaseConnTrack{
|
||||||
|
FlowId: uuid.New(),
|
||||||
|
Direction: direction,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
},
|
||||||
|
SourcePort: srcPort,
|
||||||
|
DestPort: dstPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.established.Store(false)
|
||||||
|
conn.tombstone.Store(false)
|
||||||
|
|
||||||
|
t.logger.Trace("New %s TCP connection: %s", direction, key)
|
||||||
|
t.updateState(key, conn, flags, direction == nftypes.Egress)
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
conn, exists := t.connections[key]
|
t.connections[key] = conn
|
||||||
if !exists {
|
|
||||||
// Use preallocated IPs
|
|
||||||
srcIPCopy := t.ipPool.Get()
|
|
||||||
dstIPCopy := t.ipPool.Get()
|
|
||||||
copyIP(srcIPCopy, srcIP)
|
|
||||||
copyIP(dstIPCopy, dstIP)
|
|
||||||
|
|
||||||
conn = &TCPConnTrack{
|
|
||||||
BaseConnTrack: BaseConnTrack{
|
|
||||||
SourceIP: srcIPCopy,
|
|
||||||
DestIP: dstIPCopy,
|
|
||||||
SourcePort: srcPort,
|
|
||||||
DestPort: dstPort,
|
|
||||||
},
|
|
||||||
State: TCPStateNew,
|
|
||||||
}
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
conn.established.Store(false)
|
|
||||||
t.connections[key] = conn
|
|
||||||
|
|
||||||
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
|
|
||||||
}
|
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
|
|
||||||
// Lock individual connection for state update
|
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||||
conn.Lock()
|
|
||||||
t.updateState(conn, flags, true)
|
|
||||||
conn.Unlock()
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||||
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
|
func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) bool {
|
||||||
if !isValidFlagCombination(flags) {
|
key := ConnKey{
|
||||||
return false
|
SrcIP: dstIP,
|
||||||
|
DstIP: srcIP,
|
||||||
|
SrcPort: dstPort,
|
||||||
|
DstPort: srcPort,
|
||||||
}
|
}
|
||||||
|
|
||||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
|
||||||
|
|
||||||
t.mutex.RLock()
|
t.mutex.RLock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
@@ -163,22 +235,26 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle RST packets
|
// Handle RST flag specially - it always causes transition to closed
|
||||||
if flags&TCPRst != 0 {
|
if flags&TCPRst != 0 {
|
||||||
conn.Lock()
|
if conn.IsTombstone() {
|
||||||
if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
|
|
||||||
conn.State = TCPStateClosed
|
|
||||||
conn.SetEstablished(false)
|
|
||||||
conn.Unlock()
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
conn.Lock()
|
||||||
|
conn.SetTombstone()
|
||||||
|
conn.State = TCPStateClosed
|
||||||
|
conn.SetEstablished(false)
|
||||||
conn.Unlock()
|
conn.Unlock()
|
||||||
return false
|
conn.UpdateCounters(nftypes.Ingress, size)
|
||||||
|
|
||||||
|
t.logger.Trace("TCP connection reset: %s", key)
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.Lock()
|
conn.Lock()
|
||||||
t.updateState(conn, flags, false)
|
t.updateState(key, conn, flags, false)
|
||||||
conn.UpdateLastSeen()
|
|
||||||
isEstablished := conn.IsEstablished()
|
isEstablished := conn.IsEstablished()
|
||||||
isValidState := t.isValidStateForFlags(conn.State, flags)
|
isValidState := t.isValidStateForFlags(conn.State, flags)
|
||||||
conn.Unlock()
|
conn.Unlock()
|
||||||
@@ -187,18 +263,17 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// updateState updates the TCP connection state based on flags
|
// updateState updates the TCP connection state based on flags
|
||||||
func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) {
|
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
|
||||||
// Handle RST flag specially - it always causes transition to closed
|
conn.UpdateLastSeen()
|
||||||
if flags&TCPRst != 0 {
|
|
||||||
conn.State = TCPStateClosed
|
|
||||||
conn.SetEstablished(false)
|
|
||||||
|
|
||||||
t.logger.Trace("TCP connection reset: %s:%d -> %s:%d",
|
state := conn.State
|
||||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
defer func() {
|
||||||
return
|
if state != conn.State {
|
||||||
}
|
t.logger.Trace("TCP connection %s transitioned from %s to %s", key, state, conn.State)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
switch conn.State {
|
switch state {
|
||||||
case TCPStateNew:
|
case TCPStateNew:
|
||||||
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||||
conn.State = TCPStateSynSent
|
conn.State = TCPStateSynSent
|
||||||
@@ -207,11 +282,11 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
|||||||
case TCPStateSynSent:
|
case TCPStateSynSent:
|
||||||
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
||||||
if isOutbound {
|
if isOutbound {
|
||||||
conn.State = TCPStateSynReceived
|
|
||||||
} else {
|
|
||||||
// Simultaneous open
|
|
||||||
conn.State = TCPStateEstablished
|
conn.State = TCPStateEstablished
|
||||||
conn.SetEstablished(true)
|
conn.SetEstablished(true)
|
||||||
|
} else {
|
||||||
|
// Simultaneous open
|
||||||
|
conn.State = TCPStateSynReceived
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -229,22 +304,32 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
|||||||
conn.State = TCPStateCloseWait
|
conn.State = TCPStateCloseWait
|
||||||
}
|
}
|
||||||
conn.SetEstablished(false)
|
conn.SetEstablished(false)
|
||||||
|
} else if flags&TCPRst != 0 {
|
||||||
|
conn.State = TCPStateClosed
|
||||||
|
conn.SetTombstone()
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateFinWait1:
|
case TCPStateFinWait1:
|
||||||
switch {
|
switch {
|
||||||
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||||
// Simultaneous close - both sides sent FIN
|
|
||||||
conn.State = TCPStateClosing
|
conn.State = TCPStateClosing
|
||||||
case flags&TCPFin != 0:
|
case flags&TCPFin != 0:
|
||||||
conn.State = TCPStateFinWait2
|
conn.State = TCPStateFinWait2
|
||||||
case flags&TCPAck != 0:
|
case flags&TCPAck != 0:
|
||||||
conn.State = TCPStateFinWait2
|
conn.State = TCPStateFinWait2
|
||||||
|
case flags&TCPRst != 0:
|
||||||
|
conn.State = TCPStateClosed
|
||||||
|
conn.SetTombstone()
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateFinWait2:
|
case TCPStateFinWait2:
|
||||||
if flags&TCPFin != 0 {
|
if flags&TCPFin != 0 {
|
||||||
conn.State = TCPStateTimeWait
|
conn.State = TCPStateTimeWait
|
||||||
|
|
||||||
|
t.logger.Trace("TCP connection %s completed", key)
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateClosing:
|
case TCPStateClosing:
|
||||||
@@ -252,8 +337,8 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
|||||||
conn.State = TCPStateTimeWait
|
conn.State = TCPStateTimeWait
|
||||||
// Keep established = false from previous state
|
// Keep established = false from previous state
|
||||||
|
|
||||||
t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d",
|
t.logger.Trace("TCP connection %s closed (simultaneous)", key)
|
||||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateCloseWait:
|
case TCPStateCloseWait:
|
||||||
@@ -264,17 +349,12 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
|||||||
case TCPStateLastAck:
|
case TCPStateLastAck:
|
||||||
if flags&TCPAck != 0 {
|
if flags&TCPAck != 0 {
|
||||||
conn.State = TCPStateClosed
|
conn.State = TCPStateClosed
|
||||||
|
conn.SetTombstone()
|
||||||
|
|
||||||
t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d",
|
// Send close event for gracefully closed connections
|
||||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
t.logger.Trace("TCP connection %s closed gracefully", key)
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateTimeWait:
|
|
||||||
// Stay in TIME-WAIT for 2MSL before transitioning to closed
|
|
||||||
// This is handled by the cleanup routine
|
|
||||||
|
|
||||||
t.logger.Trace("TCP connection completed - %s:%d -> %s:%d",
|
|
||||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -337,6 +417,12 @@ func (t *TCPTracker) cleanup() {
|
|||||||
defer t.mutex.Unlock()
|
defer t.mutex.Unlock()
|
||||||
|
|
||||||
for key, conn := range t.connections {
|
for key, conn := range t.connections {
|
||||||
|
if conn.IsTombstone() {
|
||||||
|
// Clean up tombstoned connections without sending an event
|
||||||
|
delete(t.connections, key)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
var timeout time.Duration
|
var timeout time.Duration
|
||||||
switch {
|
switch {
|
||||||
case conn.State == TCPStateTimeWait:
|
case conn.State == TCPStateTimeWait:
|
||||||
@@ -347,14 +433,16 @@ func (t *TCPTracker) cleanup() {
|
|||||||
timeout = TCPHandshakeTimeout
|
timeout = TCPHandshakeTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
lastSeen := conn.GetLastSeen()
|
if conn.timeoutExceeded(timeout) {
|
||||||
if time.Since(lastSeen) > timeout {
|
|
||||||
// Return IPs to pool
|
// Return IPs to pool
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
t.logger.Trace("Cleaned up timed-out TCP connection %s", key)
|
||||||
|
|
||||||
|
// event already handled by state change
|
||||||
|
if conn.State != TCPStateTimeWait {
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -365,10 +453,6 @@ func (t *TCPTracker) Close() {
|
|||||||
|
|
||||||
// Clean up all remaining IPs
|
// Clean up all remaining IPs
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
for _, conn := range t.connections {
|
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
}
|
|
||||||
t.connections = nil
|
t.connections = nil
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
}
|
}
|
||||||
@@ -386,3 +470,21 @@ func isValidFlagCombination(flags uint8) bool {
|
|||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack, ruleID []byte) {
|
||||||
|
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: conn.FlowId,
|
||||||
|
Type: typ,
|
||||||
|
RuleID: ruleID,
|
||||||
|
Direction: conn.Direction,
|
||||||
|
Protocol: nftypes.TCP,
|
||||||
|
SourceIP: conn.SourceIP,
|
||||||
|
DestIP: conn.DestIP,
|
||||||
|
SourcePort: conn.SourcePort,
|
||||||
|
DestPort: conn.DestPort,
|
||||||
|
RxPackets: conn.PacketsRx.Load(),
|
||||||
|
TxPackets: conn.PacketsTx.Load(),
|
||||||
|
RxBytes: conn.BytesRx.Load(),
|
||||||
|
TxBytes: conn.BytesTx.Load(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -9,11 +9,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestTCPStateMachine(t *testing.T) {
|
func TestTCPStateMachine(t *testing.T) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("100.64.0.1")
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
dstIP := net.ParseIP("100.64.0.2")
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
srcPort := uint16(12345)
|
srcPort := uint16(12345)
|
||||||
dstPort := uint16(80)
|
dstPort := uint16(80)
|
||||||
|
|
||||||
@@ -58,7 +58,7 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags)
|
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, 0)
|
||||||
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
|
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -76,17 +76,17 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
// Send initial SYN
|
// Send initial SYN
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||||
|
|
||||||
// Receive SYN-ACK
|
// Receive SYN-ACK
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||||
require.True(t, valid, "SYN-ACK should be allowed")
|
require.True(t, valid, "SYN-ACK should be allowed")
|
||||||
|
|
||||||
// Send ACK
|
// Send ACK
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
// Test data transfer
|
// Test data transfer
|
||||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 0)
|
||||||
require.True(t, valid, "Data should be allowed after handshake")
|
require.True(t, valid, "Data should be allowed after handshake")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -99,18 +99,18 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
// Send FIN
|
// Send FIN
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
|
||||||
// Receive ACK for FIN
|
// Receive ACK for FIN
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
require.True(t, valid, "ACK for FIN should be allowed")
|
require.True(t, valid, "ACK for FIN should be allowed")
|
||||||
|
|
||||||
// Receive FIN from other side
|
// Receive FIN from other side
|
||||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
require.True(t, valid, "FIN should be allowed")
|
require.True(t, valid, "FIN should be allowed")
|
||||||
|
|
||||||
// Send final ACK
|
// Send final ACK
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -122,7 +122,7 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
// Receive RST
|
// Receive RST
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||||
require.True(t, valid, "RST should be allowed for established connection")
|
require.True(t, valid, "RST should be allowed for established connection")
|
||||||
|
|
||||||
// Connection is logically dead but we don't enforce blocking subsequent packets
|
// Connection is logically dead but we don't enforce blocking subsequent packets
|
||||||
@@ -138,13 +138,13 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
// Both sides send FIN+ACK
|
// Both sides send FIN+ACK
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
require.True(t, valid, "Simultaneous FIN should be allowed")
|
require.True(t, valid, "Simultaneous FIN should be allowed")
|
||||||
|
|
||||||
// Both sides send final ACK
|
// Both sides send final ACK
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
require.True(t, valid, "Final ACKs should be allowed")
|
require.True(t, valid, "Final ACKs should be allowed")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
tt.test(t)
|
tt.test(t)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -162,11 +162,11 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRSTHandling(t *testing.T) {
|
func TestRSTHandling(t *testing.T) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("100.64.0.1")
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
dstIP := net.ParseIP("100.64.0.2")
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
srcPort := uint16(12345)
|
srcPort := uint16(12345)
|
||||||
dstPort := uint16(80)
|
dstPort := uint16(80)
|
||||||
|
|
||||||
@@ -181,12 +181,12 @@ func TestRSTHandling(t *testing.T) {
|
|||||||
name: "RST in established",
|
name: "RST in established",
|
||||||
setupState: func() {
|
setupState: func() {
|
||||||
// Establish connection first
|
// Establish connection first
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
},
|
},
|
||||||
sendRST: func() {
|
sendRST: func() {
|
||||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||||
},
|
},
|
||||||
wantValid: true,
|
wantValid: true,
|
||||||
desc: "Should accept RST for established connection",
|
desc: "Should accept RST for established connection",
|
||||||
@@ -195,7 +195,7 @@ func TestRSTHandling(t *testing.T) {
|
|||||||
name: "RST without connection",
|
name: "RST without connection",
|
||||||
setupState: func() {},
|
setupState: func() {},
|
||||||
sendRST: func() {
|
sendRST: func() {
|
||||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||||
},
|
},
|
||||||
wantValid: false,
|
wantValid: false,
|
||||||
desc: "Should reject RST without connection",
|
desc: "Should reject RST without connection",
|
||||||
@@ -208,7 +208,12 @@ func TestRSTHandling(t *testing.T) {
|
|||||||
tt.sendRST()
|
tt.sendRST()
|
||||||
|
|
||||||
// Verify connection state is as expected
|
// Verify connection state is as expected
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
conn := tracker.connections[key]
|
conn := tracker.connections[key]
|
||||||
if tt.wantValid {
|
if tt.wantValid {
|
||||||
require.NotNil(t, conn)
|
require.NotNil(t, conn)
|
||||||
@@ -220,63 +225,63 @@ func TestRSTHandling(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Helper to establish a TCP connection
|
// Helper to establish a TCP connection
|
||||||
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
|
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||||
|
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||||
require.True(t, valid, "SYN-ACK should be allowed")
|
require.True(t, valid, "SYN-ACK should be allowed")
|
||||||
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkTCPTracker(b *testing.B) {
|
func BenchmarkTCPTracker(b *testing.B) {
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
// Pre-populate some connections
|
// Pre-populate some connections
|
||||||
for i := 0; i < 1000; i++ {
|
for i := 0; i < 1000; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck)
|
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
i := 0
|
i := 0
|
||||||
for pb.Next() {
|
for pb.Next() {
|
||||||
if i%2 == 0 {
|
if i%2 == 0 {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||||
} else {
|
} else {
|
||||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck)
|
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, 0)
|
||||||
}
|
}
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
@@ -287,14 +292,14 @@ func BenchmarkTCPTracker(b *testing.B) {
|
|||||||
// Benchmark connection cleanup
|
// Benchmark connection cleanup
|
||||||
func BenchmarkCleanup(b *testing.B) {
|
func BenchmarkCleanup(b *testing.B) {
|
||||||
b.Run("TCPCleanup", func(b *testing.B) {
|
b.Run("TCPCleanup", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing
|
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger) // Short timeout for testing
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Pre-populate with expired connections
|
// Pre-populate with expired connections
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
for i := 0; i < 10000; i++ {
|
for i := 0; i < 10000; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for connections to expire
|
// Wait for connections to expire
|
||||||
|
|||||||
@@ -2,11 +2,14 @@ package conntrack
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -19,6 +22,8 @@ const (
|
|||||||
// UDPConnTrack represents a UDP connection state
|
// UDPConnTrack represents a UDP connection state
|
||||||
type UDPConnTrack struct {
|
type UDPConnTrack struct {
|
||||||
BaseConnTrack
|
BaseConnTrack
|
||||||
|
SourcePort uint16
|
||||||
|
DestPort uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
// UDPTracker manages UDP connection states
|
// UDPTracker manages UDP connection states
|
||||||
@@ -29,11 +34,11 @@ type UDPTracker struct {
|
|||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
tickerCancel context.CancelFunc
|
tickerCancel context.CancelFunc
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
ipPool *PreallocatedIPs
|
flowLogger nftypes.FlowLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUDPTracker creates a new UDP connection tracker
|
// NewUDPTracker creates a new UDP connection tracker
|
||||||
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
|
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *UDPTracker {
|
||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
timeout = DefaultUDPTimeout
|
timeout = DefaultUDPTimeout
|
||||||
}
|
}
|
||||||
@@ -46,7 +51,7 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
|
|||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||||
tickerCancel: cancel,
|
tickerCancel: cancel,
|
||||||
ipPool: NewPreallocatedIPs(),
|
flowLogger: flowLogger,
|
||||||
}
|
}
|
||||||
|
|
||||||
go tracker.cleanupRoutine(ctx)
|
go tracker.cleanupRoutine(ctx)
|
||||||
@@ -54,55 +59,87 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TrackOutbound records an outbound UDP connection
|
// TrackOutbound records an outbound UDP connection
|
||||||
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
|
||||||
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
t.mutex.Lock()
|
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
|
||||||
conn, exists := t.connections[key]
|
|
||||||
if !exists {
|
|
||||||
srcIPCopy := t.ipPool.Get()
|
|
||||||
dstIPCopy := t.ipPool.Get()
|
|
||||||
copyIP(srcIPCopy, srcIP)
|
|
||||||
copyIP(dstIPCopy, dstIP)
|
|
||||||
|
|
||||||
conn = &UDPConnTrack{
|
|
||||||
BaseConnTrack: BaseConnTrack{
|
|
||||||
SourceIP: srcIPCopy,
|
|
||||||
DestIP: dstIPCopy,
|
|
||||||
SourcePort: srcPort,
|
|
||||||
DestPort: dstPort,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
t.connections[key] = conn
|
|
||||||
|
|
||||||
t.logger.Trace("New UDP connection: %v", conn)
|
|
||||||
}
|
}
|
||||||
t.mutex.Unlock()
|
|
||||||
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound packet matches a tracked connection
|
// TrackInbound records an inbound UDP connection
|
||||||
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
|
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
|
||||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
|
||||||
t.mutex.RLock()
|
t.mutex.RLock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if !exists {
|
if exists {
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
conn.UpdateCounters(direction, size)
|
||||||
|
return key, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// track is the common implementation for tracking both inbound and outbound connections
|
||||||
|
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
|
||||||
|
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
|
||||||
|
if exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &UDPConnTrack{
|
||||||
|
BaseConnTrack: BaseConnTrack{
|
||||||
|
FlowId: uuid.New(),
|
||||||
|
Direction: direction,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
},
|
||||||
|
SourcePort: srcPort,
|
||||||
|
DestPort: dstPort,
|
||||||
|
}
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.connections[key] = conn
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
|
t.logger.Trace("New %s UDP connection: %s", direction, key)
|
||||||
|
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidInbound checks if an inbound packet matches a tracked connection
|
||||||
|
func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) bool {
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: dstIP,
|
||||||
|
DstIP: srcIP,
|
||||||
|
SrcPort: dstPort,
|
||||||
|
DstPort: srcPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mutex.RLock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
|
if !exists || conn.timeoutExceeded(t.timeout) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.timeoutExceeded(t.timeout) {
|
conn.UpdateLastSeen()
|
||||||
return false
|
conn.UpdateCounters(nftypes.Ingress, size)
|
||||||
}
|
|
||||||
|
|
||||||
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
return true
|
||||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
|
||||||
conn.DestPort == srcPort &&
|
|
||||||
conn.SourcePort == dstPort
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// cleanupRoutine periodically removes stale connections
|
// cleanupRoutine periodically removes stale connections
|
||||||
@@ -125,11 +162,11 @@ func (t *UDPTracker) cleanup() {
|
|||||||
|
|
||||||
for key, conn := range t.connections {
|
for key, conn := range t.connections {
|
||||||
if conn.timeoutExceeded(t.timeout) {
|
if conn.timeoutExceeded(t.timeout) {
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
t.logger.Trace("Removed UDP connection %v (timeout)", conn)
|
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||||
|
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -139,29 +176,44 @@ func (t *UDPTracker) Close() {
|
|||||||
t.tickerCancel()
|
t.tickerCancel()
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
for _, conn := range t.connections {
|
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
}
|
|
||||||
t.connections = nil
|
t.connections = nil
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConnection safely retrieves a connection state
|
// GetConnection safely retrieves a connection state
|
||||||
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) {
|
func (t *UDPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*UDPConnTrack, bool) {
|
||||||
t.mutex.RLock()
|
t.mutex.RLock()
|
||||||
defer t.mutex.RUnlock()
|
defer t.mutex.RUnlock()
|
||||||
|
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := ConnKey{
|
||||||
conn, exists := t.connections[key]
|
SrcIP: srcIP,
|
||||||
if !exists {
|
DstIP: dstIP,
|
||||||
return nil, false
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
}
|
}
|
||||||
|
conn, exists := t.connections[key]
|
||||||
return conn, true
|
return conn, exists
|
||||||
}
|
}
|
||||||
|
|
||||||
// Timeout returns the configured timeout duration for the tracker
|
// Timeout returns the configured timeout duration for the tracker
|
||||||
func (t *UDPTracker) Timeout() time.Duration {
|
func (t *UDPTracker) Timeout() time.Duration {
|
||||||
return t.timeout
|
return t.timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack, ruleID []byte) {
|
||||||
|
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: conn.FlowId,
|
||||||
|
Type: typ,
|
||||||
|
RuleID: ruleID,
|
||||||
|
Direction: conn.Direction,
|
||||||
|
Protocol: nftypes.UDP,
|
||||||
|
SourceIP: conn.SourceIP,
|
||||||
|
DestIP: conn.DestIP,
|
||||||
|
SourcePort: conn.SourcePort,
|
||||||
|
DestPort: conn.DestPort,
|
||||||
|
RxPackets: conn.PacketsRx.Load(),
|
||||||
|
TxPackets: conn.PacketsTx.Load(),
|
||||||
|
RxBytes: conn.BytesRx.Load(),
|
||||||
|
TxBytes: conn.BytesTx.Load(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package conntrack
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -30,7 +30,7 @@ func TestNewUDPTracker(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
tracker := NewUDPTracker(tt.timeout, logger)
|
tracker := NewUDPTracker(tt.timeout, logger, flowLogger)
|
||||||
assert.NotNil(t, tracker)
|
assert.NotNil(t, tracker)
|
||||||
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
||||||
assert.NotNil(t, tracker.connections)
|
assert.NotNil(t, tracker.connections)
|
||||||
@@ -41,43 +41,48 @@ func TestNewUDPTracker(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.2")
|
srcIP := netip.MustParseAddr("192.168.1.2")
|
||||||
dstIP := net.ParseIP("192.168.1.3")
|
dstIP := netip.MustParseAddr("192.168.1.3")
|
||||||
srcPort := uint16(12345)
|
srcPort := uint16(12345)
|
||||||
dstPort := uint16(53)
|
dstPort := uint16(53)
|
||||||
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
|
||||||
|
|
||||||
// Verify connection was tracked
|
// Verify connection was tracked
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
conn, exists := tracker.connections[key]
|
conn, exists := tracker.connections[key]
|
||||||
require.True(t, exists)
|
require.True(t, exists)
|
||||||
assert.True(t, conn.SourceIP.Equal(srcIP))
|
assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
|
||||||
assert.True(t, conn.DestIP.Equal(dstIP))
|
assert.True(t, conn.DestIP.Compare(dstIP) == 0)
|
||||||
assert.Equal(t, srcPort, conn.SourcePort)
|
assert.Equal(t, srcPort, conn.SourcePort)
|
||||||
assert.Equal(t, dstPort, conn.DestPort)
|
assert.Equal(t, dstPort, conn.DestPort)
|
||||||
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||||
tracker := NewUDPTracker(1*time.Second, logger)
|
tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.2")
|
srcIP := netip.MustParseAddr("192.168.1.2")
|
||||||
dstIP := net.ParseIP("192.168.1.3")
|
dstIP := netip.MustParseAddr("192.168.1.3")
|
||||||
srcPort := uint16(12345)
|
srcPort := uint16(12345)
|
||||||
dstPort := uint16(53)
|
dstPort := uint16(53)
|
||||||
|
|
||||||
// Track outbound connection
|
// Track outbound connection
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
srcIP net.IP
|
srcIP netip.Addr
|
||||||
dstIP net.IP
|
dstIP netip.Addr
|
||||||
srcPort uint16
|
srcPort uint16
|
||||||
dstPort uint16
|
dstPort uint16
|
||||||
sleep time.Duration
|
sleep time.Duration
|
||||||
@@ -94,7 +99,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid source IP",
|
name: "invalid source IP",
|
||||||
srcIP: net.ParseIP("192.168.1.4"),
|
srcIP: netip.MustParseAddr("192.168.1.4"),
|
||||||
dstIP: srcIP,
|
dstIP: srcIP,
|
||||||
srcPort: dstPort,
|
srcPort: dstPort,
|
||||||
dstPort: srcPort,
|
dstPort: srcPort,
|
||||||
@@ -104,7 +109,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "invalid destination IP",
|
name: "invalid destination IP",
|
||||||
srcIP: dstIP,
|
srcIP: dstIP,
|
||||||
dstIP: net.ParseIP("192.168.1.4"),
|
dstIP: netip.MustParseAddr("192.168.1.4"),
|
||||||
srcPort: dstPort,
|
srcPort: dstPort,
|
||||||
dstPort: srcPort,
|
dstPort: srcPort,
|
||||||
sleep: 0,
|
sleep: 0,
|
||||||
@@ -144,7 +149,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
|||||||
if tt.sleep > 0 {
|
if tt.sleep > 0 {
|
||||||
time.Sleep(tt.sleep)
|
time.Sleep(tt.sleep)
|
||||||
}
|
}
|
||||||
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort)
|
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, 0)
|
||||||
assert.Equal(t, tt.want, got)
|
assert.Equal(t, tt.want, got)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -164,8 +169,8 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
|||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cleanupTicker: time.NewTicker(cleanupInterval),
|
cleanupTicker: time.NewTicker(cleanupInterval),
|
||||||
tickerCancel: tickerCancel,
|
tickerCancel: tickerCancel,
|
||||||
ipPool: NewPreallocatedIPs(),
|
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
flowLogger: flowLogger,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start cleanup routine
|
// Start cleanup routine
|
||||||
@@ -173,27 +178,27 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
|||||||
|
|
||||||
// Add some connections
|
// Add some connections
|
||||||
connections := []struct {
|
connections := []struct {
|
||||||
srcIP net.IP
|
srcIP netip.Addr
|
||||||
dstIP net.IP
|
dstIP netip.Addr
|
||||||
srcPort uint16
|
srcPort uint16
|
||||||
dstPort uint16
|
dstPort uint16
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
srcIP: net.ParseIP("192.168.1.2"),
|
srcIP: netip.MustParseAddr("192.168.1.2"),
|
||||||
dstIP: net.ParseIP("192.168.1.3"),
|
dstIP: netip.MustParseAddr("192.168.1.3"),
|
||||||
srcPort: 12345,
|
srcPort: 12345,
|
||||||
dstPort: 53,
|
dstPort: 53,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
srcIP: net.ParseIP("192.168.1.4"),
|
srcIP: netip.MustParseAddr("192.168.1.4"),
|
||||||
dstIP: net.ParseIP("192.168.1.5"),
|
dstIP: netip.MustParseAddr("192.168.1.5"),
|
||||||
srcPort: 12346,
|
srcPort: 12346,
|
||||||
dstPort: 53,
|
dstPort: 53,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, conn := range connections {
|
for _, conn := range connections {
|
||||||
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort)
|
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify initial connections
|
// Verify initial connections
|
||||||
@@ -215,33 +220,33 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
|||||||
|
|
||||||
func BenchmarkUDPTracker(b *testing.B) {
|
func BenchmarkUDPTracker(b *testing.B) {
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
// Pre-populate some connections
|
// Pre-populate some connections
|
||||||
for i := 0; i < 1000; i++ {
|
for i := 0; i < 1000; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000))
|
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package forwarder
|
package forwarder
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
@@ -79,3 +81,10 @@ func (e *endpoint) AddHeader(*stack.PacketBuffer) {
|
|||||||
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type epID stack.TransportEndpointID
|
||||||
|
|
||||||
|
func (i epID) String() string {
|
||||||
|
// src and remote is swapped
|
||||||
|
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
||||||
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -29,6 +30,7 @@ const (
|
|||||||
|
|
||||||
type Forwarder struct {
|
type Forwarder struct {
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
|
flowLogger nftypes.FlowLogger
|
||||||
stack *stack.Stack
|
stack *stack.Stack
|
||||||
endpoint *endpoint
|
endpoint *endpoint
|
||||||
udpForwarder *udpForwarder
|
udpForwarder *udpForwarder
|
||||||
@@ -38,7 +40,7 @@ type Forwarder struct {
|
|||||||
netstack bool
|
netstack bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) {
|
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
|
||||||
s := stack.New(stack.Options{
|
s := stack.New(stack.Options{
|
||||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||||
TransportProtocols: []stack.TransportProtocolFactory{
|
TransportProtocols: []stack.TransportProtocolFactory{
|
||||||
@@ -102,9 +104,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwar
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
f := &Forwarder{
|
f := &Forwarder{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
flowLogger: flowLogger,
|
||||||
stack: s,
|
stack: s,
|
||||||
endpoint: endpoint,
|
endpoint: endpoint,
|
||||||
udpForwarder: newUDPForwarder(mtu, logger),
|
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
netstack: netstack,
|
netstack: netstack,
|
||||||
|
|||||||
@@ -3,14 +3,30 @@ package forwarder
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// handleICMP handles ICMP packets from the network stack
|
// handleICMP handles ICMP packets from the network stack
|
||||||
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
||||||
|
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
||||||
|
icmpType := uint8(icmpHdr.Type())
|
||||||
|
icmpCode := uint8(icmpHdr.Code())
|
||||||
|
|
||||||
|
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
|
||||||
|
// dont process our own replies
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
flowID := uuid.New()
|
||||||
|
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@@ -18,7 +34,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
|||||||
// TODO: support non-root
|
// TODO: support non-root
|
||||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.logger.Error("Failed to create ICMP socket for %v: %v", id, err)
|
f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
|
||||||
|
|
||||||
// This will make netstack reply on behalf of the original destination, that's ok for now
|
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||||
return false
|
return false
|
||||||
@@ -32,47 +48,31 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
|||||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||||
dst := &net.IPAddr{IP: dstIP}
|
dst := &net.IPAddr{IP: dstIP}
|
||||||
|
|
||||||
// Get the complete ICMP message (header + data)
|
|
||||||
fullPacket := stack.PayloadSince(pkt.TransportHeader())
|
fullPacket := stack.PayloadSince(pkt.TransportHeader())
|
||||||
payload := fullPacket.AsSlice()
|
payload := fullPacket.AsSlice()
|
||||||
|
|
||||||
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||||
|
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
|
||||||
|
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||||
|
|
||||||
// For Echo Requests, send and handle response
|
// For Echo Requests, send and handle response
|
||||||
switch icmpHdr.Type() {
|
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
|
||||||
case header.ICMPv4Echo:
|
f.handleEchoResponse(icmpHdr, conn, id)
|
||||||
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id)
|
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
|
||||||
case header.ICMPv4EchoReply:
|
|
||||||
// dont process our own replies
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc)
|
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
|
||||||
_, err = conn.WriteTo(payload, dst)
|
|
||||||
if err != nil {
|
|
||||||
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
|
||||||
id, icmpHdr.Type(), icmpHdr.Code())
|
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool {
|
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) {
|
||||||
if _, err := conn.WriteTo(payload, dst); err != nil {
|
|
||||||
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
|
||||||
id, icmpHdr.Type(), icmpHdr.Code())
|
|
||||||
|
|
||||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||||
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
|
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
|
||||||
return true
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response := make([]byte, f.endpoint.mtu)
|
response := make([]byte, f.endpoint.mtu)
|
||||||
@@ -81,7 +81,7 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
|
|||||||
if !isTimeout(err) {
|
if !isTimeout(err) {
|
||||||
f.logger.Error("Failed to read ICMP response: %v", err)
|
f.logger.Error("Failed to read ICMP response: %v", err)
|
||||||
}
|
}
|
||||||
return true
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||||
@@ -101,9 +101,27 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
|
|||||||
|
|
||||||
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||||
f.logger.Error("Failed to inject ICMP response: %v", err)
|
f.logger.Error("Failed to inject ICMP response: %v", err)
|
||||||
return true
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.logger.Trace("Forwarded ICMP echo reply for %v", id)
|
f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
|
||||||
return true
|
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendICMPEvent stores flow events for ICMP packets
|
||||||
|
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) {
|
||||||
|
f.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: flowID,
|
||||||
|
Type: typ,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: nftypes.ICMP,
|
||||||
|
// TODO: handle ipv6
|
||||||
|
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||||
|
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||||
|
ICMPType: icmpType,
|
||||||
|
ICMPCode: icmpCode,
|
||||||
|
|
||||||
|
// TODO: get packets/bytes
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,24 +5,38 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
"gvisor.dev/gvisor/pkg/waiter"
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
|
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// handleTCP is called by the TCP forwarder for new connections.
|
// handleTCP is called by the TCP forwarder for new connections.
|
||||||
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||||
id := r.ID()
|
id := r.ID()
|
||||||
|
|
||||||
|
flowID := uuid.New()
|
||||||
|
|
||||||
|
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil)
|
||||||
|
var success bool
|
||||||
|
defer func() {
|
||||||
|
if !success {
|
||||||
|
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||||
|
|
||||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.Complete(true)
|
r.Complete(true)
|
||||||
f.logger.Trace("forwarder: dial error for %v: %v", id, err)
|
f.logger.Trace("forwarder: dial error for %v: %v", epID(id), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,12 +58,13 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
|||||||
|
|
||||||
inConn := gonet.NewTCPConn(&wq, ep)
|
inConn := gonet.NewTCPConn(&wq, ep)
|
||||||
|
|
||||||
f.logger.Trace("forwarder: established TCP connection %v", id)
|
success = true
|
||||||
|
f.logger.Trace("forwarder: established TCP connection %v", epID(id))
|
||||||
|
|
||||||
go f.proxyTCP(id, inConn, outConn, ep)
|
go f.proxyTCP(id, inConn, outConn, ep, flowID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) {
|
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := inConn.Close(); err != nil {
|
if err := inConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: inConn close error: %v", err)
|
f.logger.Debug("forwarder: inConn close error: %v", err)
|
||||||
@@ -58,6 +73,8 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
|||||||
f.logger.Debug("forwarder: outConn close error: %v", err)
|
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||||
}
|
}
|
||||||
ep.Close()
|
ep.Close()
|
||||||
|
|
||||||
|
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Create context for managing the proxy goroutines
|
// Create context for managing the proxy goroutines
|
||||||
@@ -78,13 +95,38 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", id)
|
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id))
|
||||||
return
|
return
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
if err != nil && !isClosedError(err) {
|
if err != nil && !isClosedError(err) {
|
||||||
f.logger.Error("proxyTCP: copy error: %v", err)
|
f.logger.Error("proxyTCP: copy error: %v", err)
|
||||||
}
|
}
|
||||||
f.logger.Trace("forwarder: tearing down TCP connection %v", id)
|
f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||||
|
fields := nftypes.EventFields{
|
||||||
|
FlowID: flowID,
|
||||||
|
Type: typ,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: nftypes.TCP,
|
||||||
|
// TODO: handle ipv6
|
||||||
|
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||||
|
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||||
|
SourcePort: id.RemotePort,
|
||||||
|
DestPort: id.LocalPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
if ep != nil {
|
||||||
|
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
||||||
|
// fields are flipped since this is the in conn
|
||||||
|
// TODO: get bytes
|
||||||
|
fields.RxPackets = tcpStats.SegmentsSent.Value()
|
||||||
|
fields.TxPackets = tcpStats.SegmentsReceived.Value()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f.flowLogger.StoreEvent(fields)
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,10 +5,12 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
@@ -16,6 +18,7 @@ import (
|
|||||||
"gvisor.dev/gvisor/pkg/waiter"
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
|
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -28,15 +31,17 @@ type udpPacketConn struct {
|
|||||||
lastSeen atomic.Int64
|
lastSeen atomic.Int64
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
ep tcpip.Endpoint
|
ep tcpip.Endpoint
|
||||||
|
flowID uuid.UUID
|
||||||
}
|
}
|
||||||
|
|
||||||
type udpForwarder struct {
|
type udpForwarder struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
conns map[stack.TransportEndpointID]*udpPacketConn
|
flowLogger nftypes.FlowLogger
|
||||||
bufPool sync.Pool
|
conns map[stack.TransportEndpointID]*udpPacketConn
|
||||||
ctx context.Context
|
bufPool sync.Pool
|
||||||
cancel context.CancelFunc
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
type idleConn struct {
|
type idleConn struct {
|
||||||
@@ -44,13 +49,14 @@ type idleConn struct {
|
|||||||
conn *udpPacketConn
|
conn *udpPacketConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder {
|
func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
f := &udpForwarder{
|
f := &udpForwarder{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
flowLogger: flowLogger,
|
||||||
ctx: ctx,
|
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
||||||
cancel: cancel,
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
bufPool: sync.Pool{
|
bufPool: sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
b := make([]byte, mtu)
|
b := make([]byte, mtu)
|
||||||
@@ -72,10 +78,10 @@ func (f *udpForwarder) Stop() {
|
|||||||
for id, conn := range f.conns {
|
for id, conn := range f.conns {
|
||||||
conn.cancel()
|
conn.cancel()
|
||||||
if err := conn.conn.Close(); err != nil {
|
if err := conn.conn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
if err := conn.outConn.Close(); err != nil {
|
if err := conn.outConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.ep.Close()
|
conn.ep.Close()
|
||||||
@@ -106,10 +112,10 @@ func (f *udpForwarder) cleanup() {
|
|||||||
for _, idle := range idleConns {
|
for _, idle := range idleConns {
|
||||||
idle.conn.cancel()
|
idle.conn.cancel()
|
||||||
if err := idle.conn.conn.Close(); err != nil {
|
if err := idle.conn.conn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", idle.id, err)
|
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(idle.id), err)
|
||||||
}
|
}
|
||||||
if err := idle.conn.outConn.Close(); err != nil {
|
if err := idle.conn.outConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", idle.id, err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(idle.id), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
idle.conn.ep.Close()
|
idle.conn.ep.Close()
|
||||||
@@ -118,7 +124,7 @@ func (f *udpForwarder) cleanup() {
|
|||||||
delete(f.conns, idle.id)
|
delete(f.conns, idle.id)
|
||||||
f.Unlock()
|
f.Unlock()
|
||||||
|
|
||||||
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id)
|
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -137,14 +143,24 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
_, exists := f.udpForwarder.conns[id]
|
_, exists := f.udpForwarder.conns[id]
|
||||||
f.udpForwarder.RUnlock()
|
f.udpForwarder.RUnlock()
|
||||||
if exists {
|
if exists {
|
||||||
f.logger.Trace("forwarder: existing UDP connection for %v", id)
|
f.logger.Trace("forwarder: existing UDP connection for %v", epID(id))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
flowID := uuid.New()
|
||||||
|
|
||||||
|
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil)
|
||||||
|
var success bool
|
||||||
|
defer func() {
|
||||||
|
if !success {
|
||||||
|
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err)
|
||||||
// TODO: Send ICMP error message
|
// TODO: Send ICMP error message
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -155,7 +171,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
if epErr != nil {
|
if epErr != nil {
|
||||||
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
|
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
|
||||||
if err := outConn.Close(); err != nil {
|
if err := outConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -168,6 +184,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
outConn: outConn,
|
outConn: outConn,
|
||||||
cancel: connCancel,
|
cancel: connCancel,
|
||||||
ep: ep,
|
ep: ep,
|
||||||
|
flowID: flowID,
|
||||||
}
|
}
|
||||||
pConn.updateLastSeen()
|
pConn.updateLastSeen()
|
||||||
|
|
||||||
@@ -177,17 +194,20 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
f.udpForwarder.Unlock()
|
f.udpForwarder.Unlock()
|
||||||
pConn.cancel()
|
pConn.cancel()
|
||||||
if err := inConn.Close(); err != nil {
|
if err := inConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
if err := outConn.Close(); err != nil {
|
if err := outConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
f.udpForwarder.conns[id] = pConn
|
f.udpForwarder.conns[id] = pConn
|
||||||
f.udpForwarder.Unlock()
|
f.udpForwarder.Unlock()
|
||||||
|
|
||||||
f.logger.Trace("forwarder: established UDP connection to %v", id)
|
success = true
|
||||||
|
f.logger.Trace("forwarder: established UDP connection %v", epID(id))
|
||||||
|
|
||||||
go f.proxyUDP(connCtx, pConn, id, ep)
|
go f.proxyUDP(connCtx, pConn, id, ep)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -195,10 +215,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
|||||||
defer func() {
|
defer func() {
|
||||||
pConn.cancel()
|
pConn.cancel()
|
||||||
if err := pConn.conn.Close(); err != nil {
|
if err := pConn.conn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
if err := pConn.outConn.Close(); err != nil {
|
if err := pConn.outConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ep.Close()
|
ep.Close()
|
||||||
@@ -206,6 +226,8 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
|||||||
f.udpForwarder.Lock()
|
f.udpForwarder.Lock()
|
||||||
delete(f.udpForwarder.conns, id)
|
delete(f.udpForwarder.conns, id)
|
||||||
f.udpForwarder.Unlock()
|
f.udpForwarder.Unlock()
|
||||||
|
|
||||||
|
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
errChan := make(chan error, 2)
|
errChan := make(chan error, 2)
|
||||||
@@ -220,17 +242,43 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", id)
|
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id))
|
||||||
return
|
return
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
if err != nil && !isClosedError(err) {
|
if err != nil && !isClosedError(err) {
|
||||||
f.logger.Error("proxyUDP: copy error: %v", err)
|
f.logger.Error("proxyUDP: copy error: %v", err)
|
||||||
}
|
}
|
||||||
f.logger.Trace("forwarder: tearing down UDP connection %v", id)
|
f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sendUDPEvent stores flow events for UDP connections
|
||||||
|
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||||
|
fields := nftypes.EventFields{
|
||||||
|
FlowID: flowID,
|
||||||
|
Type: typ,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: nftypes.UDP,
|
||||||
|
// TODO: handle ipv6
|
||||||
|
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||||
|
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||||
|
SourcePort: id.RemotePort,
|
||||||
|
DestPort: id.LocalPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
if ep != nil {
|
||||||
|
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
|
||||||
|
// fields are flipped since this is the in conn
|
||||||
|
// TODO: get bytes
|
||||||
|
fields.RxPackets = tcpStats.PacketsSent.Value()
|
||||||
|
fields.TxPackets = tcpStats.PacketsReceived.Value()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f.flowLogger.StoreEvent(fields)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *udpPacketConn) updateLastSeen() {
|
func (c *udpPacketConn) updateLastSeen() {
|
||||||
c.lastSeen.Store(time.Now().UnixNano())
|
c.lastSeen.Store(time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package uspfilter
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -31,13 +32,9 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
|
|||||||
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) checkBitmapBit(ip net.IP) bool {
|
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||||
ipv4 := ip.To4()
|
high := (uint16(ip[0]) << 8) | uint16(ip[1])
|
||||||
if ipv4 == nil {
|
low := (uint16(ip[2]) << 8) | uint16(ip[3])
|
||||||
return false
|
|
||||||
}
|
|
||||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
|
||||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
|
||||||
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -122,12 +119,12 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) IsLocalIP(ip net.IP) bool {
|
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
|
||||||
m.mu.RLock()
|
m.mu.RLock()
|
||||||
defer m.mu.RUnlock()
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
if ipv4 := ip.To4(); ipv4 != nil {
|
if ip.Is4() {
|
||||||
return m.checkBitmapBit(ipv4)
|
return m.checkBitmapBit(ip.AsSlice())
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package uspfilter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -13,7 +14,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
setupAddr wgaddr.Address
|
setupAddr wgaddr.Address
|
||||||
testIP net.IP
|
testIP netip.Addr
|
||||||
expected bool
|
expected bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
@@ -25,7 +26,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("127.0.0.2"),
|
testIP: netip.MustParseAddr("127.0.0.2"),
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -37,7 +38,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("127.0.0.1"),
|
testIP: netip.MustParseAddr("127.0.0.1"),
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -49,7 +50,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("127.255.255.255"),
|
testIP: netip.MustParseAddr("127.255.255.255"),
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -61,7 +62,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("192.168.1.1"),
|
testIP: netip.MustParseAddr("192.168.1.1"),
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -73,7 +74,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("192.168.1.2"),
|
testIP: netip.MustParseAddr("192.168.1.2"),
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -85,7 +86,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(64, 128),
|
Mask: net.CIDRMask(64, 128),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("fe80::1"),
|
testIP: netip.MustParseAddr("fe80::1"),
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -174,7 +175,7 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
|
|||||||
t.Logf("Testing %d IPs", len(tests))
|
t.Logf("Testing %d IPs", len(tests))
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.ip, func(t *testing.T) {
|
t.Run(tt.ip, func(t *testing.T) {
|
||||||
result := manager.IsLocalIP(net.ParseIP(tt.ip))
|
result := manager.IsLocalIP(netip.MustParseAddr(tt.ip))
|
||||||
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
|
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
// Package logger provides a high-performance, non-blocking logger for userspace networking
|
// Package log provides a high-performance, non-blocking logger for userspace networking
|
||||||
package log
|
package log
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -13,13 +13,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
maxBatchSize = 1024 * 16 // 16KB max batch size
|
maxBatchSize = 1024 * 16
|
||||||
maxMessageSize = 1024 * 2 // 2KB per message
|
maxMessageSize = 1024 * 2
|
||||||
bufferSize = 1024 * 256 // 256KB ring buffer
|
|
||||||
defaultFlushInterval = 2 * time.Second
|
defaultFlushInterval = 2 * time.Second
|
||||||
|
logChannelSize = 1000
|
||||||
)
|
)
|
||||||
|
|
||||||
// Level represents log severity
|
|
||||||
type Level uint32
|
type Level uint32
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -42,32 +41,37 @@ var levelStrings = map[Level]string{
|
|||||||
LevelTrace: "TRAC",
|
LevelTrace: "TRAC",
|
||||||
}
|
}
|
||||||
|
|
||||||
// Logger is a high-performance, non-blocking logger
|
type logMessage struct {
|
||||||
type Logger struct {
|
level Level
|
||||||
output io.Writer
|
format string
|
||||||
level atomic.Uint32
|
args []any
|
||||||
buffer *ringBuffer
|
|
||||||
shutdown chan struct{}
|
|
||||||
closeOnce sync.Once
|
|
||||||
wg sync.WaitGroup
|
|
||||||
|
|
||||||
// Reusable buffer pool for formatting messages
|
|
||||||
bufPool sync.Pool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Logger is a high-performance, non-blocking logger
|
||||||
|
type Logger struct {
|
||||||
|
output io.Writer
|
||||||
|
level atomic.Uint32
|
||||||
|
msgChannel chan logMessage
|
||||||
|
shutdown chan struct{}
|
||||||
|
closeOnce sync.Once
|
||||||
|
wg sync.WaitGroup
|
||||||
|
bufPool sync.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFromLogrus creates a new Logger that writes to the same output as the given logrus logger
|
||||||
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||||
l := &Logger{
|
l := &Logger{
|
||||||
output: logrusLogger.Out,
|
output: logrusLogger.Out,
|
||||||
buffer: newRingBuffer(bufferSize),
|
msgChannel: make(chan logMessage, logChannelSize),
|
||||||
shutdown: make(chan struct{}),
|
shutdown: make(chan struct{}),
|
||||||
bufPool: sync.Pool{
|
bufPool: sync.Pool{
|
||||||
New: func() interface{} {
|
New: func() any {
|
||||||
// Pre-allocate buffer for message formatting
|
|
||||||
b := make([]byte, 0, maxMessageSize)
|
b := make([]byte, 0, maxMessageSize)
|
||||||
return &b
|
return &b
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
logrusLevel := logrusLogger.GetLevel()
|
logrusLevel := logrusLogger.GetLevel()
|
||||||
l.level.Store(uint32(logrusLevel))
|
l.level.Store(uint32(logrusLevel))
|
||||||
level := levelStrings[Level(logrusLevel)]
|
level := levelStrings[Level(logrusLevel)]
|
||||||
@@ -79,97 +83,149 @@ func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
|||||||
return l
|
return l
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLevel sets the logging level
|
||||||
func (l *Logger) SetLevel(level Level) {
|
func (l *Logger) SetLevel(level Level) {
|
||||||
l.level.Store(uint32(level))
|
l.level.Store(uint32(level))
|
||||||
|
|
||||||
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) {
|
func (l *Logger) log(level Level, format string, args ...any) {
|
||||||
*buf = (*buf)[:0]
|
select {
|
||||||
|
case l.msgChannel <- logMessage{level: level, format: format, args: args}:
|
||||||
// Timestamp
|
default:
|
||||||
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
|
||||||
*buf = append(*buf, ' ')
|
|
||||||
|
|
||||||
// Level
|
|
||||||
*buf = append(*buf, levelStrings[level]...)
|
|
||||||
*buf = append(*buf, ' ')
|
|
||||||
|
|
||||||
// Message
|
|
||||||
if len(args) > 0 {
|
|
||||||
*buf = append(*buf, fmt.Sprintf(format, args...)...)
|
|
||||||
} else {
|
|
||||||
*buf = append(*buf, format...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
*buf = append(*buf, '\n')
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) log(level Level, format string, args ...interface{}) {
|
// Error logs a message at error level
|
||||||
bufp := l.bufPool.Get().(*[]byte)
|
func (l *Logger) Error(format string, args ...any) {
|
||||||
l.formatMessage(bufp, level, format, args...)
|
|
||||||
|
|
||||||
if len(*bufp) > maxMessageSize {
|
|
||||||
*bufp = (*bufp)[:maxMessageSize]
|
|
||||||
}
|
|
||||||
_, _ = l.buffer.Write(*bufp)
|
|
||||||
|
|
||||||
l.bufPool.Put(bufp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) Error(format string, args ...interface{}) {
|
|
||||||
if l.level.Load() >= uint32(LevelError) {
|
if l.level.Load() >= uint32(LevelError) {
|
||||||
l.log(LevelError, format, args...)
|
l.log(LevelError, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Warn(format string, args ...interface{}) {
|
// Warn logs a message at warning level
|
||||||
|
func (l *Logger) Warn(format string, args ...any) {
|
||||||
if l.level.Load() >= uint32(LevelWarn) {
|
if l.level.Load() >= uint32(LevelWarn) {
|
||||||
l.log(LevelWarn, format, args...)
|
l.log(LevelWarn, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Info(format string, args ...interface{}) {
|
// Info logs a message at info level
|
||||||
|
func (l *Logger) Info(format string, args ...any) {
|
||||||
if l.level.Load() >= uint32(LevelInfo) {
|
if l.level.Load() >= uint32(LevelInfo) {
|
||||||
l.log(LevelInfo, format, args...)
|
l.log(LevelInfo, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
// Debug logs a message at debug level
|
||||||
|
func (l *Logger) Debug(format string, args ...any) {
|
||||||
if l.level.Load() >= uint32(LevelDebug) {
|
if l.level.Load() >= uint32(LevelDebug) {
|
||||||
l.log(LevelDebug, format, args...)
|
l.log(LevelDebug, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Trace(format string, args ...interface{}) {
|
// Trace logs a message at trace level
|
||||||
|
func (l *Logger) Trace(format string, args ...any) {
|
||||||
if l.level.Load() >= uint32(LevelTrace) {
|
if l.level.Load() >= uint32(LevelTrace) {
|
||||||
l.log(LevelTrace, format, args...)
|
l.log(LevelTrace, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// worker periodically flushes the buffer
|
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...any) {
|
||||||
|
*buf = (*buf)[:0]
|
||||||
|
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
||||||
|
*buf = append(*buf, ' ')
|
||||||
|
*buf = append(*buf, levelStrings[level]...)
|
||||||
|
*buf = append(*buf, ' ')
|
||||||
|
|
||||||
|
var msg string
|
||||||
|
if len(args) > 0 {
|
||||||
|
msg = fmt.Sprintf(format, args...)
|
||||||
|
} else {
|
||||||
|
msg = format
|
||||||
|
}
|
||||||
|
*buf = append(*buf, msg...)
|
||||||
|
*buf = append(*buf, '\n')
|
||||||
|
|
||||||
|
if len(*buf) > maxMessageSize {
|
||||||
|
*buf = (*buf)[:maxMessageSize]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processMessage handles a single log message and adds it to the buffer
|
||||||
|
func (l *Logger) processMessage(msg logMessage, buffer *[]byte) {
|
||||||
|
bufp := l.bufPool.Get().(*[]byte)
|
||||||
|
defer l.bufPool.Put(bufp)
|
||||||
|
|
||||||
|
l.formatMessage(bufp, msg.level, msg.format, msg.args...)
|
||||||
|
|
||||||
|
if len(*buffer)+len(*bufp) > maxBatchSize {
|
||||||
|
_, _ = l.output.Write(*buffer)
|
||||||
|
*buffer = (*buffer)[:0]
|
||||||
|
}
|
||||||
|
*buffer = append(*buffer, *bufp...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// flushBuffer writes the accumulated buffer to output
|
||||||
|
func (l *Logger) flushBuffer(buffer *[]byte) {
|
||||||
|
if len(*buffer) > 0 {
|
||||||
|
_, _ = l.output.Write(*buffer)
|
||||||
|
*buffer = (*buffer)[:0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processBatch processes as many messages as possible without blocking
|
||||||
|
func (l *Logger) processBatch(buffer *[]byte) {
|
||||||
|
for len(*buffer) < maxBatchSize {
|
||||||
|
select {
|
||||||
|
case msg := <-l.msgChannel:
|
||||||
|
l.processMessage(msg, buffer)
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleShutdown manages the graceful shutdown sequence with timeout
|
||||||
|
func (l *Logger) handleShutdown(buffer *[]byte) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case msg := <-l.msgChannel:
|
||||||
|
l.processMessage(msg, buffer)
|
||||||
|
case <-ctx.Done():
|
||||||
|
l.flushBuffer(buffer)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(l.msgChannel) == 0 {
|
||||||
|
l.flushBuffer(buffer)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// worker is the main goroutine that processes log messages
|
||||||
func (l *Logger) worker() {
|
func (l *Logger) worker() {
|
||||||
defer l.wg.Done()
|
defer l.wg.Done()
|
||||||
|
|
||||||
ticker := time.NewTicker(defaultFlushInterval)
|
ticker := time.NewTicker(defaultFlushInterval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
buf := make([]byte, 0, maxBatchSize)
|
buffer := make([]byte, 0, maxBatchSize)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-l.shutdown:
|
case <-l.shutdown:
|
||||||
|
l.handleShutdown(&buffer)
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
// Read accumulated messages
|
l.flushBuffer(&buffer)
|
||||||
n, _ := l.buffer.Read(buf[:cap(buf)])
|
case msg := <-l.msgChannel:
|
||||||
if n == 0 {
|
l.processMessage(msg, &buffer)
|
||||||
continue
|
l.processBatch(&buffer)
|
||||||
}
|
|
||||||
|
|
||||||
// Write batch
|
|
||||||
_, _ = l.output.Write(buf[:n])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
121
client/firewall/uspfilter/log/log_test.go
Normal file
121
client/firewall/uspfilter/log/log_test.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
package log_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type discard struct{}
|
||||||
|
|
||||||
|
func (d *discard) Write(p []byte) (n int, err error) {
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLogger(b *testing.B) {
|
||||||
|
simpleMessage := "Connection established"
|
||||||
|
|
||||||
|
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||||
|
srcIP := "192.168.1.1"
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstIP := "10.0.0.1"
|
||||||
|
dstPort := uint16(443)
|
||||||
|
state := 4 // TCPStateEstablished
|
||||||
|
|
||||||
|
complexMessage := "Packet inspection result: protocol=%s, direction=%s, flags=0x%x, sequence=%d, acknowledged=%d, payload_size=%d, fragmented=%v, connection_id=%s"
|
||||||
|
protocol := "TCP"
|
||||||
|
direction := "outbound"
|
||||||
|
flags := uint16(0x18) // ACK + PSH
|
||||||
|
sequence := uint32(123456789)
|
||||||
|
acknowledged := uint32(987654321)
|
||||||
|
payloadSize := 1460
|
||||||
|
fragmented := false
|
||||||
|
connID := "f7a12b3e-c456-7890-d123-456789abcdef"
|
||||||
|
|
||||||
|
b.Run("SimpleMessage", func(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
logger.Trace(simpleMessage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("ConntrackMessage", func(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("ComplexMessage", func(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
logger.Trace(complexMessage, protocol, direction, flags, sequence, acknowledged, payloadSize, fragmented, connID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkLoggerParallel tests the logger under concurrent load
|
||||||
|
func BenchmarkLoggerParallel(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||||
|
srcIP := "192.168.1.1"
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstIP := "10.0.0.1"
|
||||||
|
dstPort := uint16(443)
|
||||||
|
state := 4
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkLoggerBurst tests how the logger handles bursts of messages
|
||||||
|
func BenchmarkLoggerBurst(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||||
|
srcIP := "192.168.1.1"
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstIP := "10.0.0.1"
|
||||||
|
dstPort := uint16(443)
|
||||||
|
state := 4
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
for j := 0; j < 100; j++ {
|
||||||
|
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestLogger() *log.Logger {
|
||||||
|
logrusLogger := logrus.New()
|
||||||
|
logrusLogger.SetOutput(&discard{})
|
||||||
|
logrusLogger.SetLevel(logrus.TraceLevel)
|
||||||
|
return log.NewFromLogrus(logrusLogger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupLogger(logger *log.Logger) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
_ = logger.Stop(ctx)
|
||||||
|
}
|
||||||
@@ -1,85 +0,0 @@
|
|||||||
package log
|
|
||||||
|
|
||||||
import "sync"
|
|
||||||
|
|
||||||
// ringBuffer is a simple ring buffer implementation
|
|
||||||
type ringBuffer struct {
|
|
||||||
buf []byte
|
|
||||||
size int
|
|
||||||
r, w int64 // Read and write positions
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func newRingBuffer(size int) *ringBuffer {
|
|
||||||
return &ringBuffer{
|
|
||||||
buf: make([]byte, size),
|
|
||||||
size: size,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ringBuffer) Write(p []byte) (n int, err error) {
|
|
||||||
if len(p) == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
|
|
||||||
if len(p) > r.size {
|
|
||||||
p = p[:r.size]
|
|
||||||
}
|
|
||||||
|
|
||||||
n = len(p)
|
|
||||||
|
|
||||||
// Write data, handling wrap-around
|
|
||||||
pos := int(r.w % int64(r.size))
|
|
||||||
writeLen := min(len(p), r.size-pos)
|
|
||||||
copy(r.buf[pos:], p[:writeLen])
|
|
||||||
|
|
||||||
// If we have more data and need to wrap around
|
|
||||||
if writeLen < len(p) {
|
|
||||||
copy(r.buf, p[writeLen:])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update write position
|
|
||||||
r.w += int64(n)
|
|
||||||
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ringBuffer) Read(p []byte) (n int, err error) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
|
|
||||||
if r.w == r.r {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate available data accounting for wraparound
|
|
||||||
available := int(r.w - r.r)
|
|
||||||
if available < 0 {
|
|
||||||
available += r.size
|
|
||||||
}
|
|
||||||
available = min(available, r.size)
|
|
||||||
|
|
||||||
// Limit read to buffer size
|
|
||||||
toRead := min(available, len(p))
|
|
||||||
if toRead == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read data, handling wrap-around
|
|
||||||
pos := int(r.r % int64(r.size))
|
|
||||||
readLen := min(toRead, r.size-pos)
|
|
||||||
n = copy(p, r.buf[pos:pos+readLen])
|
|
||||||
|
|
||||||
// If we need more data and need to wrap around
|
|
||||||
if readLen < toRead {
|
|
||||||
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update read position
|
|
||||||
r.r += int64(n)
|
|
||||||
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
@@ -12,14 +11,14 @@ import (
|
|||||||
// PeerRule to handle management of rules
|
// PeerRule to handle management of rules
|
||||||
type PeerRule struct {
|
type PeerRule struct {
|
||||||
id string
|
id string
|
||||||
ip net.IP
|
mgmtId []byte
|
||||||
|
ip netip.Addr
|
||||||
ipLayer gopacket.LayerType
|
ipLayer gopacket.LayerType
|
||||||
matchByIP bool
|
matchByIP bool
|
||||||
protoLayer gopacket.LayerType
|
protoLayer gopacket.LayerType
|
||||||
sPort *firewall.Port
|
sPort *firewall.Port
|
||||||
dPort *firewall.Port
|
dPort *firewall.Port
|
||||||
drop bool
|
drop bool
|
||||||
comment string
|
|
||||||
|
|
||||||
udpHook func([]byte) bool
|
udpHook func([]byte) bool
|
||||||
}
|
}
|
||||||
@@ -31,6 +30,7 @@ func (r *PeerRule) ID() string {
|
|||||||
|
|
||||||
type RouteRule struct {
|
type RouteRule struct {
|
||||||
id string
|
id string
|
||||||
|
mgmtId []byte
|
||||||
sources []netip.Prefix
|
sources []netip.Prefix
|
||||||
destination netip.Prefix
|
destination netip.Prefix
|
||||||
proto firewall.Protocol
|
proto firewall.Protocol
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package uspfilter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
@@ -53,8 +53,8 @@ type TraceResult struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type PacketTrace struct {
|
type PacketTrace struct {
|
||||||
SourceIP net.IP
|
SourceIP netip.Addr
|
||||||
DestinationIP net.IP
|
DestinationIP netip.Addr
|
||||||
Protocol string
|
Protocol string
|
||||||
SourcePort uint16
|
SourcePort uint16
|
||||||
DestinationPort uint16
|
DestinationPort uint16
|
||||||
@@ -72,8 +72,8 @@ type TCPState struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type PacketBuilder struct {
|
type PacketBuilder struct {
|
||||||
SrcIP net.IP
|
SrcIP netip.Addr
|
||||||
DstIP net.IP
|
DstIP netip.Addr
|
||||||
Protocol fw.Protocol
|
Protocol fw.Protocol
|
||||||
SrcPort uint16
|
SrcPort uint16
|
||||||
DstPort uint16
|
DstPort uint16
|
||||||
@@ -126,8 +126,8 @@ func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
|
|||||||
Version: 4,
|
Version: 4,
|
||||||
TTL: 64,
|
TTL: 64,
|
||||||
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
||||||
SrcIP: p.SrcIP,
|
SrcIP: p.SrcIP.AsSlice(),
|
||||||
DstIP: p.DstIP,
|
DstIP: p.DstIP.AsSlice(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -260,28 +260,30 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
|||||||
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
|
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace {
|
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
|
||||||
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
||||||
return trace
|
return trace
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
|
if m.localipmanager.IsLocalIP(dstIP) {
|
||||||
return trace
|
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
|
||||||
|
return trace
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !m.handleRouting(trace) {
|
if !m.handleRouting(trace) {
|
||||||
return trace
|
return trace
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.nativeRouter {
|
if m.nativeRouter.Load() {
|
||||||
return m.handleNativeRouter(trace)
|
return m.handleNativeRouter(trace)
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.handleRouteACLs(trace, d, srcIP, dstIP)
|
return m.handleRouteACLs(trace, d, srcIP, dstIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool {
|
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||||
allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
|
allowed := m.isValidTrackedConnection(d, srcIP, dstIP, 0)
|
||||||
msg := "No existing connection found"
|
msg := "No existing connection found"
|
||||||
if allowed {
|
if allowed {
|
||||||
msg = m.buildConntrackStateMessage(d)
|
msg = m.buildConntrackStateMessage(d)
|
||||||
@@ -309,32 +311,46 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
|
|||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool {
|
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||||
if !m.localForwarding {
|
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
|
||||||
trace.AddResult(StageRouting, "Local forwarding disabled", false)
|
|
||||||
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
|
ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||||
|
|
||||||
|
strRuleId := "<no id>"
|
||||||
|
if ruleId != nil {
|
||||||
|
strRuleId = string(ruleId)
|
||||||
|
}
|
||||||
|
msg := fmt.Sprintf("Allowed by peer ACL rules (%s)", strRuleId)
|
||||||
|
if blocked {
|
||||||
|
msg = fmt.Sprintf("Blocked by peer ACL rules (%s)", strRuleId)
|
||||||
|
trace.AddResult(StagePeerACL, msg, false)
|
||||||
|
trace.AddResult(StageCompleted, "Packet dropped - ACL denied", false)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
|
trace.AddResult(StagePeerACL, msg, true)
|
||||||
blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
|
||||||
|
|
||||||
msg := "Allowed by peer ACL rules"
|
|
||||||
if blocked {
|
|
||||||
msg = "Blocked by peer ACL rules"
|
|
||||||
}
|
|
||||||
trace.AddResult(StagePeerACL, msg, !blocked)
|
|
||||||
|
|
||||||
|
// Handle netstack mode
|
||||||
if m.netstack {
|
if m.netstack {
|
||||||
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked)
|
switch {
|
||||||
|
case !m.localForwarding:
|
||||||
|
trace.AddResult(StageCompleted, "Packet sent to virtual stack", true)
|
||||||
|
case m.forwarder.Load() != nil:
|
||||||
|
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", true)
|
||||||
|
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||||
|
default:
|
||||||
|
trace.AddResult(StageCompleted, "Packet dropped - forwarder not initialized", false)
|
||||||
|
}
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked)
|
// In normal mode, packets are allowed through for local delivery
|
||||||
|
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleRouting(trace *PacketTrace) bool {
|
func (m *Manager) handleRouting(trace *PacketTrace) bool {
|
||||||
if !m.routingEnabled {
|
if !m.routingEnabled.Load() {
|
||||||
trace.AddResult(StageRouting, "Routing disabled", false)
|
trace.AddResult(StageRouting, "Routing disabled", false)
|
||||||
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
|
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
|
||||||
return false
|
return false
|
||||||
@@ -350,18 +366,23 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
|
|||||||
return trace
|
return trace
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace {
|
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
|
||||||
proto := getProtocolFromPacket(d)
|
proto, _ := getProtocolFromPacket(d)
|
||||||
srcPort, dstPort := getPortsFromPacket(d)
|
srcPort, dstPort := getPortsFromPacket(d)
|
||||||
allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||||
|
|
||||||
msg := "Allowed by route ACLs"
|
strId := string(id)
|
||||||
|
if id == nil {
|
||||||
|
strId = "<no id>"
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := fmt.Sprintf("Allowed by route ACLs (%s)", strId)
|
||||||
if !allowed {
|
if !allowed {
|
||||||
msg = "Blocked by route ACLs"
|
msg = fmt.Sprintf("Blocked by route ACLs (%s)", strId)
|
||||||
}
|
}
|
||||||
trace.AddResult(StageRouteACL, msg, allowed)
|
trace.AddResult(StageRouteACL, msg, allowed)
|
||||||
|
|
||||||
if allowed && m.forwarder != nil {
|
if allowed && m.forwarder.Load() != nil {
|
||||||
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -380,7 +401,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
|
|||||||
|
|
||||||
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||||
// will create or update the connection state
|
// will create or update the connection state
|
||||||
dropped := m.processOutgoingHooks(packetData)
|
dropped := m.processOutgoingHooks(packetData, 0)
|
||||||
if dropped {
|
if dropped {
|
||||||
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
440
client/firewall/uspfilter/tracer_test.go
Normal file
440
client/firewall/uspfilter/tracer_test.go
Normal file
@@ -0,0 +1,440 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
func verifyTraceStages(t *testing.T, trace *PacketTrace, expectedStages []PacketStage) {
|
||||||
|
t.Logf("Trace results: %v", trace.Results)
|
||||||
|
actualStages := make([]PacketStage, 0, len(trace.Results))
|
||||||
|
for _, result := range trace.Results {
|
||||||
|
actualStages = append(actualStages, result.Stage)
|
||||||
|
t.Logf("Stage: %s, Message: %s, Allowed: %v", result.Stage, result.Message, result.Allowed)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.ElementsMatch(t, expectedStages, actualStages, "Trace stages don't match expected stages")
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyFinalDisposition(t *testing.T, trace *PacketTrace, expectedAllowed bool) {
|
||||||
|
require.NotEmpty(t, trace.Results, "Trace should have results")
|
||||||
|
lastResult := trace.Results[len(trace.Results)-1]
|
||||||
|
require.Equal(t, StageCompleted, lastResult.Stage, "Last stage should be 'Completed'")
|
||||||
|
require.Equal(t, expectedAllowed, lastResult.Allowed, "Final disposition incorrect")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTracePacket(t *testing.T) {
|
||||||
|
setupTracerTest := func(statefulMode bool) *Manager {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: net.ParseIP("100.10.0.100"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
|
Mask: net.CIDRMask(16, 32),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if !statefulMode {
|
||||||
|
m.stateful = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
createPacketBuilder := func(srcIP, dstIP string, protocol fw.Protocol, srcPort, dstPort uint16, direction fw.RuleDirection) *PacketBuilder {
|
||||||
|
builder := &PacketBuilder{
|
||||||
|
SrcIP: netip.MustParseAddr(srcIP),
|
||||||
|
DstIP: netip.MustParseAddr(dstIP),
|
||||||
|
Protocol: protocol,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
Direction: direction,
|
||||||
|
}
|
||||||
|
|
||||||
|
if protocol == "tcp" {
|
||||||
|
builder.TCPState = &TCPState{SYN: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder
|
||||||
|
}
|
||||||
|
|
||||||
|
createICMPPacketBuilder := func(srcIP, dstIP string, icmpType, icmpCode uint8, direction fw.RuleDirection) *PacketBuilder {
|
||||||
|
return &PacketBuilder{
|
||||||
|
SrcIP: netip.MustParseAddr(srcIP),
|
||||||
|
DstIP: netip.MustParseAddr(dstIP),
|
||||||
|
Protocol: "icmp",
|
||||||
|
ICMPType: icmpType,
|
||||||
|
ICMPCode: icmpCode,
|
||||||
|
Direction: direction,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
setup func(*Manager)
|
||||||
|
packetBuilder func() *PacketBuilder
|
||||||
|
expectedStages []PacketStage
|
||||||
|
expectedAllow bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "LocalTraffic_ACLAllowed",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "LocalTraffic_ACLDenied",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionDrop
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "LocalTraffic_WithForwarder",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.netstack = true
|
||||||
|
m.localForwarding = true
|
||||||
|
|
||||||
|
m.forwarder.Store(&forwarder.Forwarder{})
|
||||||
|
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageForwarding,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "LocalTraffic_WithoutForwarder",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.netstack = true
|
||||||
|
m.localForwarding = false
|
||||||
|
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RoutedTraffic_ACLAllowed",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.routingEnabled.Store(true)
|
||||||
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
|
m.forwarder.Store(&forwarder.Forwarder{})
|
||||||
|
|
||||||
|
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||||
|
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
|
||||||
|
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StageRouteACL,
|
||||||
|
StageForwarding,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RoutedTraffic_ACLDenied",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.routingEnabled.Store(true)
|
||||||
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
|
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||||
|
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
|
||||||
|
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StageRouteACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RoutedTraffic_NativeRouter",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.routingEnabled.Store(true)
|
||||||
|
m.nativeRouter.Store(true)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StageRouteACL,
|
||||||
|
StageForwarding,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RoutedTraffic_RoutingDisabled",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.routingEnabled.Store(false)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ConnectionTracking_Hit",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
srcIP := netip.MustParseAddr("100.10.0.100")
|
||||||
|
dstIP := netip.MustParseAddr("1.1.1.1")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
m.tcpTracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, conntrack.TCPSyn, 0)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
pb := createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 80, 12345, fw.RuleDirectionIN)
|
||||||
|
pb.TCPState = &TCPState{SYN: true, ACK: true}
|
||||||
|
return pb
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "OutboundTraffic",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("100.10.0.100", "1.1.1.1", "tcp", 12345, 80, fw.RuleDirectionOUT)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ICMPEchoRequest",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolICMP
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 8, 0, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ICMPDestinationUnreachable",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolICMP
|
||||||
|
action := fw.ActionDrop
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 3, 0, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDPTraffic_WithoutHook",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolUDP
|
||||||
|
port := &fw.Port{Values: []uint16{53}}
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDPTraffic_WithHook",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
hookFunc := func([]byte) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "StatefulDisabled_NoTracking",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.stateful = false
|
||||||
|
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionDrop
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
m := setupTracerTest(true)
|
||||||
|
|
||||||
|
tc.setup(m)
|
||||||
|
|
||||||
|
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
|
||||||
|
"100.10.0.100 should be recognized as a local IP")
|
||||||
|
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")),
|
||||||
|
"172.17.0.2 should not be recognized as a local IP")
|
||||||
|
|
||||||
|
pb := tc.packetBuilder()
|
||||||
|
|
||||||
|
trace, err := m.TracePacketFromBuilder(pb)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
verifyTraceStages(t, trace, tc.expectedStages)
|
||||||
|
verifyFinalDisposition(t, trace, tc.expectedAllow)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
@@ -22,6 +23,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -65,9 +67,9 @@ func (r RouteRules) Sort() {
|
|||||||
// Manager userspace firewall manager
|
// Manager userspace firewall manager
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
// outgoingRules is used for hooks only
|
// outgoingRules is used for hooks only
|
||||||
outgoingRules map[string]RuleSet
|
outgoingRules map[netip.Addr]RuleSet
|
||||||
// incomingRules is used for filtering and hooks
|
// incomingRules is used for filtering and hooks
|
||||||
incomingRules map[string]RuleSet
|
incomingRules map[netip.Addr]RuleSet
|
||||||
routeRules RouteRules
|
routeRules RouteRules
|
||||||
wgNetwork *net.IPNet
|
wgNetwork *net.IPNet
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
@@ -79,9 +81,9 @@ type Manager struct {
|
|||||||
// indicates whether server routes are disabled
|
// indicates whether server routes are disabled
|
||||||
disableServerRoutes bool
|
disableServerRoutes bool
|
||||||
// indicates whether we forward packets not destined for ourselves
|
// indicates whether we forward packets not destined for ourselves
|
||||||
routingEnabled bool
|
routingEnabled atomic.Bool
|
||||||
// indicates whether we leave forwarding and filtering to the native firewall
|
// indicates whether we leave forwarding and filtering to the native firewall
|
||||||
nativeRouter bool
|
nativeRouter atomic.Bool
|
||||||
// indicates whether we track outbound connections
|
// indicates whether we track outbound connections
|
||||||
stateful bool
|
stateful bool
|
||||||
// indicates whether wireguards runs in netstack mode
|
// indicates whether wireguards runs in netstack mode
|
||||||
@@ -94,8 +96,9 @@ type Manager struct {
|
|||||||
udpTracker *conntrack.UDPTracker
|
udpTracker *conntrack.UDPTracker
|
||||||
icmpTracker *conntrack.ICMPTracker
|
icmpTracker *conntrack.ICMPTracker
|
||||||
tcpTracker *conntrack.TCPTracker
|
tcpTracker *conntrack.TCPTracker
|
||||||
forwarder *forwarder.Forwarder
|
forwarder atomic.Pointer[forwarder.Forwarder]
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
|
flowLogger nftypes.FlowLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
// decoder for packages
|
// decoder for packages
|
||||||
@@ -112,16 +115,16 @@ type decoder struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create userspace firewall manager constructor
|
// Create userspace firewall manager constructor
|
||||||
func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) {
|
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||||
return create(iface, nil, disableServerRoutes)
|
return create(iface, nil, disableServerRoutes, flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||||
if nativeFirewall == nil {
|
if nativeFirewall == nil {
|
||||||
return nil, errors.New("native firewall is nil")
|
return nil, errors.New("native firewall is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
mgr, err := create(iface, nativeFirewall, disableServerRoutes)
|
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -148,7 +151,7 @@ func parseCreateEnv() (bool, bool) {
|
|||||||
return disableConntrack, enableLocalForwarding
|
return disableConntrack, enableLocalForwarding
|
||||||
}
|
}
|
||||||
|
|
||||||
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||||
disableConntrack, enableLocalForwarding := parseCreateEnv()
|
disableConntrack, enableLocalForwarding := parseCreateEnv()
|
||||||
|
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
@@ -166,17 +169,18 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
nativeFirewall: nativeFirewall,
|
nativeFirewall: nativeFirewall,
|
||||||
outgoingRules: make(map[string]RuleSet),
|
outgoingRules: make(map[netip.Addr]RuleSet),
|
||||||
incomingRules: make(map[string]RuleSet),
|
incomingRules: make(map[netip.Addr]RuleSet),
|
||||||
wgIface: iface,
|
wgIface: iface,
|
||||||
localipmanager: newLocalIPManager(),
|
localipmanager: newLocalIPManager(),
|
||||||
disableServerRoutes: disableServerRoutes,
|
disableServerRoutes: disableServerRoutes,
|
||||||
routingEnabled: false,
|
|
||||||
stateful: !disableConntrack,
|
stateful: !disableConntrack,
|
||||||
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||||
|
flowLogger: flowLogger,
|
||||||
netstack: netstack.IsEnabled(),
|
netstack: netstack.IsEnabled(),
|
||||||
localForwarding: enableLocalForwarding,
|
localForwarding: enableLocalForwarding,
|
||||||
}
|
}
|
||||||
|
m.routingEnabled.Store(false)
|
||||||
|
|
||||||
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||||
return nil, fmt.Errorf("update local IPs: %w", err)
|
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||||
@@ -185,9 +189,9 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
if disableConntrack {
|
if disableConntrack {
|
||||||
log.Info("conntrack is disabled")
|
log.Info("conntrack is disabled")
|
||||||
} else {
|
} else {
|
||||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, flowLogger)
|
||||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger)
|
||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
// netstack needs the forwarder for local traffic
|
// netstack needs the forwarder for local traffic
|
||||||
@@ -208,7 +212,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
||||||
if m.forwarder == nil {
|
if m.forwarder.Load() == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
||||||
@@ -218,6 +222,7 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
|||||||
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||||
|
|
||||||
if _, err := m.AddRouteFiltering(
|
if _, err := m.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
||||||
wgPrefix,
|
wgPrefix,
|
||||||
firewall.ProtocolALL,
|
firewall.ProtocolALL,
|
||||||
@@ -251,20 +256,20 @@ func (m *Manager) determineRouting() error {
|
|||||||
|
|
||||||
switch {
|
switch {
|
||||||
case disableUspRouting:
|
case disableUspRouting:
|
||||||
m.routingEnabled = false
|
m.routingEnabled.Store(false)
|
||||||
m.nativeRouter = false
|
m.nativeRouter.Store(false)
|
||||||
log.Info("userspace routing is disabled")
|
log.Info("userspace routing is disabled")
|
||||||
|
|
||||||
case m.disableServerRoutes:
|
case m.disableServerRoutes:
|
||||||
// if server routes are disabled we will let packets pass to the native stack
|
// if server routes are disabled we will let packets pass to the native stack
|
||||||
m.routingEnabled = true
|
m.routingEnabled.Store(true)
|
||||||
m.nativeRouter = true
|
m.nativeRouter.Store(true)
|
||||||
|
|
||||||
log.Info("server routes are disabled")
|
log.Info("server routes are disabled")
|
||||||
|
|
||||||
case forceUserspaceRouter:
|
case forceUserspaceRouter:
|
||||||
m.routingEnabled = true
|
m.routingEnabled.Store(true)
|
||||||
m.nativeRouter = false
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
log.Info("userspace routing is forced")
|
log.Info("userspace routing is forced")
|
||||||
|
|
||||||
@@ -272,19 +277,19 @@ func (m *Manager) determineRouting() error {
|
|||||||
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
||||||
// netstack mode won't support native routing as there is no interface
|
// netstack mode won't support native routing as there is no interface
|
||||||
|
|
||||||
m.routingEnabled = true
|
m.routingEnabled.Store(true)
|
||||||
m.nativeRouter = true
|
m.nativeRouter.Store(true)
|
||||||
|
|
||||||
log.Info("native routing is enabled")
|
log.Info("native routing is enabled")
|
||||||
|
|
||||||
default:
|
default:
|
||||||
m.routingEnabled = true
|
m.routingEnabled.Store(true)
|
||||||
m.nativeRouter = false
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
log.Info("userspace routing enabled by default")
|
log.Info("userspace routing enabled by default")
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.routingEnabled && !m.nativeRouter {
|
if m.routingEnabled.Load() && !m.nativeRouter.Load() {
|
||||||
return m.initForwarder()
|
return m.initForwarder()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -293,24 +298,24 @@ func (m *Manager) determineRouting() error {
|
|||||||
|
|
||||||
// initForwarder initializes the forwarder, it disables routing on errors
|
// initForwarder initializes the forwarder, it disables routing on errors
|
||||||
func (m *Manager) initForwarder() error {
|
func (m *Manager) initForwarder() error {
|
||||||
if m.forwarder != nil {
|
if m.forwarder.Load() != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
||||||
intf := m.wgIface.GetWGDevice()
|
intf := m.wgIface.GetWGDevice()
|
||||||
if intf == nil {
|
if intf == nil {
|
||||||
m.routingEnabled = false
|
m.routingEnabled.Store(false)
|
||||||
return errors.New("forwarding not supported")
|
return errors.New("forwarding not supported")
|
||||||
}
|
}
|
||||||
|
|
||||||
forwarder, err := forwarder.New(m.wgIface, m.logger, m.netstack)
|
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.routingEnabled = false
|
m.routingEnabled.Store(false)
|
||||||
return fmt.Errorf("create forwarder: %w", err)
|
return fmt.Errorf("create forwarder: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.forwarder = forwarder
|
m.forwarder.Store(forwarder)
|
||||||
|
|
||||||
log.Debug("forwarder initialized")
|
log.Debug("forwarder initialized")
|
||||||
|
|
||||||
@@ -326,7 +331,7 @@ func (m *Manager) IsServerRouteSupported() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||||
if m.nativeRouter && m.nativeFirewall != nil {
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.AddNatRule(pair)
|
return m.nativeFirewall.AddNatRule(pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -337,7 +342,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
|
|
||||||
// RemoveNatRule removes a routing firewall rule
|
// RemoveNatRule removes a routing firewall rule
|
||||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
if m.nativeRouter && m.nativeFirewall != nil {
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.RemoveNatRule(pair)
|
return m.nativeFirewall.RemoveNatRule(pair)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -348,25 +353,31 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
// rule ID as comment for the rule
|
// rule ID as comment for the rule
|
||||||
func (m *Manager) AddPeerFiltering(
|
func (m *Manager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
_ string,
|
_ string,
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
|
// TODO: fix in upper layers
|
||||||
|
i, ok := netip.AddrFromSlice(ip)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid IP: %s", ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
i = i.Unmap()
|
||||||
r := PeerRule{
|
r := PeerRule{
|
||||||
id: uuid.New().String(),
|
id: uuid.New().String(),
|
||||||
ip: ip,
|
mgmtId: id,
|
||||||
|
ip: i,
|
||||||
ipLayer: layers.LayerTypeIPv6,
|
ipLayer: layers.LayerTypeIPv6,
|
||||||
matchByIP: true,
|
matchByIP: true,
|
||||||
drop: action == firewall.ActionDrop,
|
drop: action == firewall.ActionDrop,
|
||||||
comment: comment,
|
|
||||||
}
|
}
|
||||||
if ipNormalized := ip.To4(); ipNormalized != nil {
|
if i.Is4() {
|
||||||
r.ipLayer = layers.LayerTypeIPv4
|
r.ipLayer = layers.LayerTypeIPv4
|
||||||
r.ip = ipNormalized
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
|
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
|
||||||
@@ -391,15 +402,16 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
if _, ok := m.incomingRules[r.ip]; !ok {
|
||||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
m.incomingRules[r.ip] = make(RuleSet)
|
||||||
}
|
}
|
||||||
m.incomingRules[r.ip.String()][r.id] = r
|
m.incomingRules[r.ip][r.id] = r
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
return []firewall.Rule{&r}, nil
|
return []firewall.Rule{&r}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination netip.Prefix,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
@@ -407,16 +419,15 @@ func (m *Manager) AddRouteFiltering(
|
|||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
if m.nativeRouter && m.nativeFirewall != nil {
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
ruleID := uuid.New().String()
|
ruleID := uuid.New().String()
|
||||||
rule := RouteRule{
|
rule := RouteRule{
|
||||||
|
// TODO: consolidate these IDs
|
||||||
id: ruleID,
|
id: ruleID,
|
||||||
|
mgmtId: id,
|
||||||
sources: sources,
|
sources: sources,
|
||||||
destination: destination,
|
destination: destination,
|
||||||
proto: proto,
|
proto: proto,
|
||||||
@@ -425,14 +436,16 @@ func (m *Manager) AddRouteFiltering(
|
|||||||
action: action,
|
action: action,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.mutex.Lock()
|
||||||
m.routeRules = append(m.routeRules, rule)
|
m.routeRules = append(m.routeRules, rule)
|
||||||
m.routeRules.Sort()
|
m.routeRules.Sort()
|
||||||
|
m.mutex.Unlock()
|
||||||
|
|
||||||
return &rule, nil
|
return &rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
if m.nativeRouter && m.nativeFirewall != nil {
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -461,10 +474,10 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok {
|
if _, ok := m.incomingRules[r.ip][r.id]; !ok {
|
||||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||||
}
|
}
|
||||||
delete(m.incomingRules[r.ip.String()], r.id)
|
delete(m.incomingRules[r.ip], r.id)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -497,13 +510,13 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing filter outgoing packets
|
// DropOutgoing filter outgoing packets
|
||||||
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
|
||||||
return m.processOutgoingHooks(packetData)
|
return m.processOutgoingHooks(packetData, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming filter incoming packets
|
// DropIncoming filter incoming packets
|
||||||
func (m *Manager) DropIncoming(packetData []byte) bool {
|
func (m *Manager) DropIncoming(packetData []byte, size int) bool {
|
||||||
return m.dropFilter(packetData)
|
return m.dropFilter(packetData, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateLocalIPs updates the list of local IPs
|
// UpdateLocalIPs updates the list of local IPs
|
||||||
@@ -511,10 +524,7 @@ func (m *Manager) UpdateLocalIPs() error {
|
|||||||
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
||||||
m.mutex.RLock()
|
|
||||||
defer m.mutex.RUnlock()
|
|
||||||
|
|
||||||
d := m.decoders.Get().(*decoder)
|
d := m.decoders.Get().(*decoder)
|
||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
@@ -527,52 +537,37 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
srcIP, dstIP := m.extractIPs(d)
|
srcIP, dstIP := m.extractIPs(d)
|
||||||
if srcIP == nil {
|
if !srcIP.IsValid() {
|
||||||
|
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Track all protocols if stateful mode is enabled
|
if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
|
||||||
if m.stateful {
|
return true
|
||||||
switch d.decoded[1] {
|
|
||||||
case layers.LayerTypeUDP:
|
|
||||||
m.trackUDPOutbound(d, srcIP, dstIP)
|
|
||||||
case layers.LayerTypeTCP:
|
|
||||||
m.trackTCPOutbound(d, srcIP, dstIP)
|
|
||||||
case layers.LayerTypeICMPv4:
|
|
||||||
m.trackICMPOutbound(d, srcIP, dstIP)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process UDP hooks even if stateful mode is disabled
|
if m.stateful {
|
||||||
if d.decoded[1] == layers.LayerTypeUDP {
|
m.trackOutbound(d, srcIP, dstIP, size)
|
||||||
return m.checkUDPHooks(d, dstIP, packetData)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) {
|
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) {
|
||||||
switch d.decoded[0] {
|
switch d.decoded[0] {
|
||||||
case layers.LayerTypeIPv4:
|
case layers.LayerTypeIPv4:
|
||||||
return d.ip4.SrcIP, d.ip4.DstIP
|
src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
|
||||||
|
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
|
||||||
|
return src, dst
|
||||||
case layers.LayerTypeIPv6:
|
case layers.LayerTypeIPv6:
|
||||||
return d.ip6.SrcIP, d.ip6.DstIP
|
src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
|
||||||
|
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
|
||||||
|
return src, dst
|
||||||
default:
|
default:
|
||||||
return nil, nil
|
return netip.Addr{}, netip.Addr{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
|
||||||
flags := getTCPFlags(&d.tcp)
|
|
||||||
m.tcpTracker.TrackOutbound(
|
|
||||||
srcIP,
|
|
||||||
dstIP,
|
|
||||||
uint16(d.tcp.SrcPort),
|
|
||||||
uint16(d.tcp.DstPort),
|
|
||||||
flags,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTCPFlags(tcp *layers.TCP) uint8 {
|
func getTCPFlags(tcp *layers.TCP) uint8 {
|
||||||
var flags uint8
|
var flags uint8
|
||||||
if tcp.SYN {
|
if tcp.SYN {
|
||||||
@@ -596,45 +591,70 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
|
|||||||
return flags
|
return flags
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
|
||||||
m.udpTracker.TrackOutbound(
|
transport := d.decoded[1]
|
||||||
srcIP,
|
switch transport {
|
||||||
dstIP,
|
case layers.LayerTypeUDP:
|
||||||
uint16(d.udp.SrcPort),
|
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
|
||||||
uint16(d.udp.DstPort),
|
case layers.LayerTypeTCP:
|
||||||
)
|
flags := getTCPFlags(&d.tcp)
|
||||||
|
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool {
|
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byte, size int) {
|
||||||
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
|
transport := d.decoded[1]
|
||||||
if rules, exists := m.outgoingRules[ipKey]; exists {
|
switch transport {
|
||||||
for _, rule := range rules {
|
case layers.LayerTypeUDP:
|
||||||
if rule.udpHook != nil && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size)
|
||||||
return rule.udpHook(packetData)
|
case layers.LayerTypeTCP:
|
||||||
}
|
flags := getTCPFlags(&d.tcp)
|
||||||
|
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// udpHooksDrop checks if any UDP hooks should drop the packet
|
||||||
|
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||||
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
|
// Check specific destination IP first
|
||||||
|
if rules, exists := m.outgoingRules[dstIP]; exists {
|
||||||
|
for _, rule := range rules {
|
||||||
|
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||||
|
return rule.udpHook(packetData)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
// Check IPv4 unspecified address
|
||||||
if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest {
|
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
|
||||||
m.icmpTracker.TrackOutbound(
|
for _, rule := range rules {
|
||||||
srcIP,
|
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||||
dstIP,
|
return rule.udpHook(packetData)
|
||||||
d.icmp4.Id,
|
}
|
||||||
d.icmp4.Seq,
|
}
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check IPv6 unspecified address
|
||||||
|
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
|
||||||
|
for _, rule := range rules {
|
||||||
|
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||||
|
return rule.udpHook(packetData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// dropFilter implements filtering logic for incoming packets.
|
// dropFilter implements filtering logic for incoming packets.
|
||||||
// If it returns true, the packet should be dropped.
|
// If it returns true, the packet should be dropped.
|
||||||
func (m *Manager) dropFilter(packetData []byte) bool {
|
func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
||||||
m.mutex.RLock()
|
|
||||||
defer m.mutex.RUnlock()
|
|
||||||
|
|
||||||
d := m.decoders.Get().(*decoder)
|
d := m.decoders.Get().(*decoder)
|
||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
@@ -643,19 +663,19 @@ func (m *Manager) dropFilter(packetData []byte) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
srcIP, dstIP := m.extractIPs(d)
|
srcIP, dstIP := m.extractIPs(d)
|
||||||
if srcIP == nil {
|
if !srcIP.IsValid() {
|
||||||
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// For all inbound traffic, first check if it matches a tracked connection.
|
// For all inbound traffic, first check if it matches a tracked connection.
|
||||||
// This must happen before any other filtering because the packets are statefully tracked.
|
// This must happen before any other filtering because the packets are statefully tracked.
|
||||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
|
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.localipmanager.IsLocalIP(dstIP) {
|
if m.localipmanager.IsLocalIP(dstIP) {
|
||||||
return m.handleLocalTraffic(d, srcIP, dstIP, packetData)
|
return m.handleLocalTraffic(d, srcIP, dstIP, packetData, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
|
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
|
||||||
@@ -663,10 +683,29 @@ func (m *Manager) dropFilter(packetData []byte) bool {
|
|||||||
|
|
||||||
// handleLocalTraffic handles local traffic.
|
// handleLocalTraffic handles local traffic.
|
||||||
// If it returns true, the packet should be dropped.
|
// If it returns true, the packet should be dropped.
|
||||||
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
|
||||||
if m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) {
|
ruleID, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||||
m.logger.Trace("Dropping local packet (ACL denied): src=%s dst=%s",
|
if blocked {
|
||||||
srcIP, dstIP)
|
_, pnum := getProtocolFromPacket(d)
|
||||||
|
srcPort, dstPort := getPortsFromPacket(d)
|
||||||
|
|
||||||
|
m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||||
|
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
|
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: uuid.New(),
|
||||||
|
Type: nftypes.TypeDrop,
|
||||||
|
RuleID: ruleID,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: pnum,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
SourcePort: srcPort,
|
||||||
|
DestPort: dstPort,
|
||||||
|
// TODO: icmp type/code
|
||||||
|
RxPackets: 1,
|
||||||
|
RxBytes: uint64(size),
|
||||||
|
})
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -675,6 +714,9 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData
|
|||||||
return m.handleNetstackLocalTraffic(packetData)
|
return m.handleNetstackLocalTraffic(packetData)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// track inbound packets to get the correct direction and session id for flows
|
||||||
|
m.trackInbound(d, srcIP, dstIP, ruleID, size)
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -684,12 +726,12 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.forwarder == nil {
|
if m.forwarder.Load() == nil {
|
||||||
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
|
||||||
m.logger.Error("Failed to inject local packet: %v", err)
|
m.logger.Error("Failed to inject local packet: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -699,30 +741,43 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
|||||||
|
|
||||||
// handleRoutedTraffic handles routed traffic.
|
// handleRoutedTraffic handles routed traffic.
|
||||||
// If it returns true, the packet should be dropped.
|
// If it returns true, the packet should be dropped.
|
||||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool {
|
||||||
// Drop if routing is disabled
|
// Drop if routing is disabled
|
||||||
if !m.routingEnabled {
|
if !m.routingEnabled.Load() {
|
||||||
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
|
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||||
srcIP, dstIP)
|
srcIP, dstIP)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pass to native stack if native router is enabled or forced
|
// Pass to native stack if native router is enabled or forced
|
||||||
if m.nativeRouter {
|
if m.nativeRouter.Load() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
proto := getProtocolFromPacket(d)
|
proto, pnum := getProtocolFromPacket(d)
|
||||||
srcPort, dstPort := getPortsFromPacket(d)
|
srcPort, dstPort := getPortsFromPacket(d)
|
||||||
|
|
||||||
if !m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) {
|
if ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
|
||||||
m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v",
|
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||||
srcIP, srcPort, dstIP, dstPort, proto)
|
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
|
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: uuid.New(),
|
||||||
|
Type: nftypes.TypeDrop,
|
||||||
|
RuleID: ruleID,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: pnum,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
SourcePort: srcPort,
|
||||||
|
DestPort: dstPort,
|
||||||
|
// TODO: icmp type/code
|
||||||
|
})
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Let forwarder handle the packet if it passed route ACLs
|
// Let forwarder handle the packet if it passed route ACLs
|
||||||
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
|
||||||
m.logger.Error("Failed to inject incoming packet: %v", err)
|
m.logger.Error("Failed to inject incoming packet: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -730,16 +785,16 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func getProtocolFromPacket(d *decoder) firewall.Protocol {
|
func getProtocolFromPacket(d *decoder) (firewall.Protocol, nftypes.Protocol) {
|
||||||
switch d.decoded[1] {
|
switch d.decoded[1] {
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
return firewall.ProtocolTCP
|
return firewall.ProtocolTCP, nftypes.TCP
|
||||||
case layers.LayerTypeUDP:
|
case layers.LayerTypeUDP:
|
||||||
return firewall.ProtocolUDP
|
return firewall.ProtocolUDP, nftypes.UDP
|
||||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||||
return firewall.ProtocolICMP
|
return firewall.ProtocolICMP, nftypes.ICMP
|
||||||
default:
|
default:
|
||||||
return firewall.ProtocolALL
|
return firewall.ProtocolALL, nftypes.ProtocolUnknown
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -767,7 +822,7 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
|
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size int) bool {
|
||||||
switch d.decoded[1] {
|
switch d.decoded[1] {
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
return m.tcpTracker.IsValidInbound(
|
return m.tcpTracker.IsValidInbound(
|
||||||
@@ -776,6 +831,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
|
|||||||
uint16(d.tcp.SrcPort),
|
uint16(d.tcp.SrcPort),
|
||||||
uint16(d.tcp.DstPort),
|
uint16(d.tcp.DstPort),
|
||||||
getTCPFlags(&d.tcp),
|
getTCPFlags(&d.tcp),
|
||||||
|
size,
|
||||||
)
|
)
|
||||||
|
|
||||||
case layers.LayerTypeUDP:
|
case layers.LayerTypeUDP:
|
||||||
@@ -784,6 +840,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
|
|||||||
dstIP,
|
dstIP,
|
||||||
uint16(d.udp.SrcPort),
|
uint16(d.udp.SrcPort),
|
||||||
uint16(d.udp.DstPort),
|
uint16(d.udp.DstPort),
|
||||||
|
size,
|
||||||
)
|
)
|
||||||
|
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
@@ -791,8 +848,8 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
|
|||||||
srcIP,
|
srcIP,
|
||||||
dstIP,
|
dstIP,
|
||||||
d.icmp4.Id,
|
d.icmp4.Id,
|
||||||
d.icmp4.Seq,
|
|
||||||
d.icmp4.TypeCode.Type(),
|
d.icmp4.TypeCode.Type(),
|
||||||
|
size,
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: ICMPv6
|
// TODO: ICMPv6
|
||||||
@@ -812,25 +869,27 @@ func (m *Manager) isSpecialICMP(d *decoder) bool {
|
|||||||
icmpType == layers.ICMPv4TypeTimeExceeded
|
icmpType == layers.ICMPv4TypeTimeExceeded
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
|
func (m *Manager) peerACLsBlock(srcIP netip.Addr, packetData []byte, rules map[netip.Addr]RuleSet, d *decoder) ([]byte, bool) {
|
||||||
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
if m.isSpecialICMP(d) {
|
if m.isSpecialICMP(d) {
|
||||||
return false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
|
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP], d); ok {
|
||||||
return filter
|
return mgmtId, filter
|
||||||
}
|
}
|
||||||
|
|
||||||
if filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok {
|
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv4Unspecified()], d); ok {
|
||||||
return filter
|
return mgmtId, filter
|
||||||
}
|
}
|
||||||
|
|
||||||
if filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok {
|
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv6Unspecified()], d); ok {
|
||||||
return filter
|
return mgmtId, filter
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default policy: DROP ALL
|
// Default policy: DROP ALL
|
||||||
return true
|
return nil, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
||||||
@@ -850,15 +909,15 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) (bool, bool) {
|
func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
|
||||||
payloadLayer := d.decoded[1]
|
payloadLayer := d.decoded[1]
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if rule.matchByIP && !ip.Equal(rule.ip) {
|
if rule.matchByIP && ip.Compare(rule.ip) != 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if rule.protoLayer == layerTypeAll {
|
if rule.protoLayer == layerTypeAll {
|
||||||
return rule.drop, true
|
return rule.mgmtId, rule.drop, true
|
||||||
}
|
}
|
||||||
|
|
||||||
if payloadLayer != rule.protoLayer {
|
if payloadLayer != rule.protoLayer {
|
||||||
@@ -868,39 +927,36 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *de
|
|||||||
switch payloadLayer {
|
switch payloadLayer {
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
|
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
|
||||||
return rule.drop, true
|
return rule.mgmtId, rule.drop, true
|
||||||
}
|
}
|
||||||
case layers.LayerTypeUDP:
|
case layers.LayerTypeUDP:
|
||||||
// if rule has UDP hook (and if we are here we match this rule)
|
// if rule has UDP hook (and if we are here we match this rule)
|
||||||
// we ignore rule.drop and call this hook
|
// we ignore rule.drop and call this hook
|
||||||
if rule.udpHook != nil {
|
if rule.udpHook != nil {
|
||||||
return rule.udpHook(packetData), true
|
return rule.mgmtId, rule.udpHook(packetData), true
|
||||||
}
|
}
|
||||||
|
|
||||||
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||||
return rule.drop, true
|
return rule.mgmtId, rule.drop, true
|
||||||
}
|
}
|
||||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||||
return rule.drop, true
|
return rule.mgmtId, rule.drop, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false, false
|
return nil, false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// routeACLsPass returns treu if the packet is allowed by the route ACLs
|
// routeACLsPass returns true if the packet is allowed by the route ACLs
|
||||||
func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
defer m.mutex.RUnlock()
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
srcAddr := netip.AddrFrom4([4]byte(srcIP.To4()))
|
|
||||||
dstAddr := netip.AddrFrom4([4]byte(dstIP.To4()))
|
|
||||||
|
|
||||||
for _, rule := range m.routeRules {
|
for _, rule := range m.routeRules {
|
||||||
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) {
|
if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches {
|
||||||
return rule.action == firewall.ActionAccept
|
return rule.mgmtId, rule.action == firewall.ActionAccept
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||||
@@ -940,36 +996,32 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
|
|||||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||||
//
|
//
|
||||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||||
func (m *Manager) AddUDPPacketHook(
|
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
|
||||||
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
|
|
||||||
) string {
|
|
||||||
r := PeerRule{
|
r := PeerRule{
|
||||||
id: uuid.New().String(),
|
id: uuid.New().String(),
|
||||||
ip: ip,
|
ip: ip,
|
||||||
protoLayer: layers.LayerTypeUDP,
|
protoLayer: layers.LayerTypeUDP,
|
||||||
dPort: &firewall.Port{Values: []uint16{dPort}},
|
dPort: &firewall.Port{Values: []uint16{dPort}},
|
||||||
ipLayer: layers.LayerTypeIPv6,
|
ipLayer: layers.LayerTypeIPv6,
|
||||||
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
|
|
||||||
udpHook: hook,
|
udpHook: hook,
|
||||||
}
|
}
|
||||||
|
|
||||||
if ip.To4() != nil {
|
if ip.Is4() {
|
||||||
r.ipLayer = layers.LayerTypeIPv4
|
r.ipLayer = layers.LayerTypeIPv4
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
if in {
|
if in {
|
||||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
if _, ok := m.incomingRules[r.ip]; !ok {
|
||||||
m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
|
m.incomingRules[r.ip] = make(map[string]PeerRule)
|
||||||
}
|
}
|
||||||
m.incomingRules[r.ip.String()][r.id] = r
|
m.incomingRules[r.ip][r.id] = r
|
||||||
} else {
|
} else {
|
||||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
if _, ok := m.outgoingRules[r.ip]; !ok {
|
||||||
m.outgoingRules[r.ip.String()] = make(map[string]PeerRule)
|
m.outgoingRules[r.ip] = make(map[string]PeerRule)
|
||||||
}
|
}
|
||||||
m.outgoingRules[r.ip.String()][r.id] = r
|
m.outgoingRules[r.ip][r.id] = r
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
|
|
||||||
return r.id
|
return r.id
|
||||||
@@ -1017,20 +1069,21 @@ func (m *Manager) DisableRouting() error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if m.forwarder == nil {
|
fwder := m.forwarder.Load()
|
||||||
|
if fwder == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
m.routingEnabled = false
|
m.routingEnabled.Store(false)
|
||||||
m.nativeRouter = false
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
// don't stop forwarder if in use by netstack
|
// don't stop forwarder if in use by netstack
|
||||||
if m.netstack && m.localForwarding {
|
if m.netstack && m.localForwarding {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
m.forwarder.Stop()
|
fwder.Stop()
|
||||||
m.forwarder = nil
|
m.forwarder.Store(nil)
|
||||||
|
|
||||||
log.Debug("forwarder stopped")
|
log.Debug("forwarder stopped")
|
||||||
|
|
||||||
|
|||||||
@@ -93,8 +93,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
stateful: false,
|
stateful: false,
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
// Single rule allowing all traffic
|
// Single rule allowing all traffic
|
||||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
|
_, err := m.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
|
||||||
fw.ActionAccept, "", "allow all")
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
},
|
},
|
||||||
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
||||||
@@ -114,10 +113,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
// Add explicit rules matching return traffic pattern
|
// Add explicit rules matching return traffic pattern
|
||||||
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
||||||
ip := generateRandomIPs(1)[0]
|
ip := generateRandomIPs(1)[0]
|
||||||
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
|
_, err := m.AddPeerFiltering(
|
||||||
|
nil,
|
||||||
|
ip,
|
||||||
|
fw.ProtocolTCP,
|
||||||
&fw.Port{Values: []uint16{uint16(1024 + i)}},
|
&fw.Port{Values: []uint16{uint16(1024 + i)}},
|
||||||
&fw.Port{Values: []uint16{80}},
|
&fw.Port{Values: []uint16{80}},
|
||||||
fw.ActionAccept, "", "explicit return")
|
fw.ActionAccept,
|
||||||
|
"",
|
||||||
|
)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -128,8 +132,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
stateful: true,
|
stateful: true,
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
// Add some basic rules but rely on state for established connections
|
// Add some basic rules but rely on state for established connections
|
||||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
|
_, err := m.AddPeerFiltering(
|
||||||
fw.ActionDrop, "", "default drop")
|
nil,
|
||||||
|
net.ParseIP("0.0.0.0"),
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionDrop,
|
||||||
|
"",
|
||||||
|
)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
},
|
},
|
||||||
desc: "Connection tracking with established connections",
|
desc: "Connection tracking with established connections",
|
||||||
@@ -158,7 +169,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
// Create manager and basic setup
|
// Create manager and basic setup
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -182,13 +193,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
|
|
||||||
// For stateful scenarios, establish the connection
|
// For stateful scenarios, establish the connection
|
||||||
if sc.stateful {
|
if sc.stateful {
|
||||||
manager.processOutgoingHooks(outbound)
|
manager.processOutgoingHooks(outbound, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Measure inbound packet processing
|
// Measure inbound packet processing
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound)
|
manager.dropFilter(inbound, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -203,7 +214,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -219,7 +230,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, layers.IPProtocolTCP)
|
uint16(1024+i), 80, layers.IPProtocolTCP)
|
||||||
manager.processOutgoingHooks(outbound)
|
manager.processOutgoingHooks(outbound, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test packet
|
// Test packet
|
||||||
@@ -227,11 +238,11 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
||||||
|
|
||||||
// First establish our test connection
|
// First establish our test connection
|
||||||
manager.processOutgoingHooks(testOut)
|
manager.processOutgoingHooks(testOut, 0)
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(testIn)
|
manager.dropFilter(testIn, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -251,7 +262,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -267,12 +278,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
||||||
|
|
||||||
if sc.established {
|
if sc.established {
|
||||||
manager.processOutgoingHooks(outbound)
|
manager.processOutgoingHooks(outbound, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound)
|
manager.dropFilter(inbound, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -450,7 +461,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -466,25 +477,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
// For stateful cases and established connections
|
// For stateful cases and established connections
|
||||||
if !strings.Contains(sc.name, "allow_non_wg") ||
|
if !strings.Contains(sc.name, "allow_non_wg") ||
|
||||||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
||||||
manager.processOutgoingHooks(outbound)
|
manager.processOutgoingHooks(outbound, 0)
|
||||||
|
|
||||||
// For TCP post-handshake, simulate full handshake
|
// For TCP post-handshake, simulate full handshake
|
||||||
if sc.state == "post_handshake" {
|
if sc.state == "post_handshake" {
|
||||||
// SYN
|
// SYN
|
||||||
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
||||||
manager.processOutgoingHooks(syn)
|
manager.processOutgoingHooks(syn, 0)
|
||||||
// SYN-ACK
|
// SYN-ACK
|
||||||
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack)
|
manager.dropFilter(synack, 0)
|
||||||
// ACK
|
// ACK
|
||||||
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
||||||
manager.processOutgoingHooks(ack)
|
manager.processOutgoingHooks(ack, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound)
|
manager.dropFilter(inbound, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -577,7 +588,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -590,10 +601,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
// Single rule to allow all return traffic from port 80
|
// Single rule to allow all return traffic from port 80
|
||||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
&fw.Port{Values: []uint16{80}},
|
|
||||||
nil,
|
|
||||||
fw.ActionAccept, "", "return traffic")
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -616,17 +624,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
// Initial SYN
|
// Initial SYN
|
||||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||||
manager.processOutgoingHooks(syn)
|
manager.processOutgoingHooks(syn, 0)
|
||||||
|
|
||||||
// SYN-ACK
|
// SYN-ACK
|
||||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack)
|
manager.dropFilter(synack, 0)
|
||||||
|
|
||||||
// ACK
|
// ACK
|
||||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||||
manager.processOutgoingHooks(ack)
|
manager.processOutgoingHooks(ack, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare test packets simulating bidirectional traffic
|
// Prepare test packets simulating bidirectional traffic
|
||||||
@@ -647,9 +655,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
// Simulate bidirectional traffic
|
// Simulate bidirectional traffic
|
||||||
// First outbound data
|
// First outbound data
|
||||||
manager.processOutgoingHooks(outPackets[connIdx])
|
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
||||||
// Then inbound response - this is what we're actually measuring
|
// Then inbound response - this is what we're actually measuring
|
||||||
manager.dropFilter(inPackets[connIdx])
|
manager.dropFilter(inPackets[connIdx], 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -668,7 +676,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -681,10 +689,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
// Single rule to allow all return traffic from port 80
|
// Single rule to allow all return traffic from port 80
|
||||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
&fw.Port{Values: []uint16{80}},
|
|
||||||
nil,
|
|
||||||
fw.ActionAccept, "", "return traffic")
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -756,19 +761,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
p := patterns[connIdx]
|
p := patterns[connIdx]
|
||||||
|
|
||||||
// Connection establishment
|
// Connection establishment
|
||||||
manager.processOutgoingHooks(p.syn)
|
manager.processOutgoingHooks(p.syn, 0)
|
||||||
manager.dropFilter(p.synAck)
|
manager.dropFilter(p.synAck, 0)
|
||||||
manager.processOutgoingHooks(p.ack)
|
manager.processOutgoingHooks(p.ack, 0)
|
||||||
|
|
||||||
// Data transfer
|
// Data transfer
|
||||||
manager.processOutgoingHooks(p.request)
|
manager.processOutgoingHooks(p.request, 0)
|
||||||
manager.dropFilter(p.response)
|
manager.dropFilter(p.response, 0)
|
||||||
|
|
||||||
// Connection teardown
|
// Connection teardown
|
||||||
manager.processOutgoingHooks(p.finClient)
|
manager.processOutgoingHooks(p.finClient, 0)
|
||||||
manager.dropFilter(p.ackServer)
|
manager.dropFilter(p.ackServer, 0)
|
||||||
manager.dropFilter(p.finServer)
|
manager.dropFilter(p.finServer, 0)
|
||||||
manager.processOutgoingHooks(p.ackClient)
|
manager.processOutgoingHooks(p.ackClient, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -787,7 +792,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -799,10 +804,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
&fw.Port{Values: []uint16{80}},
|
|
||||||
nil,
|
|
||||||
fw.ActionAccept, "", "return traffic")
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -824,15 +826,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
for i := 0; i < sc.connCount; i++ {
|
for i := 0; i < sc.connCount; i++ {
|
||||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||||
manager.processOutgoingHooks(syn)
|
manager.processOutgoingHooks(syn, 0)
|
||||||
|
|
||||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack)
|
manager.dropFilter(synack, 0)
|
||||||
|
|
||||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||||
manager.processOutgoingHooks(ack)
|
manager.processOutgoingHooks(ack, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pre-generate test packets
|
// Pre-generate test packets
|
||||||
@@ -854,8 +856,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
counter++
|
counter++
|
||||||
|
|
||||||
// Simulate bidirectional traffic
|
// Simulate bidirectional traffic
|
||||||
manager.processOutgoingHooks(outPackets[connIdx])
|
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
||||||
manager.dropFilter(inPackets[connIdx])
|
manager.dropFilter(inPackets[connIdx], 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@@ -875,7 +877,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -886,10 +888,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
&fw.Port{Values: []uint16{80}},
|
|
||||||
nil,
|
|
||||||
fw.ActionAccept, "", "return traffic")
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -951,17 +950,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
p := patterns[connIdx]
|
p := patterns[connIdx]
|
||||||
|
|
||||||
// Full connection lifecycle
|
// Full connection lifecycle
|
||||||
manager.processOutgoingHooks(p.syn)
|
manager.processOutgoingHooks(p.syn, 0)
|
||||||
manager.dropFilter(p.synAck)
|
manager.dropFilter(p.synAck, 0)
|
||||||
manager.processOutgoingHooks(p.ack)
|
manager.processOutgoingHooks(p.ack, 0)
|
||||||
|
|
||||||
manager.processOutgoingHooks(p.request)
|
manager.processOutgoingHooks(p.request, 0)
|
||||||
manager.dropFilter(p.response)
|
manager.dropFilter(p.response, 0)
|
||||||
|
|
||||||
manager.processOutgoingHooks(p.finClient)
|
manager.processOutgoingHooks(p.finClient, 0)
|
||||||
manager.dropFilter(p.ackServer)
|
manager.dropFilter(p.ackServer, 0)
|
||||||
manager.dropFilter(p.finServer)
|
manager.dropFilter(p.finServer, 0)
|
||||||
manager.processOutgoingHooks(p.ackClient)
|
manager.processOutgoingHooks(p.ackClient, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@@ -1033,14 +1032,7 @@ func BenchmarkRouteACLs(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rules {
|
for _, r := range rules {
|
||||||
_, err := manager.AddRouteFiltering(
|
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
|
||||||
r.sources,
|
|
||||||
r.dest,
|
|
||||||
r.proto,
|
|
||||||
nil,
|
|
||||||
r.port,
|
|
||||||
fw.ActionAccept,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -1062,8 +1054,8 @@ func BenchmarkRouteACLs(b *testing.B) {
|
|||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
srcIP := net.ParseIP(tc.srcIP)
|
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||||
dstIP := net.ParseIP(tc.dstIP)
|
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||||
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false)
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, manager)
|
require.NotNil(t, manager)
|
||||||
|
|
||||||
@@ -192,20 +192,20 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
||||||
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
|
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
|
||||||
isDropped := manager.DropIncoming(packet)
|
isDropped := manager.DropIncoming(packet, 0)
|
||||||
require.True(t, isDropped, "Packet should be dropped when no rules exist")
|
require.True(t, isDropped, "Packet should be dropped when no rules exist")
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
rules, err := manager.AddPeerFiltering(
|
rules, err := manager.AddPeerFiltering(
|
||||||
|
nil,
|
||||||
net.ParseIP(tc.ruleIP),
|
net.ParseIP(tc.ruleIP),
|
||||||
tc.ruleProto,
|
tc.ruleProto,
|
||||||
tc.ruleSrcPort,
|
tc.ruleSrcPort,
|
||||||
tc.ruleDstPort,
|
tc.ruleDstPort,
|
||||||
tc.ruleAction,
|
tc.ruleAction,
|
||||||
"",
|
"",
|
||||||
tc.name,
|
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotEmpty(t, rules)
|
require.NotEmpty(t, rules)
|
||||||
@@ -217,7 +217,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||||
isDropped := manager.DropIncoming(packet)
|
isDropped := manager.DropIncoming(packet, 0)
|
||||||
require.Equal(t, tc.shouldBeBlocked, isDropped)
|
require.Equal(t, tc.shouldBeBlocked, isDropped)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -302,12 +302,12 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false)
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
require.NoError(tb, manager.EnableRouting())
|
require.NoError(tb, manager.EnableRouting())
|
||||||
require.NoError(tb, err)
|
require.NoError(tb, err)
|
||||||
require.NotNil(tb, manager)
|
require.NotNil(tb, manager)
|
||||||
require.True(tb, manager.routingEnabled)
|
require.True(tb, manager.routingEnabled.Load())
|
||||||
require.False(tb, manager.nativeRouter)
|
require.False(tb, manager.nativeRouter.Load())
|
||||||
|
|
||||||
tb.Cleanup(func() {
|
tb.Cleanup(func() {
|
||||||
require.NoError(tb, manager.Close(nil))
|
require.NoError(tb, manager.Close(nil))
|
||||||
@@ -803,6 +803,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
rule, err := manager.AddRouteFiltering(
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
tc.rule.sources,
|
tc.rule.sources,
|
||||||
tc.rule.dest,
|
tc.rule.dest,
|
||||||
tc.rule.proto,
|
tc.rule.proto,
|
||||||
@@ -817,12 +818,12 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||||
})
|
})
|
||||||
|
|
||||||
srcIP := net.ParseIP(tc.srcIP)
|
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||||
dstIP := net.ParseIP(tc.dstIP)
|
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||||
|
|
||||||
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
|
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
|
||||||
// to the forwarder
|
// to the forwarder
|
||||||
isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||||
require.Equal(t, tc.shouldPass, isAllowed)
|
require.Equal(t, tc.shouldPass, isAllowed)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -985,6 +986,7 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
var rules []fw.Rule
|
var rules []fw.Rule
|
||||||
for _, r := range tc.rules {
|
for _, r := range tc.rules {
|
||||||
rule, err := manager.AddRouteFiltering(
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
r.sources,
|
r.sources,
|
||||||
r.dest,
|
r.dest,
|
||||||
r.proto,
|
r.proto,
|
||||||
@@ -1004,10 +1006,10 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
for i, p := range tc.packets {
|
for i, p := range tc.packets {
|
||||||
srcIP := net.ParseIP(p.srcIP)
|
srcIP := netip.MustParseAddr(p.srcIP)
|
||||||
dstIP := net.ParseIP(p.dstIP)
|
dstIP := netip.MustParseAddr(p.dstIP)
|
||||||
|
|
||||||
isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
|
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
|
||||||
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
|
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -18,9 +20,11 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
|
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
type IFaceMock struct {
|
type IFaceMock struct {
|
||||||
SetFilterFunc func(device.PacketFilter) error
|
SetFilterFunc func(device.PacketFilter) error
|
||||||
@@ -62,7 +66,7 @@ func TestManagerCreate(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -82,7 +86,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -92,9 +96,8 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
|||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
|
||||||
|
|
||||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
rule, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -116,26 +119,25 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
ip := netip.MustParseAddr("192.168.1.1")
|
||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule 2"
|
|
||||||
|
|
||||||
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
if _, ok := m.incomingRules[ip.String()][r.ID()]; !ok {
|
if _, ok := m.incomingRules[ip][r.ID()]; !ok {
|
||||||
t.Errorf("rule2 is not in the incomingRules")
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -149,7 +151,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
if _, ok := m.incomingRules[ip.String()][r.ID()]; ok {
|
if _, ok := m.incomingRules[ip][r.ID()]; ok {
|
||||||
t.Errorf("rule2 is not in the incomingRules")
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -160,7 +162,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
in bool
|
in bool
|
||||||
expDir fw.RuleDirection
|
expDir fw.RuleDirection
|
||||||
ip net.IP
|
ip netip.Addr
|
||||||
dPort uint16
|
dPort uint16
|
||||||
hook func([]byte) bool
|
hook func([]byte) bool
|
||||||
expectedID string
|
expectedID string
|
||||||
@@ -169,7 +171,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
name: "Test Outgoing UDP Packet Hook",
|
name: "Test Outgoing UDP Packet Hook",
|
||||||
in: false,
|
in: false,
|
||||||
expDir: fw.RuleDirectionOUT,
|
expDir: fw.RuleDirectionOUT,
|
||||||
ip: net.IPv4(10, 168, 0, 1),
|
ip: netip.MustParseAddr("10.168.0.1"),
|
||||||
dPort: 8000,
|
dPort: 8000,
|
||||||
hook: func([]byte) bool { return true },
|
hook: func([]byte) bool { return true },
|
||||||
},
|
},
|
||||||
@@ -177,7 +179,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
name: "Test Incoming UDP Packet Hook",
|
name: "Test Incoming UDP Packet Hook",
|
||||||
in: true,
|
in: true,
|
||||||
expDir: fw.RuleDirectionIN,
|
expDir: fw.RuleDirectionIN,
|
||||||
ip: net.IPv6loopback,
|
ip: netip.MustParseAddr("::1"),
|
||||||
dPort: 9000,
|
dPort: 9000,
|
||||||
hook: func([]byte) bool { return false },
|
hook: func([]byte) bool { return false },
|
||||||
},
|
},
|
||||||
@@ -187,18 +189,18 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||||
|
|
||||||
var addedRule PeerRule
|
var addedRule PeerRule
|
||||||
if tt.in {
|
if tt.in {
|
||||||
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
if len(manager.incomingRules[tt.ip]) != 1 {
|
||||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, rule := range manager.incomingRules[tt.ip.String()] {
|
for _, rule := range manager.incomingRules[tt.ip] {
|
||||||
addedRule = rule
|
addedRule = rule
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -206,12 +208,12 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, rule := range manager.outgoingRules[tt.ip.String()] {
|
for _, rule := range manager.outgoingRules[tt.ip] {
|
||||||
addedRule = rule
|
addedRule = rule
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tt.ip.Equal(addedRule.ip) {
|
if tt.ip.Compare(addedRule.ip) != 0 {
|
||||||
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -236,7 +238,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -246,9 +248,8 @@ func TestManagerReset(t *testing.T) {
|
|||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
|
||||||
|
|
||||||
_, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
_, err = m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -279,7 +280,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -292,9 +293,8 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
ip := net.ParseIP("0.0.0.0")
|
ip := net.ParseIP("0.0.0.0")
|
||||||
proto := fw.ProtocolUDP
|
proto := fw.ProtocolUDP
|
||||||
action := fw.ActionAccept
|
action := fw.ActionAccept
|
||||||
comment := "Test rule"
|
|
||||||
|
|
||||||
_, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment)
|
_, err = m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -328,7 +328,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.dropFilter(buf.Bytes()) {
|
if m.dropFilter(buf.Bytes(), 0) {
|
||||||
t.Errorf("expected packet to be accepted")
|
t.Errorf("expected packet to be accepted")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -347,7 +347,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// creating manager instance
|
// creating manager instance
|
||||||
manager, err := Create(iface, false)
|
manager, err := Create(iface, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create Manager: %s", err)
|
t.Fatalf("Failed to create Manager: %s", err)
|
||||||
}
|
}
|
||||||
@@ -357,7 +357,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
|
|
||||||
// Add a UDP packet hook
|
// Add a UDP packet hook
|
||||||
hookFunc := func(data []byte) bool { return true }
|
hookFunc := func(data []byte) bool { return true }
|
||||||
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc)
|
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
|
||||||
|
|
||||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||||
found := false
|
found := false
|
||||||
@@ -393,7 +393,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
func TestProcessOutgoingHooks(t *testing.T) {
|
func TestProcessOutgoingHooks(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
manager.wgNetwork = &net.IPNet{
|
||||||
@@ -401,7 +401,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(16, 32),
|
Mask: net.CIDRMask(16, 32),
|
||||||
}
|
}
|
||||||
manager.udpTracker.Close()
|
manager.udpTracker.Close()
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger)
|
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
}()
|
}()
|
||||||
@@ -423,7 +423,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
hookCalled := false
|
hookCalled := false
|
||||||
hookID := manager.AddUDPPacketHook(
|
hookID := manager.AddUDPPacketHook(
|
||||||
false,
|
false,
|
||||||
net.ParseIP("100.10.0.100"),
|
netip.MustParseAddr("100.10.0.100"),
|
||||||
53,
|
53,
|
||||||
func([]byte) bool {
|
func([]byte) bool {
|
||||||
hookCalled = true
|
hookCalled = true
|
||||||
@@ -458,7 +458,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Test hook gets called
|
// Test hook gets called
|
||||||
result := manager.processOutgoingHooks(buf.Bytes())
|
result := manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||||
require.True(t, result)
|
require.True(t, result)
|
||||||
require.True(t, hookCalled)
|
require.True(t, hookCalled)
|
||||||
|
|
||||||
@@ -468,7 +468,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
result = manager.processOutgoingHooks(buf.Bytes())
|
result = manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||||
require.False(t, result)
|
require.False(t, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -479,7 +479,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
manager, err := Create(ifaceMock, false)
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
@@ -494,7 +494,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
}
|
}
|
||||||
@@ -506,7 +506,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
manager.wgNetwork = &net.IPNet{
|
||||||
@@ -515,7 +515,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
manager.udpTracker.Close() // Close the existing tracker
|
manager.udpTracker.Close() // Close the existing tracker
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger)
|
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
|
||||||
manager.decoders = sync.Pool{
|
manager.decoders = sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
d := &decoder{
|
d := &decoder{
|
||||||
@@ -534,8 +534,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// Set up packet parameters
|
// Set up packet parameters
|
||||||
srcIP := net.ParseIP("100.10.0.1")
|
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||||
dstIP := net.ParseIP("100.10.0.100")
|
dstIP := netip.MustParseAddr("100.10.0.100")
|
||||||
srcPort := uint16(51334)
|
srcPort := uint16(51334)
|
||||||
dstPort := uint16(53)
|
dstPort := uint16(53)
|
||||||
|
|
||||||
@@ -543,8 +543,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
outboundIPv4 := &layers.IPv4{
|
outboundIPv4 := &layers.IPv4{
|
||||||
TTL: 64,
|
TTL: 64,
|
||||||
Version: 4,
|
Version: 4,
|
||||||
SrcIP: srcIP,
|
SrcIP: srcIP.AsSlice(),
|
||||||
DstIP: dstIP,
|
DstIP: dstIP.AsSlice(),
|
||||||
Protocol: layers.IPProtocolUDP,
|
Protocol: layers.IPProtocolUDP,
|
||||||
}
|
}
|
||||||
outboundUDP := &layers.UDP{
|
outboundUDP := &layers.UDP{
|
||||||
@@ -569,15 +569,15 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Process outbound packet and verify connection tracking
|
// Process outbound packet and verify connection tracking
|
||||||
drop := manager.DropOutgoing(outboundBuf.Bytes())
|
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
|
||||||
require.False(t, drop, "Initial outbound packet should not be dropped")
|
require.False(t, drop, "Initial outbound packet should not be dropped")
|
||||||
|
|
||||||
// Verify connection was tracked
|
// Verify connection was tracked
|
||||||
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
require.True(t, exists, "Connection should be tracked after outbound packet")
|
require.True(t, exists, "Connection should be tracked after outbound packet")
|
||||||
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match")
|
require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
|
||||||
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match")
|
require.True(t, dstIP.Compare(conn.DestIP) == 0, "Destination IP should match")
|
||||||
require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
|
require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
|
||||||
require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
|
require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
|
||||||
|
|
||||||
@@ -585,8 +585,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
inboundIPv4 := &layers.IPv4{
|
inboundIPv4 := &layers.IPv4{
|
||||||
TTL: 64,
|
TTL: 64,
|
||||||
Version: 4,
|
Version: 4,
|
||||||
SrcIP: dstIP, // Original destination is now source
|
SrcIP: dstIP.AsSlice(), // Original destination is now source
|
||||||
DstIP: srcIP, // Original source is now destination
|
DstIP: srcIP.AsSlice(), // Original source is now destination
|
||||||
Protocol: layers.IPProtocolUDP,
|
Protocol: layers.IPProtocolUDP,
|
||||||
}
|
}
|
||||||
inboundUDP := &layers.UDP{
|
inboundUDP := &layers.UDP{
|
||||||
@@ -636,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
for _, cp := range checkPoints {
|
for _, cp := range checkPoints {
|
||||||
time.Sleep(cp.sleep)
|
time.Sleep(cp.sleep)
|
||||||
|
|
||||||
drop = manager.dropFilter(inboundBuf.Bytes())
|
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
|
||||||
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||||
|
|
||||||
// If the connection should still be valid, verify it exists
|
// If the connection should still be valid, verify it exists
|
||||||
@@ -685,7 +685,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create a new outbound connection for invalid tests
|
// Create a new outbound connection for invalid tests
|
||||||
drop = manager.processOutgoingHooks(outboundBuf.Bytes())
|
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
|
||||||
require.False(t, drop, "Second outbound packet should not be dropped")
|
require.False(t, drop, "Second outbound packet should not be dropped")
|
||||||
|
|
||||||
for _, tc := range invalidCases {
|
for _, tc := range invalidCases {
|
||||||
@@ -707,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify the invalid packet is dropped
|
// Verify the invalid packet is dropped
|
||||||
drop = manager.dropFilter(testBuf.Bytes())
|
drop = manager.dropFilter(testBuf.Bytes(), 0)
|
||||||
require.True(t, drop, tc.description)
|
require.True(t, drop, tc.description)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package device
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
@@ -10,16 +11,16 @@ import (
|
|||||||
// PacketFilter interface for firewall abilities
|
// PacketFilter interface for firewall abilities
|
||||||
type PacketFilter interface {
|
type PacketFilter interface {
|
||||||
// DropOutgoing filter outgoing packets from host to external destinations
|
// DropOutgoing filter outgoing packets from host to external destinations
|
||||||
DropOutgoing(packetData []byte) bool
|
DropOutgoing(packetData []byte, size int) bool
|
||||||
|
|
||||||
// DropIncoming filter incoming packets from external sources to host
|
// DropIncoming filter incoming packets from external sources to host
|
||||||
DropIncoming(packetData []byte) bool
|
DropIncoming(packetData []byte, size int) bool
|
||||||
|
|
||||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||||
//
|
//
|
||||||
// Hook function returns flag which indicates should be the matched package dropped or not.
|
// Hook function returns flag which indicates should be the matched package dropped or not.
|
||||||
// Hook function receives raw network packet data as argument.
|
// Hook function receives raw network packet data as argument.
|
||||||
AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string
|
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
|
||||||
|
|
||||||
// RemovePacketHook removes hook by ID
|
// RemovePacketHook removes hook by ID
|
||||||
RemovePacketHook(hookID string) error
|
RemovePacketHook(hookID string) error
|
||||||
@@ -57,7 +58,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) {
|
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
||||||
bufs = append(bufs[:i], bufs[i+1:]...)
|
bufs = append(bufs[:i], bufs[i+1:]...)
|
||||||
sizes = append(sizes[:i], sizes[i+1:]...)
|
sizes = append(sizes[:i], sizes[i+1:]...)
|
||||||
n--
|
n--
|
||||||
@@ -81,7 +82,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
|||||||
filteredBufs := make([][]byte, 0, len(bufs))
|
filteredBufs := make([][]byte, 0, len(bufs))
|
||||||
dropped := 0
|
dropped := 0
|
||||||
for _, buf := range bufs {
|
for _, buf := range bufs {
|
||||||
if !filter.DropIncoming(buf[offset:]) {
|
if !filter.DropIncoming(buf[offset:], len(buf)) {
|
||||||
filteredBufs = append(filteredBufs, buf)
|
filteredBufs = append(filteredBufs, buf)
|
||||||
dropped++
|
dropped++
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
|||||||
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
|
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
|
||||||
|
|
||||||
filter := mocks.NewMockPacketFilter(ctrl)
|
filter := mocks.NewMockPacketFilter(ctrl)
|
||||||
filter.EXPECT().DropIncoming(gomock.Any()).Return(true)
|
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
|
||||||
|
|
||||||
wrapped := newDeviceFilter(tun)
|
wrapped := newDeviceFilter(tun)
|
||||||
wrapped.filter = filter
|
wrapped.filter = filter
|
||||||
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
|||||||
return 1, nil
|
return 1, nil
|
||||||
})
|
})
|
||||||
filter := mocks.NewMockPacketFilter(ctrl)
|
filter := mocks.NewMockPacketFilter(ctrl)
|
||||||
filter.EXPECT().DropOutgoing(gomock.Any()).Return(true)
|
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
|
||||||
|
|
||||||
wrapped := newDeviceFilter(tun)
|
wrapped := newDeviceFilter(tun)
|
||||||
wrapped.filter = filter
|
wrapped.filter = filter
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ package mocks
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
net "net"
|
net "net"
|
||||||
|
"net/netip"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
@@ -35,7 +36,7 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddUDPPacketHook mocks base method.
|
// AddUDPPacketHook mocks base method.
|
||||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func([]byte) bool) string {
|
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||||
ret0, _ := ret[0].(string)
|
ret0, _ := ret[0].(string)
|
||||||
@@ -49,31 +50,31 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming mocks base method.
|
// DropIncoming mocks base method.
|
||||||
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
|
func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DropIncoming", arg0)
|
ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming indicates an expected call of DropIncoming.
|
// DropIncoming indicates an expected call of DropIncoming.
|
||||||
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing mocks base method.
|
// DropOutgoing mocks base method.
|
||||||
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
|
func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
|
ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing indicates an expected call of DropOutgoing.
|
// DropOutgoing indicates an expected call of DropOutgoing.
|
||||||
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemovePacketHook mocks base method.
|
// RemovePacketHook mocks base method.
|
||||||
|
|||||||
@@ -240,7 +240,7 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul
|
|||||||
|
|
||||||
dPorts := convertPortInfo(rule.PortInfo)
|
dPorts := convertPortInfo(rule.PortInfo)
|
||||||
|
|
||||||
addedRule, err := d.firewall.AddRouteFiltering(sources, destination, protocol, nil, dPorts, action)
|
addedRule, err := d.firewall.AddRouteFiltering(rule.PolicyID, sources, destination, protocol, nil, dPorts, action)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("add route rule: %w", err)
|
return "", fmt.Errorf("add route rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -281,7 +281,7 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action, "")
|
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action)
|
||||||
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
|
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
|
||||||
return ruleID, rulesPair, nil
|
return ruleID, rulesPair, nil
|
||||||
}
|
}
|
||||||
@@ -289,11 +289,11 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
|||||||
var rules []firewall.Rule
|
var rules []firewall.Rule
|
||||||
switch r.Direction {
|
switch r.Direction {
|
||||||
case mgmProto.RuleDirection_IN:
|
case mgmProto.RuleDirection_IN:
|
||||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||||
case mgmProto.RuleDirection_OUT:
|
case mgmProto.RuleDirection_OUT:
|
||||||
// TODO: Remove this soon. Outbound rules are obsolete.
|
// TODO: Remove this soon. Outbound rules are obsolete.
|
||||||
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
|
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
|
||||||
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||||
default:
|
default:
|
||||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||||
}
|
}
|
||||||
@@ -322,14 +322,14 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) addInRules(
|
func (d *DefaultManager) addInRules(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
protocol firewall.Protocol,
|
protocol firewall.Protocol,
|
||||||
port *firewall.Port,
|
port *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
rule, err := d.firewall.AddPeerFiltering(ip, protocol, nil, port, action, ipsetName, comment)
|
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, nil, port, action, ipsetName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -338,18 +338,18 @@ func (d *DefaultManager) addInRules(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) addOutRules(
|
func (d *DefaultManager) addOutRules(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
protocol firewall.Protocol,
|
protocol firewall.Protocol,
|
||||||
port *firewall.Port,
|
port *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
if shouldSkipInvertedRule(protocol, port) {
|
if shouldSkipInvertedRule(protocol, port) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err := d.firewall.AddPeerFiltering(ip, protocol, port, nil, action, ipsetName, comment)
|
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, port, nil, action, ipsetName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -364,9 +364,8 @@ func (d *DefaultManager) getPeerRuleID(
|
|||||||
direction int,
|
direction int,
|
||||||
port *firewall.Port,
|
port *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
comment string,
|
|
||||||
) id.RuleID {
|
) id.RuleID {
|
||||||
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
|
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action))
|
||||||
if port != nil {
|
if port != nil {
|
||||||
idStr += port.String()
|
idStr += port.String()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package acl
|
package acl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -10,9 +11,12 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
func TestDefaultManager(t *testing.T) {
|
func TestDefaultManager(t *testing.T) {
|
||||||
networkMap := &mgmProto.NetworkMap{
|
networkMap := &mgmProto.NetworkMap{
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
@@ -52,7 +56,7 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create firewall: %v", err)
|
t.Errorf("create firewall: %v", err)
|
||||||
return
|
return
|
||||||
@@ -346,7 +350,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create firewall: %v", err)
|
t.Errorf("create firewall: %v", err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
@@ -30,6 +31,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
type mocWGIface struct {
|
type mocWGIface struct {
|
||||||
filter device.PacketFilter
|
filter device.PacketFilter
|
||||||
}
|
}
|
||||||
@@ -456,7 +459,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||||
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
|
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
|
||||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||||
packetfilter.EXPECT().SetNetwork(ipNet)
|
packetfilter.EXPECT().SetNetwork(ipNet)
|
||||||
@@ -917,7 +920,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pf, err := uspfilter.Create(wgIface, false)
|
pf, err := uspfilter.Create(wgIface, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create uspfilter: %v", err)
|
t.Fatalf("failed to create uspfilter: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
@@ -117,5 +117,10 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
|
ip, err := netip.ParseAddr(s.runtimeIP)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("parse runtime ip: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return filter.AddUDPPacketHook(false, ip, uint16(s.runtimePort), hook), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ func (h *Manager) allowDNSFirewall() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "", "")
|
dnsRules, err := h.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -34,6 +34,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||||
@@ -189,6 +191,7 @@ type Engine struct {
|
|||||||
persistNetworkMap bool
|
persistNetworkMap bool
|
||||||
latestNetworkMap *mgmProto.NetworkMap
|
latestNetworkMap *mgmProto.NetworkMap
|
||||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||||
|
flowManager nftypes.FlowManager
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peer is an instance of the Connection Peer
|
// Peer is an instance of the Connection Peer
|
||||||
@@ -308,6 +311,12 @@ func (e *Engine) Stop() error {
|
|||||||
time.Sleep(500 * time.Millisecond)
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
e.close()
|
e.close()
|
||||||
|
|
||||||
|
// stop flow manager after wg interface is gone
|
||||||
|
if e.flowManager != nil {
|
||||||
|
e.flowManager.Close()
|
||||||
|
}
|
||||||
|
|
||||||
log.Infof("stopped Netbird Engine")
|
log.Infof("stopped Netbird Engine")
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
@@ -342,6 +351,10 @@ func (e *Engine) Start() error {
|
|||||||
}
|
}
|
||||||
e.wgInterface = wgIface
|
e.wgInterface = wgIface
|
||||||
|
|
||||||
|
// start flow manager right after interface creation
|
||||||
|
publicKey := e.config.WgPrivateKey.PublicKey()
|
||||||
|
e.flowManager = netflow.NewManager(e.ctx, e.wgInterface, publicKey[:], e.statusRecorder)
|
||||||
|
|
||||||
if e.config.RosenpassEnabled {
|
if e.config.RosenpassEnabled {
|
||||||
log.Infof("rosenpass is enabled")
|
log.Infof("rosenpass is enabled")
|
||||||
if e.config.RosenpassPermissive {
|
if e.config.RosenpassPermissive {
|
||||||
@@ -448,7 +461,7 @@ func (e *Engine) createFirewall() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.config.DisableServerRoutes)
|
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes)
|
||||||
if err != nil || e.firewall == nil {
|
if err != nil || e.firewall == nil {
|
||||||
log.Errorf("failed creating firewall manager: %s", err)
|
log.Errorf("failed creating firewall manager: %s", err)
|
||||||
return nil
|
return nil
|
||||||
@@ -482,13 +495,13 @@ func (e *Engine) initFirewall() error {
|
|||||||
|
|
||||||
// this rule is static and will be torn down on engine down by the firewall manager
|
// this rule is static and will be torn down on engine down by the firewall manager
|
||||||
if _, err := e.firewall.AddPeerFiltering(
|
if _, err := e.firewall.AddPeerFiltering(
|
||||||
|
nil,
|
||||||
net.IP{0, 0, 0, 0},
|
net.IP{0, 0, 0, 0},
|
||||||
firewallManager.ProtocolUDP,
|
firewallManager.ProtocolUDP,
|
||||||
nil,
|
nil,
|
||||||
&port,
|
&port,
|
||||||
firewallManager.ActionAccept,
|
firewallManager.ActionAccept,
|
||||||
"",
|
"",
|
||||||
"",
|
|
||||||
); err != nil {
|
); err != nil {
|
||||||
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
|
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
|
||||||
return nil
|
return nil
|
||||||
@@ -512,6 +525,7 @@ func (e *Engine) blockLanAccess() {
|
|||||||
v4 := netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
v4 := netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||||
for _, network := range toBlock {
|
for _, network := range toBlock {
|
||||||
if _, err := e.firewall.AddRouteFiltering(
|
if _, err := e.firewall.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
[]netip.Prefix{v4},
|
[]netip.Prefix{v4},
|
||||||
network,
|
network,
|
||||||
firewallManager.ProtocolALL,
|
firewallManager.ProtocolALL,
|
||||||
@@ -642,25 +656,14 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
stunTurn = append(stunTurn, e.TURNs...)
|
stunTurn = append(stunTurn, e.TURNs...)
|
||||||
e.stunTurn.Store(stunTurn)
|
e.stunTurn.Store(stunTurn)
|
||||||
|
|
||||||
relayMsg := wCfg.GetRelay()
|
err = e.handleRelayUpdate(wCfg.GetRelay())
|
||||||
if relayMsg != nil {
|
if err != nil {
|
||||||
// when we receive token we expect valid address list too
|
return err
|
||||||
c := &auth.Token{
|
}
|
||||||
Payload: relayMsg.GetTokenPayload(),
|
|
||||||
Signature: relayMsg.GetTokenSignature(),
|
|
||||||
}
|
|
||||||
if err := e.relayManager.UpdateToken(c); err != nil {
|
|
||||||
log.Errorf("failed to update relay token: %v", err)
|
|
||||||
return fmt.Errorf("update relay token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
e.relayManager.UpdateServerURLs(relayMsg.Urls)
|
err = e.handleFlowUpdate(wCfg.GetFlow())
|
||||||
|
if err != nil {
|
||||||
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
|
return fmt.Errorf("handle the flow configuration: %w", err)
|
||||||
// We can ignore all errors because the guard will manage the reconnection retries.
|
|
||||||
_ = e.relayManager.Serve()
|
|
||||||
} else {
|
|
||||||
e.relayManager.UpdateServerURLs(nil)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo update signal
|
// todo update signal
|
||||||
@@ -691,6 +694,55 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
|
||||||
|
if update != nil {
|
||||||
|
// when we receive token we expect valid address list too
|
||||||
|
c := &auth.Token{
|
||||||
|
Payload: update.GetTokenPayload(),
|
||||||
|
Signature: update.GetTokenSignature(),
|
||||||
|
}
|
||||||
|
if err := e.relayManager.UpdateToken(c); err != nil {
|
||||||
|
return fmt.Errorf("update relay token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.relayManager.UpdateServerURLs(update.Urls)
|
||||||
|
|
||||||
|
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
|
||||||
|
// We can ignore all errors because the guard will manage the reconnection retries.
|
||||||
|
_ = e.relayManager.Serve()
|
||||||
|
} else {
|
||||||
|
e.relayManager.UpdateServerURLs(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) handleFlowUpdate(config *mgmProto.FlowConfig) error {
|
||||||
|
if config == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
flowConfig, err := toFlowLoggerConfig(config)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return e.flowManager.Update(flowConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func toFlowLoggerConfig(config *mgmProto.FlowConfig) (*nftypes.FlowConfig, error) {
|
||||||
|
if config.GetInterval() == nil {
|
||||||
|
return nil, errors.New("flow interval is nil")
|
||||||
|
}
|
||||||
|
return &nftypes.FlowConfig{
|
||||||
|
Enabled: config.GetEnabled(),
|
||||||
|
Counters: config.GetCounters(),
|
||||||
|
URL: config.GetUrl(),
|
||||||
|
TokenPayload: config.GetTokenPayload(),
|
||||||
|
TokenSignature: config.GetTokenSignature(),
|
||||||
|
Interval: config.GetInterval().AsDuration(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// updateChecksIfNew updates checks if there are changes and sync new meta with management
|
// updateChecksIfNew updates checks if there are changes and sync new meta with management
|
||||||
func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||||
// if checks are equal, we skip the update
|
// if checks are equal, we skip the update
|
||||||
|
|||||||
@@ -1435,13 +1435,13 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock())
|
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settings.NewManagerMock())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
|
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManagerMock(), peersUpdateManager, secretsManager, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
306
client/internal/netflow/conntrack/conntrack.go
Normal file
306
client/internal/netflow/conntrack/conntrack.go
Normal file
@@ -0,0 +1,306 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
nfct "github.com/ti-mo/conntrack"
|
||||||
|
"github.com/ti-mo/netfilter"
|
||||||
|
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultChannelSize = 100
|
||||||
|
|
||||||
|
// ConnTrack manages kernel-based conntrack events
|
||||||
|
type ConnTrack struct {
|
||||||
|
flowLogger nftypes.FlowLogger
|
||||||
|
iface nftypes.IFaceMapper
|
||||||
|
|
||||||
|
conn *nfct.Conn
|
||||||
|
mux sync.Mutex
|
||||||
|
|
||||||
|
instanceID uuid.UUID
|
||||||
|
started bool
|
||||||
|
done chan struct{}
|
||||||
|
sysctlModified bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new connection tracker that interfaces with the kernel's conntrack system
|
||||||
|
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack {
|
||||||
|
return &ConnTrack{
|
||||||
|
flowLogger: flowLogger,
|
||||||
|
iface: iface,
|
||||||
|
instanceID: uuid.New(),
|
||||||
|
started: false,
|
||||||
|
done: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins tracking connections by listening for conntrack events. This method is idempotent.
|
||||||
|
func (c *ConnTrack) Start(enableCounters bool) error {
|
||||||
|
c.mux.Lock()
|
||||||
|
defer c.mux.Unlock()
|
||||||
|
|
||||||
|
if c.started {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Starting conntrack event listening")
|
||||||
|
|
||||||
|
if enableCounters {
|
||||||
|
c.EnableAccounting()
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := nfct.Dial(nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("dial conntrack: %w", err)
|
||||||
|
}
|
||||||
|
c.conn = conn
|
||||||
|
|
||||||
|
events := make(chan nfct.Event, defaultChannelSize)
|
||||||
|
errChan, err := conn.Listen(events, 1, []netfilter.NetlinkGroup{
|
||||||
|
netfilter.GroupCTNew,
|
||||||
|
netfilter.GroupCTDestroy,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err := c.conn.Close(); err != nil {
|
||||||
|
log.Errorf("Error closing conntrack connection: %v", err)
|
||||||
|
}
|
||||||
|
c.conn = nil
|
||||||
|
return fmt.Errorf("start conntrack listener: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.started = true
|
||||||
|
|
||||||
|
go c.receiverRoutine(events, errChan)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConnTrack) receiverRoutine(events chan nfct.Event, errChan chan error) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case event := <-events:
|
||||||
|
c.handleEvent(event)
|
||||||
|
case err := <-errChan:
|
||||||
|
log.Errorf("Error from conntrack event listener: %v", err)
|
||||||
|
if err := c.conn.Close(); err != nil {
|
||||||
|
log.Errorf("Error closing conntrack connection: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the connection tracking. This method is idempotent.
|
||||||
|
func (c *ConnTrack) Stop() {
|
||||||
|
c.mux.Lock()
|
||||||
|
defer c.mux.Unlock()
|
||||||
|
|
||||||
|
if !c.started {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Stopping conntrack event listening")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case c.done <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.conn != nil {
|
||||||
|
if err := c.conn.Close(); err != nil {
|
||||||
|
log.Errorf("Error closing conntrack connection: %v", err)
|
||||||
|
}
|
||||||
|
c.conn = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.started = false
|
||||||
|
|
||||||
|
c.RestoreAccounting()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops listening for events and cleans up resources
|
||||||
|
func (c *ConnTrack) Close() error {
|
||||||
|
c.mux.Lock()
|
||||||
|
defer c.mux.Unlock()
|
||||||
|
|
||||||
|
if c.started {
|
||||||
|
select {
|
||||||
|
case c.done <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.conn != nil {
|
||||||
|
err := c.conn.Close()
|
||||||
|
c.conn = nil
|
||||||
|
c.started = false
|
||||||
|
|
||||||
|
c.RestoreAccounting()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("close conntrack: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleEvent processes incoming conntrack events
|
||||||
|
func (c *ConnTrack) handleEvent(event nfct.Event) {
|
||||||
|
if event.Flow == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if event.Type != nfct.EventNew && event.Type != nfct.EventDestroy {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
flow := *event.Flow
|
||||||
|
|
||||||
|
proto := nftypes.Protocol(flow.TupleOrig.Proto.Protocol)
|
||||||
|
if proto == nftypes.ProtocolUnknown {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
srcIP := flow.TupleOrig.IP.SourceAddress
|
||||||
|
dstIP := flow.TupleOrig.IP.DestinationAddress
|
||||||
|
|
||||||
|
if !c.relevantFlow(srcIP, dstIP) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var srcPort, dstPort uint16
|
||||||
|
var icmpType, icmpCode uint8
|
||||||
|
|
||||||
|
switch proto {
|
||||||
|
case nftypes.TCP, nftypes.UDP, nftypes.SCTP:
|
||||||
|
srcPort = flow.TupleOrig.Proto.SourcePort
|
||||||
|
dstPort = flow.TupleOrig.Proto.DestinationPort
|
||||||
|
case nftypes.ICMP:
|
||||||
|
icmpType = flow.TupleOrig.Proto.ICMPType
|
||||||
|
icmpCode = flow.TupleOrig.Proto.ICMPCode
|
||||||
|
}
|
||||||
|
|
||||||
|
flowID := c.getFlowID(flow.ID)
|
||||||
|
direction := c.inferDirection(srcIP, dstIP)
|
||||||
|
|
||||||
|
eventType := nftypes.TypeStart
|
||||||
|
eventStr := "New"
|
||||||
|
|
||||||
|
if event.Type == nfct.EventDestroy {
|
||||||
|
eventType = nftypes.TypeEnd
|
||||||
|
eventStr = "Ended"
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Tracef("%s %s %s connection: %s:%d -> %s:%d", eventStr, direction, proto, srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
|
c.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: flowID,
|
||||||
|
Type: eventType,
|
||||||
|
Direction: direction,
|
||||||
|
Protocol: proto,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
SourcePort: srcPort,
|
||||||
|
DestPort: dstPort,
|
||||||
|
ICMPType: icmpType,
|
||||||
|
ICMPCode: icmpCode,
|
||||||
|
RxPackets: c.mapRxPackets(flow, direction),
|
||||||
|
TxPackets: c.mapTxPackets(flow, direction),
|
||||||
|
RxBytes: c.mapRxBytes(flow, direction),
|
||||||
|
TxBytes: c.mapTxBytes(flow, direction),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// relevantFlow checks if the flow is related to the specified interface
|
||||||
|
func (c *ConnTrack) relevantFlow(srcIP, dstIP netip.Addr) bool {
|
||||||
|
// TODO: filter traffic by interface
|
||||||
|
|
||||||
|
wgnet := c.iface.Address().Network
|
||||||
|
if !wgnet.Contains(srcIP.AsSlice()) && !wgnet.Contains(dstIP.AsSlice()) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapRxPackets maps packet counts to RX based on flow direction
|
||||||
|
func (c *ConnTrack) mapRxPackets(flow nfct.Flow, direction nftypes.Direction) uint64 {
|
||||||
|
// For Ingress: CountersOrig is from external to us (RX)
|
||||||
|
// For Egress: CountersReply is from external to us (RX)
|
||||||
|
if direction == nftypes.Ingress {
|
||||||
|
return flow.CountersOrig.Packets
|
||||||
|
}
|
||||||
|
return flow.CountersReply.Packets
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapTxPackets maps packet counts to TX based on flow direction
|
||||||
|
func (c *ConnTrack) mapTxPackets(flow nfct.Flow, direction nftypes.Direction) uint64 {
|
||||||
|
// For Ingress: CountersReply is from us to external (TX)
|
||||||
|
// For Egress: CountersOrig is from us to external (TX)
|
||||||
|
if direction == nftypes.Ingress {
|
||||||
|
return flow.CountersReply.Packets
|
||||||
|
}
|
||||||
|
return flow.CountersOrig.Packets
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapRxBytes maps byte counts to RX based on flow direction
|
||||||
|
func (c *ConnTrack) mapRxBytes(flow nfct.Flow, direction nftypes.Direction) uint64 {
|
||||||
|
// For Ingress: CountersOrig is from external to us (RX)
|
||||||
|
// For Egress: CountersReply is from external to us (RX)
|
||||||
|
if direction == nftypes.Ingress {
|
||||||
|
return flow.CountersOrig.Bytes
|
||||||
|
}
|
||||||
|
return flow.CountersReply.Bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapTxBytes maps byte counts to TX based on flow direction
|
||||||
|
func (c *ConnTrack) mapTxBytes(flow nfct.Flow, direction nftypes.Direction) uint64 {
|
||||||
|
// For Ingress: CountersReply is from us to external (TX)
|
||||||
|
// For Egress: CountersOrig is from us to external (TX)
|
||||||
|
if direction == nftypes.Ingress {
|
||||||
|
return flow.CountersReply.Bytes
|
||||||
|
}
|
||||||
|
return flow.CountersOrig.Bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// getFlowID creates a unique UUID based on the conntrack ID and instance ID
|
||||||
|
func (c *ConnTrack) getFlowID(conntrackID uint32) uuid.UUID {
|
||||||
|
var buf [4]byte
|
||||||
|
binary.BigEndian.PutUint32(buf[:], conntrackID)
|
||||||
|
return uuid.NewSHA1(c.instanceID, buf[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConnTrack) inferDirection(srcIP, dstIP netip.Addr) nftypes.Direction {
|
||||||
|
wgaddr := c.iface.Address().IP
|
||||||
|
wgnetwork := c.iface.Address().Network
|
||||||
|
src, dst := srcIP.AsSlice(), dstIP.AsSlice()
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case wgaddr.Equal(src):
|
||||||
|
return nftypes.Egress
|
||||||
|
case wgaddr.Equal(dst):
|
||||||
|
return nftypes.Ingress
|
||||||
|
case wgnetwork.Contains(src):
|
||||||
|
// netbird network -> resource network
|
||||||
|
return nftypes.Ingress
|
||||||
|
case wgnetwork.Contains(dst):
|
||||||
|
// resource network -> netbird network
|
||||||
|
return nftypes.Egress
|
||||||
|
|
||||||
|
// TODO: handle site2site traffic
|
||||||
|
}
|
||||||
|
|
||||||
|
return nftypes.DirectionUnknown
|
||||||
|
}
|
||||||
9
client/internal/netflow/conntrack/conntrack_nonlinux.go
Normal file
9
client/internal/netflow/conntrack/conntrack_nonlinux.go
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
//go:build !linux || android
|
||||||
|
|
||||||
|
package conntrack
|
||||||
|
|
||||||
|
import nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
|
||||||
|
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) nftypes.ConnTracker {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
73
client/internal/netflow/conntrack/sysctl.go
Normal file
73
client/internal/netflow/conntrack/sysctl.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// conntrackAcctPath is the sysctl path for conntrack accounting
|
||||||
|
conntrackAcctPath = "net.netfilter.nf_conntrack_acct"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EnableAccounting ensures that connection tracking accounting is enabled in the kernel.
|
||||||
|
func (c *ConnTrack) EnableAccounting() {
|
||||||
|
// haven't restored yet
|
||||||
|
if c.sysctlModified {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
modified, err := setSysctl(conntrackAcctPath, 1)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to enable conntrack accounting: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.sysctlModified = modified
|
||||||
|
}
|
||||||
|
|
||||||
|
// RestoreAccounting restores the connection tracking accounting setting to its original value.
|
||||||
|
func (c *ConnTrack) RestoreAccounting() {
|
||||||
|
if !c.sysctlModified {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := setSysctl(conntrackAcctPath, 0); err != nil {
|
||||||
|
log.Warnf("Failed to restore conntrack accounting: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.sysctlModified = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// setSysctl sets a sysctl configuration and returns whether it was modified.
|
||||||
|
func setSysctl(key string, desiredValue int) (bool, error) {
|
||||||
|
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
|
||||||
|
|
||||||
|
currentValue, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("read sysctl %s: %w", key, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
|
||||||
|
if err != nil && len(currentValue) > 0 {
|
||||||
|
return false, fmt.Errorf("convert current value to int: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentV == desiredValue {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// nolint:gosec
|
||||||
|
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
|
||||||
|
return false, fmt.Errorf("write sysctl %s: %w", key, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
137
client/internal/netflow/logger/logger.go
Normal file
137
client/internal/netflow/logger/logger.go
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow/store"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
)
|
||||||
|
|
||||||
|
type rcvChan chan *types.EventFields
|
||||||
|
type Logger struct {
|
||||||
|
mux sync.Mutex
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
enabled atomic.Bool
|
||||||
|
rcvChan atomic.Pointer[rcvChan]
|
||||||
|
cancelReceiver context.CancelFunc
|
||||||
|
statusRecorder *peer.Status
|
||||||
|
wgIfaceIPNet net.IPNet
|
||||||
|
Store types.Store
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(ctx context.Context, statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger {
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
return &Logger{
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
statusRecorder: statusRecorder,
|
||||||
|
wgIfaceIPNet: wgIfaceIPNet,
|
||||||
|
Store: store.NewMemoryStore(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) StoreEvent(flowEvent types.EventFields) {
|
||||||
|
if !l.enabled.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c := l.rcvChan.Load()
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case *c <- &flowEvent:
|
||||||
|
default:
|
||||||
|
// todo: we should collect or log on this
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Enable() {
|
||||||
|
go l.startReceiver()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) startReceiver() {
|
||||||
|
if l.enabled.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
l.mux.Lock()
|
||||||
|
ctx, cancel := context.WithCancel(l.ctx)
|
||||||
|
l.cancelReceiver = cancel
|
||||||
|
l.mux.Unlock()
|
||||||
|
|
||||||
|
c := make(rcvChan, 100)
|
||||||
|
l.rcvChan.Store(&c)
|
||||||
|
l.enabled.Store(true)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Info("flow Memory store receiver stopped")
|
||||||
|
return
|
||||||
|
case eventFields := <-c:
|
||||||
|
id := uuid.New()
|
||||||
|
event := types.Event{
|
||||||
|
ID: id,
|
||||||
|
EventFields: *eventFields,
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if event.Direction == types.Ingress {
|
||||||
|
if !l.wgIfaceIPNet.Contains(net.IP(event.SourceIP.AsSlice())) {
|
||||||
|
event.SourceResourceID = []byte(l.statusRecorder.CheckRoutes(event.SourceIP))
|
||||||
|
}
|
||||||
|
} else if event.Direction == types.Egress {
|
||||||
|
if !l.wgIfaceIPNet.Contains(net.IP(event.DestIP.AsSlice())) {
|
||||||
|
event.DestResourceID = []byte(l.statusRecorder.CheckRoutes(event.DestIP))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
l.Store.StoreEvent(&event)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Disable() {
|
||||||
|
l.stop()
|
||||||
|
l.Store.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) stop() {
|
||||||
|
if !l.enabled.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
l.enabled.Store(false)
|
||||||
|
l.mux.Lock()
|
||||||
|
if l.cancelReceiver != nil {
|
||||||
|
l.cancelReceiver()
|
||||||
|
l.cancelReceiver = nil
|
||||||
|
}
|
||||||
|
l.rcvChan.Store(nil)
|
||||||
|
l.mux.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) GetEvents() []*types.Event {
|
||||||
|
return l.Store.GetEvents()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) DeleteEvents(ids []uuid.UUID) {
|
||||||
|
l.Store.DeleteEvents(ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Close() {
|
||||||
|
l.stop()
|
||||||
|
l.cancel()
|
||||||
|
}
|
||||||
68
client/internal/netflow/logger/logger_test.go
Normal file
68
client/internal/netflow/logger/logger_test.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package logger_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow/logger"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStore(t *testing.T) {
|
||||||
|
logger := logger.New(context.Background(), nil, net.IPNet{})
|
||||||
|
logger.Enable()
|
||||||
|
|
||||||
|
event := types.EventFields{
|
||||||
|
FlowID: uuid.New(),
|
||||||
|
Type: types.TypeStart,
|
||||||
|
Direction: types.Ingress,
|
||||||
|
Protocol: 6,
|
||||||
|
}
|
||||||
|
|
||||||
|
wait := func() { time.Sleep(time.Millisecond) }
|
||||||
|
wait()
|
||||||
|
logger.StoreEvent(event)
|
||||||
|
wait()
|
||||||
|
|
||||||
|
allEvents := logger.GetEvents()
|
||||||
|
matched := false
|
||||||
|
for _, e := range allEvents {
|
||||||
|
if e.EventFields.FlowID == event.FlowID {
|
||||||
|
matched = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !matched {
|
||||||
|
t.Errorf("didn't match any event")
|
||||||
|
}
|
||||||
|
|
||||||
|
// test disable
|
||||||
|
logger.Disable()
|
||||||
|
wait()
|
||||||
|
logger.StoreEvent(event)
|
||||||
|
wait()
|
||||||
|
allEvents = logger.GetEvents()
|
||||||
|
if len(allEvents) != 0 {
|
||||||
|
t.Errorf("expected 0 events, got %d", len(allEvents))
|
||||||
|
}
|
||||||
|
|
||||||
|
// test re-enable
|
||||||
|
logger.Enable()
|
||||||
|
wait()
|
||||||
|
logger.StoreEvent(event)
|
||||||
|
wait()
|
||||||
|
|
||||||
|
allEvents = logger.GetEvents()
|
||||||
|
matched = false
|
||||||
|
for _, e := range allEvents {
|
||||||
|
if e.EventFields.FlowID == event.FlowID {
|
||||||
|
matched = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !matched {
|
||||||
|
t.Errorf("didn't match any event")
|
||||||
|
}
|
||||||
|
}
|
||||||
240
client/internal/netflow/manager.go
Normal file
240
client/internal/netflow/manager.go
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
package netflow
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow/conntrack"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow/logger"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/flow/client"
|
||||||
|
"github.com/netbirdio/netbird/flow/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Manager handles netflow tracking and logging
|
||||||
|
type Manager struct {
|
||||||
|
mux sync.Mutex
|
||||||
|
logger nftypes.FlowLogger
|
||||||
|
flowConfig *nftypes.FlowConfig
|
||||||
|
conntrack nftypes.ConnTracker
|
||||||
|
ctx context.Context
|
||||||
|
receiverClient *client.GRPCClient
|
||||||
|
publicKey []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new netflow manager
|
||||||
|
func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
|
||||||
|
var ipNet net.IPNet
|
||||||
|
if iface != nil {
|
||||||
|
ipNet = *iface.Address().Network
|
||||||
|
}
|
||||||
|
flowLogger := logger.New(ctx, statusRecorder, ipNet)
|
||||||
|
|
||||||
|
var ct nftypes.ConnTracker
|
||||||
|
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
|
||||||
|
ct = conntrack.New(flowLogger, iface)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Manager{
|
||||||
|
logger: flowLogger,
|
||||||
|
conntrack: ct,
|
||||||
|
ctx: ctx,
|
||||||
|
publicKey: publicKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update applies new flow configuration settings
|
||||||
|
// needsNewClient checks if a new client needs to be created
|
||||||
|
func (m *Manager) needsNewClient(previous *nftypes.FlowConfig) bool {
|
||||||
|
current := m.flowConfig
|
||||||
|
return previous == nil ||
|
||||||
|
!previous.Enabled ||
|
||||||
|
previous.TokenPayload != current.TokenPayload ||
|
||||||
|
previous.TokenSignature != current.TokenSignature ||
|
||||||
|
previous.URL != current.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
// enableFlow starts components for flow tracking
|
||||||
|
func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
|
||||||
|
// first make sender ready so events don't pile up
|
||||||
|
if m.needsNewClient(previous) {
|
||||||
|
if m.receiverClient != nil {
|
||||||
|
if err := m.receiverClient.Close(); err != nil {
|
||||||
|
log.Warnf("error closing previous flow client: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create client: %w", err)
|
||||||
|
}
|
||||||
|
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
|
||||||
|
|
||||||
|
m.receiverClient = flowClient
|
||||||
|
go m.receiveACKs(flowClient)
|
||||||
|
go m.startSender()
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Enable()
|
||||||
|
|
||||||
|
if m.conntrack != nil {
|
||||||
|
if err := m.conntrack.Start(m.flowConfig.Counters); err != nil {
|
||||||
|
return fmt.Errorf("start conntrack: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// disableFlow stops components for flow tracking
|
||||||
|
func (m *Manager) disableFlow() error {
|
||||||
|
if m.conntrack != nil {
|
||||||
|
m.conntrack.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Disable()
|
||||||
|
|
||||||
|
if m.receiverClient != nil {
|
||||||
|
return m.receiverClient.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update applies new flow configuration settings
|
||||||
|
func (m *Manager) Update(update *nftypes.FlowConfig) error {
|
||||||
|
if update == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
|
previous := m.flowConfig
|
||||||
|
m.flowConfig = update
|
||||||
|
|
||||||
|
if update.Enabled {
|
||||||
|
return m.enableFlow(previous)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.disableFlow()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close cleans up all resources
|
||||||
|
func (m *Manager) Close() {
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
|
if m.conntrack != nil {
|
||||||
|
m.conntrack.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.receiverClient != nil {
|
||||||
|
if err := m.receiverClient.Close(); err != nil {
|
||||||
|
log.Warnf("failed to close receiver client: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLogger returns the flow logger
|
||||||
|
func (m *Manager) GetLogger() nftypes.FlowLogger {
|
||||||
|
return m.logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) startSender() {
|
||||||
|
ticker := time.NewTicker(m.flowConfig.Interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-m.ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
events := m.logger.GetEvents()
|
||||||
|
for _, event := range events {
|
||||||
|
if err := m.send(event); err != nil {
|
||||||
|
log.Errorf("failed to send flow event to server: %s", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Tracef("sent flow event: %s", event.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) receiveACKs(client *client.GRPCClient) {
|
||||||
|
err := client.Receive(m.ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
|
||||||
|
log.Tracef("received flow event ack: %s", ack.EventId)
|
||||||
|
m.logger.DeleteEvents([]uuid.UUID{uuid.UUID(ack.EventId)})
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil && !errors.Is(err, context.Canceled) {
|
||||||
|
log.Errorf("failed to receive flow event ack: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) send(event *nftypes.Event) error {
|
||||||
|
m.mux.Lock()
|
||||||
|
client := m.receiverClient
|
||||||
|
m.mux.Unlock()
|
||||||
|
|
||||||
|
if client == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return client.Send(toProtoEvent(m.publicKey, event))
|
||||||
|
}
|
||||||
|
|
||||||
|
func toProtoEvent(publicKey []byte, event *nftypes.Event) *proto.FlowEvent {
|
||||||
|
protoEvent := &proto.FlowEvent{
|
||||||
|
EventId: event.ID[:],
|
||||||
|
Timestamp: timestamppb.New(event.Timestamp),
|
||||||
|
PublicKey: publicKey,
|
||||||
|
FlowFields: &proto.FlowFields{
|
||||||
|
FlowId: event.FlowID[:],
|
||||||
|
RuleId: event.RuleID,
|
||||||
|
Type: proto.Type(event.Type),
|
||||||
|
Direction: proto.Direction(event.Direction),
|
||||||
|
Protocol: uint32(event.Protocol),
|
||||||
|
SourceIp: event.SourceIP.AsSlice(),
|
||||||
|
DestIp: event.DestIP.AsSlice(),
|
||||||
|
RxPackets: event.RxPackets,
|
||||||
|
TxPackets: event.TxPackets,
|
||||||
|
RxBytes: event.RxBytes,
|
||||||
|
TxBytes: event.TxBytes,
|
||||||
|
SourceResourceId: event.SourceResourceID,
|
||||||
|
DestResourceId: event.DestResourceID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if event.Protocol == nftypes.ICMP {
|
||||||
|
protoEvent.FlowFields.ConnectionInfo = &proto.FlowFields_IcmpInfo{
|
||||||
|
IcmpInfo: &proto.ICMPInfo{
|
||||||
|
IcmpType: uint32(event.ICMPType),
|
||||||
|
IcmpCode: uint32(event.ICMPCode),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return protoEvent
|
||||||
|
}
|
||||||
|
|
||||||
|
protoEvent.FlowFields.ConnectionInfo = &proto.FlowFields_PortInfo{
|
||||||
|
PortInfo: &proto.PortInfo{
|
||||||
|
SourcePort: uint32(event.SourcePort),
|
||||||
|
DestPort: uint32(event.DestPort),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return protoEvent
|
||||||
|
}
|
||||||
52
client/internal/netflow/store/memory.go
Normal file
52
client/internal/netflow/store/memory.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewMemoryStore() *Memory {
|
||||||
|
return &Memory{
|
||||||
|
events: make(map[uuid.UUID]*types.Event),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Memory struct {
|
||||||
|
mux sync.Mutex
|
||||||
|
events map[uuid.UUID]*types.Event
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Memory) StoreEvent(event *types.Event) {
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
m.events[event.ID] = event
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Memory) Close() {
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
maps.Clear(m.events)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Memory) GetEvents() []*types.Event {
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
events := make([]*types.Event, 0, len(m.events))
|
||||||
|
for _, event := range m.events {
|
||||||
|
events = append(events, event)
|
||||||
|
}
|
||||||
|
return events
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Memory) DeleteEvents(ids []uuid.UUID) {
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
for _, id := range ids {
|
||||||
|
delete(m.events, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
151
client/internal/netflow/types/types.go
Normal file
151
client/internal/netflow/types/types.go
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Protocol uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
ProtocolUnknown = Protocol(0)
|
||||||
|
ICMP = Protocol(1)
|
||||||
|
TCP = Protocol(6)
|
||||||
|
UDP = Protocol(17)
|
||||||
|
SCTP = Protocol(132)
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p Protocol) String() string {
|
||||||
|
switch p {
|
||||||
|
case 1:
|
||||||
|
return "ICMP"
|
||||||
|
case 6:
|
||||||
|
return "TCP"
|
||||||
|
case 17:
|
||||||
|
return "UDP"
|
||||||
|
case 132:
|
||||||
|
return "SCTP"
|
||||||
|
default:
|
||||||
|
return strconv.FormatUint(uint64(p), 10)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Type int
|
||||||
|
|
||||||
|
const (
|
||||||
|
TypeUnknown = Type(iota)
|
||||||
|
TypeStart
|
||||||
|
TypeEnd
|
||||||
|
TypeDrop
|
||||||
|
)
|
||||||
|
|
||||||
|
type Direction int
|
||||||
|
|
||||||
|
func (d Direction) String() string {
|
||||||
|
switch d {
|
||||||
|
case Ingress:
|
||||||
|
return "ingress"
|
||||||
|
case Egress:
|
||||||
|
return "egress"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
DirectionUnknown = Direction(iota)
|
||||||
|
Ingress
|
||||||
|
Egress
|
||||||
|
)
|
||||||
|
|
||||||
|
type Event struct {
|
||||||
|
ID uuid.UUID
|
||||||
|
Timestamp time.Time
|
||||||
|
EventFields
|
||||||
|
}
|
||||||
|
|
||||||
|
type EventFields struct {
|
||||||
|
FlowID uuid.UUID
|
||||||
|
Type Type
|
||||||
|
RuleID []byte
|
||||||
|
Direction Direction
|
||||||
|
Protocol Protocol
|
||||||
|
SourceIP netip.Addr
|
||||||
|
DestIP netip.Addr
|
||||||
|
SourceResourceID []byte
|
||||||
|
DestResourceID []byte
|
||||||
|
SourcePort uint16
|
||||||
|
DestPort uint16
|
||||||
|
ICMPType uint8
|
||||||
|
ICMPCode uint8
|
||||||
|
RxPackets uint64
|
||||||
|
TxPackets uint64
|
||||||
|
RxBytes uint64
|
||||||
|
TxBytes uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
type FlowConfig struct {
|
||||||
|
URL string
|
||||||
|
Interval time.Duration
|
||||||
|
Enabled bool
|
||||||
|
Counters bool
|
||||||
|
TokenPayload string
|
||||||
|
TokenSignature string
|
||||||
|
}
|
||||||
|
|
||||||
|
type FlowManager interface {
|
||||||
|
// FlowConfig handles network map updates
|
||||||
|
Update(update *FlowConfig) error
|
||||||
|
// Close closes the manager
|
||||||
|
Close()
|
||||||
|
// GetLogger returns a flow logger
|
||||||
|
GetLogger() FlowLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
type FlowLogger interface {
|
||||||
|
// StoreEvent stores a flow event
|
||||||
|
StoreEvent(flowEvent EventFields)
|
||||||
|
// GetEvents returns all stored events
|
||||||
|
GetEvents() []*Event
|
||||||
|
// DeleteEvents deletes events from the store
|
||||||
|
DeleteEvents([]uuid.UUID)
|
||||||
|
// Close closes the logger
|
||||||
|
Close()
|
||||||
|
// Enable enables the flow logger receiver
|
||||||
|
Enable()
|
||||||
|
// Disable disables the flow logger receiver
|
||||||
|
Disable()
|
||||||
|
}
|
||||||
|
|
||||||
|
type Store interface {
|
||||||
|
// StoreEvent stores a flow event
|
||||||
|
StoreEvent(event *Event)
|
||||||
|
// GetEvents returns all stored events
|
||||||
|
GetEvents() []*Event
|
||||||
|
// DeleteEvents deletes events from the store
|
||||||
|
DeleteEvents([]uuid.UUID)
|
||||||
|
// Close closes the store
|
||||||
|
Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnTracker defines the interface for connection tracking functionality
|
||||||
|
type ConnTracker interface {
|
||||||
|
// Start begins tracking connections by listening for conntrack events.
|
||||||
|
Start(bool) error
|
||||||
|
// Stop stops the connection tracking.
|
||||||
|
Stop()
|
||||||
|
// Close stops listening for events and cleans up resources
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// IFaceMapper provides interface to check if we're using userspace WireGuard
|
||||||
|
type IFaceMapper interface {
|
||||||
|
IsUserspaceBind() bool
|
||||||
|
Name() string
|
||||||
|
Address() wgaddr.Address
|
||||||
|
}
|
||||||
73
client/internal/peer/route.go
Normal file
73
client/internal/peer/route.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type routeIDLookup struct {
|
||||||
|
localMap sync.Map
|
||||||
|
remoteMap sync.Map
|
||||||
|
resolvedIPs sync.Map
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *routeIDLookup) AddLocalRouteID(resourceID string, route netip.Prefix) {
|
||||||
|
_, exists := r.localMap.LoadOrStore(route, resourceID)
|
||||||
|
if exists {
|
||||||
|
log.Tracef("resourceID %s already exists in local map", resourceID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *routeIDLookup) RemoveLocalRouteID(route netip.Prefix) {
|
||||||
|
r.localMap.Delete(route)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *routeIDLookup) AddRemoteRouteID(resourceID string, route netip.Prefix) {
|
||||||
|
_, exists := r.remoteMap.LoadOrStore(route, resourceID)
|
||||||
|
if exists {
|
||||||
|
log.Tracef("resourceID %s already exists in remote map", resourceID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *routeIDLookup) RemoveRemoteRouteID(route netip.Prefix) {
|
||||||
|
r.remoteMap.Delete(route)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *routeIDLookup) AddResolvedIP(resourceID string, route netip.Prefix) {
|
||||||
|
r.resolvedIPs.Store(route.Addr(), resourceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *routeIDLookup) RemoveResolvedIP(route netip.Prefix) {
|
||||||
|
r.resolvedIPs.Delete(route.Addr())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *routeIDLookup) Lookup(ip netip.Addr) string {
|
||||||
|
resId, ok := r.resolvedIPs.Load(ip)
|
||||||
|
if ok {
|
||||||
|
return resId.(string)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resourceID string
|
||||||
|
r.localMap.Range(func(key, value interface{}) bool {
|
||||||
|
if key.(netip.Prefix).Contains(ip) {
|
||||||
|
resourceID = value.(string)
|
||||||
|
return false
|
||||||
|
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
if resourceID == "" {
|
||||||
|
r.remoteMap.Range(func(key, value interface{}) bool {
|
||||||
|
if key.(netip.Prefix).Contains(ip) {
|
||||||
|
resourceID = value.(string)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return resourceID
|
||||||
|
}
|
||||||
@@ -176,6 +176,8 @@ type Status struct {
|
|||||||
eventQueue *EventQueue
|
eventQueue *EventQueue
|
||||||
|
|
||||||
ingressGwMgr *ingressgw.Manager
|
ingressGwMgr *ingressgw.Manager
|
||||||
|
|
||||||
|
routeIDLookup routeIDLookup
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRecorder returns a new Status instance
|
// NewRecorder returns a new Status instance
|
||||||
@@ -311,7 +313,7 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) AddPeerStateRoute(peer string, route string) error {
|
func (d *Status) AddPeerStateRoute(peer string, route string, resourceId string) error {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
@@ -323,6 +325,14 @@ func (d *Status) AddPeerStateRoute(peer string, route string) error {
|
|||||||
peerState.AddRoute(route)
|
peerState.AddRoute(route)
|
||||||
d.peers[peer] = peerState
|
d.peers[peer] = peerState
|
||||||
|
|
||||||
|
pref, err := netip.ParsePrefix(route)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to parse prefix %s: %v", route, err)
|
||||||
|
} else {
|
||||||
|
|
||||||
|
d.routeIDLookup.AddRemoteRouteID(resourceId, pref)
|
||||||
|
}
|
||||||
|
|
||||||
// todo: consider to make sense of this notification or not
|
// todo: consider to make sense of this notification or not
|
||||||
d.notifyPeerListChanged()
|
d.notifyPeerListChanged()
|
||||||
return nil
|
return nil
|
||||||
@@ -340,11 +350,27 @@ func (d *Status) RemovePeerStateRoute(peer string, route string) error {
|
|||||||
peerState.DeleteRoute(route)
|
peerState.DeleteRoute(route)
|
||||||
d.peers[peer] = peerState
|
d.peers[peer] = peerState
|
||||||
|
|
||||||
|
pref, err := netip.ParsePrefix(route)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to parse prefix %s: %v", route, err)
|
||||||
|
} else {
|
||||||
|
d.routeIDLookup.RemoveRemoteRouteID(pref)
|
||||||
|
}
|
||||||
|
|
||||||
// todo: consider to make sense of this notification or not
|
// todo: consider to make sense of this notification or not
|
||||||
d.notifyPeerListChanged()
|
d.notifyPeerListChanged()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CheckRoutes checks if the source and destination addresses are within the same route
|
||||||
|
// and returns the resource ID of the route that contains the addresses
|
||||||
|
func (d *Status) CheckRoutes(ip netip.Addr) (resId string) {
|
||||||
|
if d == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return d.routeIDLookup.Lookup(ip)
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Status) UpdatePeerICEState(receivedState State) error {
|
func (d *Status) UpdatePeerICEState(receivedState State) error {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
@@ -558,6 +584,50 @@ func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
|||||||
d.notifyAddressChanged()
|
d.notifyAddressChanged()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddLocalPeerStateRoute adds a route to the local peer state
|
||||||
|
func (d *Status) AddLocalPeerStateRoute(route, resourceId string) {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
pref, err := netip.ParsePrefix(route)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to parse prefix %s: %v", route, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if d.localPeer.Routes == nil {
|
||||||
|
d.localPeer.Routes = map[string]struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
d.localPeer.Routes[route] = struct{}{}
|
||||||
|
|
||||||
|
d.routeIDLookup.AddLocalRouteID(resourceId, pref)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveLocalPeerStateRoute removes a route from the local peer state
|
||||||
|
func (d *Status) RemoveLocalPeerStateRoute(route string) {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
pref, err := netip.ParsePrefix(route)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to parse prefix %s: %v", route, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(d.localPeer.Routes, route)
|
||||||
|
|
||||||
|
d.routeIDLookup.RemoveLocalRouteID(pref)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanLocalPeerStateRoutes cleans all routes from the local peer state
|
||||||
|
func (d *Status) CleanLocalPeerStateRoutes() {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
d.localPeer.Routes = map[string]struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
// CleanLocalPeerState cleans local peer status
|
// CleanLocalPeerState cleans local peer status
|
||||||
func (d *Status) CleanLocalPeerState() {
|
func (d *Status) CleanLocalPeerState() {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
@@ -641,7 +711,7 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
|
|||||||
d.nsGroupStates = dnsStates
|
d.nsGroupStates = dnsStates
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix) {
|
func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix, resourceId string) {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
@@ -650,6 +720,10 @@ func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resol
|
|||||||
Prefixes: prefixes,
|
Prefixes: prefixes,
|
||||||
ParentDomain: originalDomain,
|
ParentDomain: originalDomain,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
d.routeIDLookup.AddResolvedIP(resourceId, prefix)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
||||||
@@ -660,6 +734,10 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
|||||||
for k, v := range d.resolvedDomainsStates {
|
for k, v := range d.resolvedDomainsStates {
|
||||||
if v.ParentDomain == domain {
|
if v.ParentDomain == domain {
|
||||||
delete(d.resolvedDomainsStates, k)
|
delete(d.resolvedDomainsStates, k)
|
||||||
|
|
||||||
|
for _, prefix := range v.Prefixes {
|
||||||
|
d.routeIDLookup.RemoveResolvedIP(prefix)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ type PKCEAuthProviderConfig struct {
|
|||||||
RedirectURLs []string
|
RedirectURLs []string
|
||||||
// UseIDToken indicates if the id token should be used for authentication
|
// UseIDToken indicates if the id token should be used for authentication
|
||||||
UseIDToken bool
|
UseIDToken bool
|
||||||
//ClientCertPair is used for mTLS authentication to the IDP
|
// ClientCertPair is used for mTLS authentication to the IDP
|
||||||
ClientCertPair *tls.Certificate
|
ClientCertPair *tls.Certificate
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -330,7 +330,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem(rsn reason) error
|
|||||||
c.connectEvent()
|
c.connectEvent()
|
||||||
}
|
}
|
||||||
|
|
||||||
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String())
|
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String(), c.currentChosen.GetResourceID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("add peer state route: %w", err)
|
return fmt.Errorf("add peer state route: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -321,7 +321,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
|
|||||||
if len(toAdd) > 0 || len(toRemove) > 0 {
|
if len(toAdd) > 0 || len(toRemove) > 0 {
|
||||||
d.interceptedDomains[resolvedDomain] = newPrefixes
|
d.interceptedDomains[resolvedDomain] = newPrefixes
|
||||||
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
|
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
|
||||||
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes)
|
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
|
||||||
|
|
||||||
if len(toAdd) > 0 {
|
if len(toAdd) > 0 {
|
||||||
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||||
|
|||||||
@@ -288,7 +288,7 @@ func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) e
|
|||||||
updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes)
|
updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes)
|
||||||
r.dynamicDomains[domain] = updatedPrefixes
|
r.dynamicDomains[domain] = updatedPrefixes
|
||||||
|
|
||||||
r.statusRecorder.UpdateResolvedDomainsStates(domain, domain, updatedPrefixes)
|
r.statusRecorder.UpdateResolvedDomainsStates(domain, domain, updatedPrefixes, r.route.GetResourceID())
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
|||||||
@@ -103,9 +103,7 @@ func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
|
|||||||
|
|
||||||
delete(m.routes, route.ID)
|
delete(m.routes, route.ID)
|
||||||
|
|
||||||
state := m.statusRecorder.GetLocalPeerState()
|
m.statusRecorder.RemoveLocalPeerStateRoute(route.Network.String())
|
||||||
delete(state.Routes, route.Network.String())
|
|
||||||
m.statusRecorder.UpdateLocalPeerState(state)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -131,18 +129,12 @@ func (m *serverRouter) addToServerNetwork(route *route.Route) error {
|
|||||||
|
|
||||||
m.routes[route.ID] = route
|
m.routes[route.ID] = route
|
||||||
|
|
||||||
state := m.statusRecorder.GetLocalPeerState()
|
|
||||||
if state.Routes == nil {
|
|
||||||
state.Routes = map[string]struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
routeStr := route.Network.String()
|
routeStr := route.Network.String()
|
||||||
if route.IsDynamic() {
|
if route.IsDynamic() {
|
||||||
routeStr = route.Domains.SafeString()
|
routeStr = route.Domains.SafeString()
|
||||||
}
|
}
|
||||||
state.Routes[routeStr] = struct{}{}
|
|
||||||
|
|
||||||
m.statusRecorder.UpdateLocalPeerState(state)
|
m.statusRecorder.AddLocalPeerStateRoute(routeStr, route.GetResourceID())
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -164,9 +156,7 @@ func (m *serverRouter) cleanUp() {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
state := m.statusRecorder.GetLocalPeerState()
|
m.statusRecorder.CleanLocalPeerStateRoutes()
|
||||||
state.Routes = nil
|
|
||||||
m.statusRecorder.UpdateLocalPeerState(state)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) {
|
func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) {
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
@@ -129,13 +128,13 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
|||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock())
|
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settings.NewManagerMock())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
|
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManagerMock(), peersUpdateManager, secretsManager, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
@@ -41,11 +42,21 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
|
|||||||
srcIP = engine.GetWgAddr()
|
srcIP = engine.GetWgAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
srcAddr, ok := netip.AddrFromSlice(srcIP)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid source IP address")
|
||||||
|
}
|
||||||
|
|
||||||
dstIP := net.ParseIP(req.GetDestinationIp())
|
dstIP := net.ParseIP(req.GetDestinationIp())
|
||||||
if req.GetDestinationIp() == "self" {
|
if req.GetDestinationIp() == "self" {
|
||||||
dstIP = engine.GetWgAddr()
|
dstIP = engine.GetWgAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dstAddr, ok := netip.AddrFromSlice(dstIP)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid source IP address")
|
||||||
|
}
|
||||||
|
|
||||||
if srcIP == nil || dstIP == nil {
|
if srcIP == nil || dstIP == nil {
|
||||||
return nil, fmt.Errorf("invalid IP address")
|
return nil, fmt.Errorf("invalid IP address")
|
||||||
}
|
}
|
||||||
@@ -85,8 +96,8 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
builder := &uspfilter.PacketBuilder{
|
builder := &uspfilter.PacketBuilder{
|
||||||
SrcIP: srcIP,
|
SrcIP: srcAddr,
|
||||||
DstIP: dstIP,
|
DstIP: dstAddr,
|
||||||
Protocol: protocol,
|
Protocol: protocol,
|
||||||
SrcPort: uint16(req.GetSourcePort()),
|
SrcPort: uint16(req.GetSourcePort()),
|
||||||
DstPort: uint16(req.GetDestinationPort()),
|
DstPort: uint16(req.GetDestinationPort()),
|
||||||
|
|||||||
32
flow/client/auth.go
Normal file
32
flow/client/auth.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ credentials.PerRPCCredentials = (*authToken)(nil)
|
||||||
|
|
||||||
|
type authToken struct {
|
||||||
|
metaMap map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t authToken) GetRequestMetadata(context.Context, ...string) (map[string]string, error) {
|
||||||
|
return t.metaMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (authToken) RequireTransportSecurity() bool {
|
||||||
|
return false // Set to true if you want to require a secure connection
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAuthToken returns a DialOption which sets the receiver flow credentials and places auth state on each outbound RPC
|
||||||
|
func withAuthToken(payload, signature string) grpc.DialOption {
|
||||||
|
value := fmt.Sprintf("%s.%s", signature, payload)
|
||||||
|
authMap := map[string]string{
|
||||||
|
"authorization": "Bearer " + value,
|
||||||
|
}
|
||||||
|
return grpc.WithPerRPCCredentials(authToken{metaMap: authMap})
|
||||||
|
}
|
||||||
167
flow/client/client.go
Normal file
167
flow/client/client.go
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/connectivity"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
"google.golang.org/grpc/keepalive"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/flow/proto"
|
||||||
|
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/util/grpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GRPCClient struct {
|
||||||
|
realClient proto.FlowServiceClient
|
||||||
|
clientConn *grpc.ClientConn
|
||||||
|
stream proto.FlowService_EventsClient
|
||||||
|
streamMu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCClient, error) {
|
||||||
|
var opts []grpc.DialOption
|
||||||
|
|
||||||
|
if strings.Contains(addr, "443") {
|
||||||
|
certPool, err := x509.SystemCertPool()
|
||||||
|
if err != nil || certPool == nil {
|
||||||
|
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
|
||||||
|
certPool = embeddedroots.Get()
|
||||||
|
}
|
||||||
|
|
||||||
|
opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
||||||
|
RootCAs: certPool,
|
||||||
|
})))
|
||||||
|
} else {
|
||||||
|
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
}
|
||||||
|
|
||||||
|
opts = append(opts,
|
||||||
|
nbgrpc.WithCustomDialer(),
|
||||||
|
grpc.WithIdleTimeout(interval*2),
|
||||||
|
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||||
|
Time: 30 * time.Second,
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
}),
|
||||||
|
withAuthToken(payload, signature),
|
||||||
|
grpc.WithDefaultServiceConfig(`{"healthCheckConfig": {"serviceName": ""}}`),
|
||||||
|
)
|
||||||
|
|
||||||
|
conn, err := grpc.NewClient(addr, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating new grpc client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &GRPCClient{
|
||||||
|
realClient: proto.NewFlowServiceClient(conn),
|
||||||
|
clientConn: conn,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GRPCClient) Close() error {
|
||||||
|
c.streamMu.Lock()
|
||||||
|
defer c.streamMu.Unlock()
|
||||||
|
|
||||||
|
c.stream = nil
|
||||||
|
return c.clientConn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHandler func(msg *proto.FlowEventAck) error) error {
|
||||||
|
backOff := defaultBackoff(ctx, interval)
|
||||||
|
operation := func() error {
|
||||||
|
return c.establishStreamAndReceive(ctx, msgHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := backoff.Retry(operation, backOff); err != nil {
|
||||||
|
return fmt.Errorf("receive failed permanently: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GRPCClient) establishStreamAndReceive(ctx context.Context, msgHandler func(msg *proto.FlowEventAck) error) error {
|
||||||
|
if c.clientConn.GetState() == connectivity.Shutdown {
|
||||||
|
return backoff.Permanent(errors.New("connection to flow receiver has been shut down"))
|
||||||
|
}
|
||||||
|
|
||||||
|
stream, err := c.realClient.Events(ctx, grpc.WaitForReady(true))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create event stream: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = checkHeader(stream); err != nil {
|
||||||
|
return fmt.Errorf("check header: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.streamMu.Lock()
|
||||||
|
c.stream = stream
|
||||||
|
c.streamMu.Unlock()
|
||||||
|
|
||||||
|
return c.receive(stream, msgHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GRPCClient) receive(stream proto.FlowService_EventsClient, msgHandler func(msg *proto.FlowEventAck) error) error {
|
||||||
|
for {
|
||||||
|
msg, err := stream.Recv()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("receive from stream: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := msgHandler(msg); err != nil {
|
||||||
|
return fmt.Errorf("handle message: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkHeader(stream proto.FlowService_EventsClient) error {
|
||||||
|
header, err := stream.Header()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("waiting for flow receiver header: %s", err)
|
||||||
|
return fmt.Errorf("wait for header: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(header) == 0 {
|
||||||
|
log.Error("flow receiver sent no headers")
|
||||||
|
return fmt.Errorf("should have headers")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff {
|
||||||
|
return backoff.WithContext(&backoff.ExponentialBackOff{
|
||||||
|
InitialInterval: 800 * time.Millisecond,
|
||||||
|
RandomizationFactor: 1,
|
||||||
|
Multiplier: 1.7,
|
||||||
|
MaxInterval: interval / 2,
|
||||||
|
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
|
||||||
|
Stop: backoff.Stop,
|
||||||
|
Clock: backoff.SystemClock,
|
||||||
|
}, ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GRPCClient) Send(event *proto.FlowEvent) error {
|
||||||
|
c.streamMu.Lock()
|
||||||
|
stream := c.stream
|
||||||
|
c.streamMu.Unlock()
|
||||||
|
|
||||||
|
if stream == nil {
|
||||||
|
return errors.New("stream not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := stream.Send(event); err != nil {
|
||||||
|
return fmt.Errorf("send flow event: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
769
flow/proto/flow.pb.go
Normal file
769
flow/proto/flow.pb.go
Normal file
@@ -0,0 +1,769 @@
|
|||||||
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// protoc-gen-go v1.26.0
|
||||||
|
// protoc v4.24.3
|
||||||
|
// source: flow.proto
|
||||||
|
|
||||||
|
package proto
|
||||||
|
|
||||||
|
import (
|
||||||
|
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||||
|
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||||
|
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
reflect "reflect"
|
||||||
|
sync "sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Verify that this generated code is sufficiently up-to-date.
|
||||||
|
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||||
|
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||||
|
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Flow event types
|
||||||
|
type Type int32
|
||||||
|
|
||||||
|
const (
|
||||||
|
Type_TYPE_UNKNOWN Type = 0
|
||||||
|
Type_TYPE_START Type = 1
|
||||||
|
Type_TYPE_END Type = 2
|
||||||
|
Type_TYPE_DROP Type = 3
|
||||||
|
)
|
||||||
|
|
||||||
|
// Enum value maps for Type.
|
||||||
|
var (
|
||||||
|
Type_name = map[int32]string{
|
||||||
|
0: "TYPE_UNKNOWN",
|
||||||
|
1: "TYPE_START",
|
||||||
|
2: "TYPE_END",
|
||||||
|
3: "TYPE_DROP",
|
||||||
|
}
|
||||||
|
Type_value = map[string]int32{
|
||||||
|
"TYPE_UNKNOWN": 0,
|
||||||
|
"TYPE_START": 1,
|
||||||
|
"TYPE_END": 2,
|
||||||
|
"TYPE_DROP": 3,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (x Type) Enum() *Type {
|
||||||
|
p := new(Type)
|
||||||
|
*p = x
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x Type) String() string {
|
||||||
|
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Type) Descriptor() protoreflect.EnumDescriptor {
|
||||||
|
return file_flow_proto_enumTypes[0].Descriptor()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Type) Type() protoreflect.EnumType {
|
||||||
|
return &file_flow_proto_enumTypes[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x Type) Number() protoreflect.EnumNumber {
|
||||||
|
return protoreflect.EnumNumber(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use Type.Descriptor instead.
|
||||||
|
func (Type) EnumDescriptor() ([]byte, []int) {
|
||||||
|
return file_flow_proto_rawDescGZIP(), []int{0}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flow direction
|
||||||
|
type Direction int32
|
||||||
|
|
||||||
|
const (
|
||||||
|
Direction_DIRECTION_UNKNOWN Direction = 0
|
||||||
|
Direction_INGRESS Direction = 1
|
||||||
|
Direction_EGRESS Direction = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
// Enum value maps for Direction.
|
||||||
|
var (
|
||||||
|
Direction_name = map[int32]string{
|
||||||
|
0: "DIRECTION_UNKNOWN",
|
||||||
|
1: "INGRESS",
|
||||||
|
2: "EGRESS",
|
||||||
|
}
|
||||||
|
Direction_value = map[string]int32{
|
||||||
|
"DIRECTION_UNKNOWN": 0,
|
||||||
|
"INGRESS": 1,
|
||||||
|
"EGRESS": 2,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (x Direction) Enum() *Direction {
|
||||||
|
p := new(Direction)
|
||||||
|
*p = x
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x Direction) String() string {
|
||||||
|
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Direction) Descriptor() protoreflect.EnumDescriptor {
|
||||||
|
return file_flow_proto_enumTypes[1].Descriptor()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Direction) Type() protoreflect.EnumType {
|
||||||
|
return &file_flow_proto_enumTypes[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x Direction) Number() protoreflect.EnumNumber {
|
||||||
|
return protoreflect.EnumNumber(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use Direction.Descriptor instead.
|
||||||
|
func (Direction) EnumDescriptor() ([]byte, []int) {
|
||||||
|
return file_flow_proto_rawDescGZIP(), []int{1}
|
||||||
|
}
|
||||||
|
|
||||||
|
type FlowEvent struct {
|
||||||
|
state protoimpl.MessageState
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
|
||||||
|
// Unique client event identifier
|
||||||
|
EventId []byte `protobuf:"bytes,1,opt,name=event_id,json=eventId,proto3" json:"event_id,omitempty"`
|
||||||
|
// When the event occurred
|
||||||
|
Timestamp *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
|
||||||
|
// Public key of the sending peer
|
||||||
|
PublicKey []byte `protobuf:"bytes,3,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"`
|
||||||
|
FlowFields *FlowFields `protobuf:"bytes,4,opt,name=flow_fields,json=flowFields,proto3" json:"flow_fields,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowEvent) Reset() {
|
||||||
|
*x = FlowEvent{}
|
||||||
|
if protoimpl.UnsafeEnabled {
|
||||||
|
mi := &file_flow_proto_msgTypes[0]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowEvent) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*FlowEvent) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *FlowEvent) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_flow_proto_msgTypes[0]
|
||||||
|
if protoimpl.UnsafeEnabled && x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use FlowEvent.ProtoReflect.Descriptor instead.
|
||||||
|
func (*FlowEvent) Descriptor() ([]byte, []int) {
|
||||||
|
return file_flow_proto_rawDescGZIP(), []int{0}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowEvent) GetEventId() []byte {
|
||||||
|
if x != nil {
|
||||||
|
return x.EventId
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowEvent) GetTimestamp() *timestamppb.Timestamp {
|
||||||
|
if x != nil {
|
||||||
|
return x.Timestamp
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowEvent) GetPublicKey() []byte {
|
||||||
|
if x != nil {
|
||||||
|
return x.PublicKey
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowEvent) GetFlowFields() *FlowFields {
|
||||||
|
if x != nil {
|
||||||
|
return x.FlowFields
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type FlowEventAck struct {
|
||||||
|
state protoimpl.MessageState
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
|
||||||
|
// Unique client event identifier that has been ack'ed
|
||||||
|
EventId []byte `protobuf:"bytes,1,opt,name=event_id,json=eventId,proto3" json:"event_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowEventAck) Reset() {
|
||||||
|
*x = FlowEventAck{}
|
||||||
|
if protoimpl.UnsafeEnabled {
|
||||||
|
mi := &file_flow_proto_msgTypes[1]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowEventAck) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*FlowEventAck) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *FlowEventAck) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_flow_proto_msgTypes[1]
|
||||||
|
if protoimpl.UnsafeEnabled && x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use FlowEventAck.ProtoReflect.Descriptor instead.
|
||||||
|
func (*FlowEventAck) Descriptor() ([]byte, []int) {
|
||||||
|
return file_flow_proto_rawDescGZIP(), []int{1}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowEventAck) GetEventId() []byte {
|
||||||
|
if x != nil {
|
||||||
|
return x.EventId
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type FlowFields struct {
|
||||||
|
state protoimpl.MessageState
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
|
||||||
|
// Unique client flow session identifier
|
||||||
|
FlowId []byte `protobuf:"bytes,1,opt,name=flow_id,json=flowId,proto3" json:"flow_id,omitempty"`
|
||||||
|
// Flow type
|
||||||
|
Type Type `protobuf:"varint,2,opt,name=type,proto3,enum=flow.Type" json:"type,omitempty"`
|
||||||
|
// RuleId identifies the rule that allowed or denied the connection
|
||||||
|
RuleId []byte `protobuf:"bytes,3,opt,name=rule_id,json=ruleId,proto3" json:"rule_id,omitempty"`
|
||||||
|
// Initiating traffic direction
|
||||||
|
Direction Direction `protobuf:"varint,4,opt,name=direction,proto3,enum=flow.Direction" json:"direction,omitempty"`
|
||||||
|
// IP protocol number
|
||||||
|
Protocol uint32 `protobuf:"varint,5,opt,name=protocol,proto3" json:"protocol,omitempty"`
|
||||||
|
// Source IP address
|
||||||
|
SourceIp []byte `protobuf:"bytes,6,opt,name=source_ip,json=sourceIp,proto3" json:"source_ip,omitempty"`
|
||||||
|
// Destination IP address
|
||||||
|
DestIp []byte `protobuf:"bytes,7,opt,name=dest_ip,json=destIp,proto3" json:"dest_ip,omitempty"`
|
||||||
|
// Layer 4 -specific information
|
||||||
|
//
|
||||||
|
// Types that are assignable to ConnectionInfo:
|
||||||
|
//
|
||||||
|
// *FlowFields_PortInfo
|
||||||
|
// *FlowFields_IcmpInfo
|
||||||
|
ConnectionInfo isFlowFields_ConnectionInfo `protobuf_oneof:"connection_info"`
|
||||||
|
// Number of packets
|
||||||
|
RxPackets uint64 `protobuf:"varint,10,opt,name=rx_packets,json=rxPackets,proto3" json:"rx_packets,omitempty"`
|
||||||
|
TxPackets uint64 `protobuf:"varint,11,opt,name=tx_packets,json=txPackets,proto3" json:"tx_packets,omitempty"`
|
||||||
|
// Number of bytes
|
||||||
|
RxBytes uint64 `protobuf:"varint,12,opt,name=rx_bytes,json=rxBytes,proto3" json:"rx_bytes,omitempty"`
|
||||||
|
TxBytes uint64 `protobuf:"varint,13,opt,name=tx_bytes,json=txBytes,proto3" json:"tx_bytes,omitempty"`
|
||||||
|
// Resource ID
|
||||||
|
SourceResourceId []byte `protobuf:"bytes,14,opt,name=source_resource_id,json=sourceResourceId,proto3" json:"source_resource_id,omitempty"`
|
||||||
|
DestResourceId []byte `protobuf:"bytes,15,opt,name=dest_resource_id,json=destResourceId,proto3" json:"dest_resource_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowFields) Reset() {
|
||||||
|
*x = FlowFields{}
|
||||||
|
if protoimpl.UnsafeEnabled {
|
||||||
|
mi := &file_flow_proto_msgTypes[2]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowFields) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*FlowFields) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *FlowFields) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_flow_proto_msgTypes[2]
|
||||||
|
if protoimpl.UnsafeEnabled && x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use FlowFields.ProtoReflect.Descriptor instead.
|
||||||
|
func (*FlowFields) Descriptor() ([]byte, []int) {
|
||||||
|
return file_flow_proto_rawDescGZIP(), []int{2}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowFields) GetFlowId() []byte {
|
||||||
|
if x != nil {
|
||||||
|
return x.FlowId
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowFields) GetType() Type {
|
||||||
|
if x != nil {
|
||||||
|
return x.Type
|
||||||
|
}
|
||||||
|
return Type_TYPE_UNKNOWN
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowFields) GetRuleId() []byte {
|
||||||
|
if x != nil {
|
||||||
|
return x.RuleId
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowFields) GetDirection() Direction {
|
||||||
|
if x != nil {
|
||||||
|
return x.Direction
|
||||||
|
}
|
||||||
|
return Direction_DIRECTION_UNKNOWN
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowFields) GetProtocol() uint32 {
|
||||||
|
if x != nil {
|
||||||
|
return x.Protocol
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowFields) GetSourceIp() []byte {
|
||||||
|
if x != nil {
|
||||||
|
return x.SourceIp
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowFields) GetDestIp() []byte {
|
||||||
|
if x != nil {
|
||||||
|
return x.DestIp
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *FlowFields) GetConnectionInfo() isFlowFields_ConnectionInfo {
|
||||||
|
if m != nil {
|
||||||
|
return m.ConnectionInfo
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowFields) GetPortInfo() *PortInfo {
|
||||||
|
if x, ok := x.GetConnectionInfo().(*FlowFields_PortInfo); ok {
|
||||||
|
return x.PortInfo
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowFields) GetIcmpInfo() *ICMPInfo {
|
||||||
|
if x, ok := x.GetConnectionInfo().(*FlowFields_IcmpInfo); ok {
|
||||||
|
return x.IcmpInfo
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowFields) GetRxPackets() uint64 {
|
||||||
|
if x != nil {
|
||||||
|
return x.RxPackets
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowFields) GetTxPackets() uint64 {
|
||||||
|
if x != nil {
|
||||||
|
return x.TxPackets
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowFields) GetRxBytes() uint64 {
|
||||||
|
if x != nil {
|
||||||
|
return x.RxBytes
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowFields) GetTxBytes() uint64 {
|
||||||
|
if x != nil {
|
||||||
|
return x.TxBytes
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowFields) GetSourceResourceId() []byte {
|
||||||
|
if x != nil {
|
||||||
|
return x.SourceResourceId
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *FlowFields) GetDestResourceId() []byte {
|
||||||
|
if x != nil {
|
||||||
|
return x.DestResourceId
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type isFlowFields_ConnectionInfo interface {
|
||||||
|
isFlowFields_ConnectionInfo()
|
||||||
|
}
|
||||||
|
|
||||||
|
type FlowFields_PortInfo struct {
|
||||||
|
// TCP/UDP port information
|
||||||
|
PortInfo *PortInfo `protobuf:"bytes,8,opt,name=port_info,json=portInfo,proto3,oneof"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FlowFields_IcmpInfo struct {
|
||||||
|
// ICMP type and code
|
||||||
|
IcmpInfo *ICMPInfo `protobuf:"bytes,9,opt,name=icmp_info,json=icmpInfo,proto3,oneof"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*FlowFields_PortInfo) isFlowFields_ConnectionInfo() {}
|
||||||
|
|
||||||
|
func (*FlowFields_IcmpInfo) isFlowFields_ConnectionInfo() {}
|
||||||
|
|
||||||
|
// TCP/UDP port information
|
||||||
|
type PortInfo struct {
|
||||||
|
state protoimpl.MessageState
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
|
||||||
|
SourcePort uint32 `protobuf:"varint,1,opt,name=source_port,json=sourcePort,proto3" json:"source_port,omitempty"`
|
||||||
|
DestPort uint32 `protobuf:"varint,2,opt,name=dest_port,json=destPort,proto3" json:"dest_port,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *PortInfo) Reset() {
|
||||||
|
*x = PortInfo{}
|
||||||
|
if protoimpl.UnsafeEnabled {
|
||||||
|
mi := &file_flow_proto_msgTypes[3]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *PortInfo) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*PortInfo) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *PortInfo) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_flow_proto_msgTypes[3]
|
||||||
|
if protoimpl.UnsafeEnabled && x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use PortInfo.ProtoReflect.Descriptor instead.
|
||||||
|
func (*PortInfo) Descriptor() ([]byte, []int) {
|
||||||
|
return file_flow_proto_rawDescGZIP(), []int{3}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *PortInfo) GetSourcePort() uint32 {
|
||||||
|
if x != nil {
|
||||||
|
return x.SourcePort
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *PortInfo) GetDestPort() uint32 {
|
||||||
|
if x != nil {
|
||||||
|
return x.DestPort
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ICMP message information
|
||||||
|
type ICMPInfo struct {
|
||||||
|
state protoimpl.MessageState
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
|
||||||
|
IcmpType uint32 `protobuf:"varint,1,opt,name=icmp_type,json=icmpType,proto3" json:"icmp_type,omitempty"`
|
||||||
|
IcmpCode uint32 `protobuf:"varint,2,opt,name=icmp_code,json=icmpCode,proto3" json:"icmp_code,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ICMPInfo) Reset() {
|
||||||
|
*x = ICMPInfo{}
|
||||||
|
if protoimpl.UnsafeEnabled {
|
||||||
|
mi := &file_flow_proto_msgTypes[4]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ICMPInfo) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*ICMPInfo) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *ICMPInfo) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_flow_proto_msgTypes[4]
|
||||||
|
if protoimpl.UnsafeEnabled && x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use ICMPInfo.ProtoReflect.Descriptor instead.
|
||||||
|
func (*ICMPInfo) Descriptor() ([]byte, []int) {
|
||||||
|
return file_flow_proto_rawDescGZIP(), []int{4}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ICMPInfo) GetIcmpType() uint32 {
|
||||||
|
if x != nil {
|
||||||
|
return x.IcmpType
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *ICMPInfo) GetIcmpCode() uint32 {
|
||||||
|
if x != nil {
|
||||||
|
return x.IcmpCode
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
var File_flow_proto protoreflect.FileDescriptor
|
||||||
|
|
||||||
|
var file_flow_proto_rawDesc = []byte{
|
||||||
|
0x0a, 0x0a, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x66, 0x6c,
|
||||||
|
0x6f, 0x77, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f,
|
||||||
|
0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72,
|
||||||
|
0x6f, 0x74, 0x6f, 0x22, 0xb2, 0x01, 0x0a, 0x09, 0x46, 0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e,
|
||||||
|
0x74, 0x12, 0x19, 0x0a, 0x08, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20,
|
||||||
|
0x01, 0x28, 0x0c, 0x52, 0x07, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x38, 0x0a, 0x09,
|
||||||
|
0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32,
|
||||||
|
0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75,
|
||||||
|
0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69, 0x6d,
|
||||||
|
0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63,
|
||||||
|
0x5f, 0x6b, 0x65, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x70, 0x75, 0x62, 0x6c,
|
||||||
|
0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x31, 0x0a, 0x0b, 0x66, 0x6c, 0x6f, 0x77, 0x5f, 0x66, 0x69,
|
||||||
|
0x65, 0x6c, 0x64, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x66, 0x6c, 0x6f,
|
||||||
|
0x77, 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x52, 0x0a, 0x66, 0x6c,
|
||||||
|
0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x22, 0x29, 0x0a, 0x0c, 0x46, 0x6c, 0x6f, 0x77,
|
||||||
|
0x45, 0x76, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x6b, 0x12, 0x19, 0x0a, 0x08, 0x65, 0x76, 0x65, 0x6e,
|
||||||
|
0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x65, 0x76, 0x65, 0x6e,
|
||||||
|
0x74, 0x49, 0x64, 0x22, 0x9c, 0x04, 0x0a, 0x0a, 0x46, 0x6c, 0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c,
|
||||||
|
0x64, 0x73, 0x12, 0x17, 0x0a, 0x07, 0x66, 0x6c, 0x6f, 0x77, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20,
|
||||||
|
0x01, 0x28, 0x0c, 0x52, 0x06, 0x66, 0x6c, 0x6f, 0x77, 0x49, 0x64, 0x12, 0x1e, 0x0a, 0x04, 0x74,
|
||||||
|
0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0a, 0x2e, 0x66, 0x6c, 0x6f, 0x77,
|
||||||
|
0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x17, 0x0a, 0x07, 0x72,
|
||||||
|
0x75, 0x6c, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x72, 0x75,
|
||||||
|
0x6c, 0x65, 0x49, 0x64, 0x12, 0x2d, 0x0a, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f,
|
||||||
|
0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0f, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x44,
|
||||||
|
0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74,
|
||||||
|
0x69, 0x6f, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18,
|
||||||
|
0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12,
|
||||||
|
0x1b, 0x0a, 0x09, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x70, 0x18, 0x06, 0x20, 0x01,
|
||||||
|
0x28, 0x0c, 0x52, 0x08, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x70, 0x12, 0x17, 0x0a, 0x07,
|
||||||
|
0x64, 0x65, 0x73, 0x74, 0x5f, 0x69, 0x70, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x64,
|
||||||
|
0x65, 0x73, 0x74, 0x49, 0x70, 0x12, 0x2d, 0x0a, 0x09, 0x70, 0x6f, 0x72, 0x74, 0x5f, 0x69, 0x6e,
|
||||||
|
0x66, 0x6f, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e,
|
||||||
|
0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x48, 0x00, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74,
|
||||||
|
0x49, 0x6e, 0x66, 0x6f, 0x12, 0x2d, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x69, 0x6e, 0x66,
|
||||||
|
0x6f, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x49,
|
||||||
|
0x43, 0x4d, 0x50, 0x49, 0x6e, 0x66, 0x6f, 0x48, 0x00, 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x49,
|
||||||
|
0x6e, 0x66, 0x6f, 0x12, 0x1d, 0x0a, 0x0a, 0x72, 0x78, 0x5f, 0x70, 0x61, 0x63, 0x6b, 0x65, 0x74,
|
||||||
|
0x73, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x04, 0x52, 0x09, 0x72, 0x78, 0x50, 0x61, 0x63, 0x6b, 0x65,
|
||||||
|
0x74, 0x73, 0x12, 0x1d, 0x0a, 0x0a, 0x74, 0x78, 0x5f, 0x70, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x73,
|
||||||
|
0x18, 0x0b, 0x20, 0x01, 0x28, 0x04, 0x52, 0x09, 0x74, 0x78, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74,
|
||||||
|
0x73, 0x12, 0x19, 0x0a, 0x08, 0x72, 0x78, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x0c, 0x20,
|
||||||
|
0x01, 0x28, 0x04, 0x52, 0x07, 0x72, 0x78, 0x42, 0x79, 0x74, 0x65, 0x73, 0x12, 0x19, 0x0a, 0x08,
|
||||||
|
0x74, 0x78, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x04, 0x52, 0x07,
|
||||||
|
0x74, 0x78, 0x42, 0x79, 0x74, 0x65, 0x73, 0x12, 0x2c, 0x0a, 0x12, 0x73, 0x6f, 0x75, 0x72, 0x63,
|
||||||
|
0x65, 0x5f, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x0e, 0x20,
|
||||||
|
0x01, 0x28, 0x0c, 0x52, 0x10, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x65, 0x73, 0x6f, 0x75,
|
||||||
|
0x72, 0x63, 0x65, 0x49, 0x64, 0x12, 0x28, 0x0a, 0x10, 0x64, 0x65, 0x73, 0x74, 0x5f, 0x72, 0x65,
|
||||||
|
0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x0c, 0x52,
|
||||||
|
0x0e, 0x64, 0x65, 0x73, 0x74, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x64, 0x42,
|
||||||
|
0x11, 0x0a, 0x0f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x6e,
|
||||||
|
0x66, 0x6f, 0x22, 0x48, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1f,
|
||||||
|
0x0a, 0x0b, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20,
|
||||||
|
0x01, 0x28, 0x0d, 0x52, 0x0a, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x12,
|
||||||
|
0x1b, 0x0a, 0x09, 0x64, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01,
|
||||||
|
0x28, 0x0d, 0x52, 0x08, 0x64, 0x65, 0x73, 0x74, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x44, 0x0a, 0x08,
|
||||||
|
0x49, 0x43, 0x4d, 0x50, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1b, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70,
|
||||||
|
0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x69, 0x63, 0x6d,
|
||||||
|
0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x63, 0x6f,
|
||||||
|
0x64, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x43, 0x6f,
|
||||||
|
0x64, 0x65, 0x2a, 0x45, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x10, 0x0a, 0x0c, 0x54, 0x59,
|
||||||
|
0x50, 0x45, 0x5f, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x0e, 0x0a, 0x0a,
|
||||||
|
0x54, 0x59, 0x50, 0x45, 0x5f, 0x53, 0x54, 0x41, 0x52, 0x54, 0x10, 0x01, 0x12, 0x0c, 0x0a, 0x08,
|
||||||
|
0x54, 0x59, 0x50, 0x45, 0x5f, 0x45, 0x4e, 0x44, 0x10, 0x02, 0x12, 0x0d, 0x0a, 0x09, 0x54, 0x59,
|
||||||
|
0x50, 0x45, 0x5f, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x03, 0x2a, 0x3b, 0x0a, 0x09, 0x44, 0x69, 0x72,
|
||||||
|
0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x15, 0x0a, 0x11, 0x44, 0x49, 0x52, 0x45, 0x43, 0x54,
|
||||||
|
0x49, 0x4f, 0x4e, 0x5f, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x0b, 0x0a,
|
||||||
|
0x07, 0x49, 0x4e, 0x47, 0x52, 0x45, 0x53, 0x53, 0x10, 0x01, 0x12, 0x0a, 0x0a, 0x06, 0x45, 0x47,
|
||||||
|
0x52, 0x45, 0x53, 0x53, 0x10, 0x02, 0x32, 0x42, 0x0a, 0x0b, 0x46, 0x6c, 0x6f, 0x77, 0x53, 0x65,
|
||||||
|
0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x06, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x12,
|
||||||
|
0x0f, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e, 0x74,
|
||||||
|
0x1a, 0x12, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e,
|
||||||
|
0x74, 0x41, 0x63, 0x6b, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70,
|
||||||
|
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
file_flow_proto_rawDescOnce sync.Once
|
||||||
|
file_flow_proto_rawDescData = file_flow_proto_rawDesc
|
||||||
|
)
|
||||||
|
|
||||||
|
func file_flow_proto_rawDescGZIP() []byte {
|
||||||
|
file_flow_proto_rawDescOnce.Do(func() {
|
||||||
|
file_flow_proto_rawDescData = protoimpl.X.CompressGZIP(file_flow_proto_rawDescData)
|
||||||
|
})
|
||||||
|
return file_flow_proto_rawDescData
|
||||||
|
}
|
||||||
|
|
||||||
|
var file_flow_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
|
||||||
|
var file_flow_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
|
||||||
|
var file_flow_proto_goTypes = []interface{}{
|
||||||
|
(Type)(0), // 0: flow.Type
|
||||||
|
(Direction)(0), // 1: flow.Direction
|
||||||
|
(*FlowEvent)(nil), // 2: flow.FlowEvent
|
||||||
|
(*FlowEventAck)(nil), // 3: flow.FlowEventAck
|
||||||
|
(*FlowFields)(nil), // 4: flow.FlowFields
|
||||||
|
(*PortInfo)(nil), // 5: flow.PortInfo
|
||||||
|
(*ICMPInfo)(nil), // 6: flow.ICMPInfo
|
||||||
|
(*timestamppb.Timestamp)(nil), // 7: google.protobuf.Timestamp
|
||||||
|
}
|
||||||
|
var file_flow_proto_depIdxs = []int32{
|
||||||
|
7, // 0: flow.FlowEvent.timestamp:type_name -> google.protobuf.Timestamp
|
||||||
|
4, // 1: flow.FlowEvent.flow_fields:type_name -> flow.FlowFields
|
||||||
|
0, // 2: flow.FlowFields.type:type_name -> flow.Type
|
||||||
|
1, // 3: flow.FlowFields.direction:type_name -> flow.Direction
|
||||||
|
5, // 4: flow.FlowFields.port_info:type_name -> flow.PortInfo
|
||||||
|
6, // 5: flow.FlowFields.icmp_info:type_name -> flow.ICMPInfo
|
||||||
|
2, // 6: flow.FlowService.Events:input_type -> flow.FlowEvent
|
||||||
|
3, // 7: flow.FlowService.Events:output_type -> flow.FlowEventAck
|
||||||
|
7, // [7:8] is the sub-list for method output_type
|
||||||
|
6, // [6:7] is the sub-list for method input_type
|
||||||
|
6, // [6:6] is the sub-list for extension type_name
|
||||||
|
6, // [6:6] is the sub-list for extension extendee
|
||||||
|
0, // [0:6] is the sub-list for field type_name
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() { file_flow_proto_init() }
|
||||||
|
func file_flow_proto_init() {
|
||||||
|
if File_flow_proto != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !protoimpl.UnsafeEnabled {
|
||||||
|
file_flow_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
||||||
|
switch v := v.(*FlowEvent); i {
|
||||||
|
case 0:
|
||||||
|
return &v.state
|
||||||
|
case 1:
|
||||||
|
return &v.sizeCache
|
||||||
|
case 2:
|
||||||
|
return &v.unknownFields
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
file_flow_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
|
||||||
|
switch v := v.(*FlowEventAck); i {
|
||||||
|
case 0:
|
||||||
|
return &v.state
|
||||||
|
case 1:
|
||||||
|
return &v.sizeCache
|
||||||
|
case 2:
|
||||||
|
return &v.unknownFields
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
file_flow_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
|
||||||
|
switch v := v.(*FlowFields); i {
|
||||||
|
case 0:
|
||||||
|
return &v.state
|
||||||
|
case 1:
|
||||||
|
return &v.sizeCache
|
||||||
|
case 2:
|
||||||
|
return &v.unknownFields
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
file_flow_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
|
||||||
|
switch v := v.(*PortInfo); i {
|
||||||
|
case 0:
|
||||||
|
return &v.state
|
||||||
|
case 1:
|
||||||
|
return &v.sizeCache
|
||||||
|
case 2:
|
||||||
|
return &v.unknownFields
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
file_flow_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
|
||||||
|
switch v := v.(*ICMPInfo); i {
|
||||||
|
case 0:
|
||||||
|
return &v.state
|
||||||
|
case 1:
|
||||||
|
return &v.sizeCache
|
||||||
|
case 2:
|
||||||
|
return &v.unknownFields
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
file_flow_proto_msgTypes[2].OneofWrappers = []interface{}{
|
||||||
|
(*FlowFields_PortInfo)(nil),
|
||||||
|
(*FlowFields_IcmpInfo)(nil),
|
||||||
|
}
|
||||||
|
type x struct{}
|
||||||
|
out := protoimpl.TypeBuilder{
|
||||||
|
File: protoimpl.DescBuilder{
|
||||||
|
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||||
|
RawDescriptor: file_flow_proto_rawDesc,
|
||||||
|
NumEnums: 2,
|
||||||
|
NumMessages: 5,
|
||||||
|
NumExtensions: 0,
|
||||||
|
NumServices: 1,
|
||||||
|
},
|
||||||
|
GoTypes: file_flow_proto_goTypes,
|
||||||
|
DependencyIndexes: file_flow_proto_depIdxs,
|
||||||
|
EnumInfos: file_flow_proto_enumTypes,
|
||||||
|
MessageInfos: file_flow_proto_msgTypes,
|
||||||
|
}.Build()
|
||||||
|
File_flow_proto = out.File
|
||||||
|
file_flow_proto_rawDesc = nil
|
||||||
|
file_flow_proto_goTypes = nil
|
||||||
|
file_flow_proto_depIdxs = nil
|
||||||
|
}
|
||||||
102
flow/proto/flow.proto
Normal file
102
flow/proto/flow.proto
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
import "google/protobuf/timestamp.proto";
|
||||||
|
|
||||||
|
option go_package = "/proto";
|
||||||
|
|
||||||
|
package flow;
|
||||||
|
|
||||||
|
service FlowService {
|
||||||
|
// Client to receiver streams of events and acknowledgements
|
||||||
|
rpc Events(stream FlowEvent) returns (stream FlowEventAck) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
message FlowEvent {
|
||||||
|
// Unique client event identifier
|
||||||
|
bytes event_id = 1;
|
||||||
|
|
||||||
|
// When the event occurred
|
||||||
|
google.protobuf.Timestamp timestamp = 2;
|
||||||
|
|
||||||
|
// Public key of the sending peer
|
||||||
|
bytes public_key = 3;
|
||||||
|
|
||||||
|
FlowFields flow_fields = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message FlowEventAck {
|
||||||
|
// Unique client event identifier that has been ack'ed
|
||||||
|
bytes event_id = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message FlowFields {
|
||||||
|
// Unique client flow session identifier
|
||||||
|
bytes flow_id = 1;
|
||||||
|
|
||||||
|
// Flow type
|
||||||
|
Type type = 2;
|
||||||
|
|
||||||
|
// RuleId identifies the rule that allowed or denied the connection
|
||||||
|
bytes rule_id = 3;
|
||||||
|
|
||||||
|
// Initiating traffic direction
|
||||||
|
Direction direction = 4;
|
||||||
|
|
||||||
|
// IP protocol number
|
||||||
|
uint32 protocol = 5;
|
||||||
|
|
||||||
|
// Source IP address
|
||||||
|
bytes source_ip = 6;
|
||||||
|
|
||||||
|
// Destination IP address
|
||||||
|
bytes dest_ip = 7;
|
||||||
|
|
||||||
|
// Layer 4 -specific information
|
||||||
|
oneof connection_info {
|
||||||
|
// TCP/UDP port information
|
||||||
|
PortInfo port_info = 8;
|
||||||
|
|
||||||
|
// ICMP type and code
|
||||||
|
ICMPInfo icmp_info = 9;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Number of packets
|
||||||
|
uint64 rx_packets = 10;
|
||||||
|
uint64 tx_packets = 11;
|
||||||
|
|
||||||
|
// Number of bytes
|
||||||
|
uint64 rx_bytes = 12;
|
||||||
|
uint64 tx_bytes = 13;
|
||||||
|
|
||||||
|
// Resource ID
|
||||||
|
bytes source_resource_id = 14;
|
||||||
|
bytes dest_resource_id = 15;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flow event types
|
||||||
|
enum Type {
|
||||||
|
TYPE_UNKNOWN = 0;
|
||||||
|
TYPE_START = 1;
|
||||||
|
TYPE_END = 2;
|
||||||
|
TYPE_DROP = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flow direction
|
||||||
|
enum Direction {
|
||||||
|
DIRECTION_UNKNOWN = 0;
|
||||||
|
INGRESS = 1;
|
||||||
|
EGRESS = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TCP/UDP port information
|
||||||
|
message PortInfo {
|
||||||
|
uint32 source_port = 1;
|
||||||
|
uint32 dest_port = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ICMP message information
|
||||||
|
message ICMPInfo {
|
||||||
|
uint32 icmp_type = 1;
|
||||||
|
uint32 icmp_code = 2;
|
||||||
|
}
|
||||||
135
flow/proto/flow_grpc.pb.go
Normal file
135
flow/proto/flow_grpc.pb.go
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||||
|
|
||||||
|
package proto
|
||||||
|
|
||||||
|
import (
|
||||||
|
context "context"
|
||||||
|
grpc "google.golang.org/grpc"
|
||||||
|
codes "google.golang.org/grpc/codes"
|
||||||
|
status "google.golang.org/grpc/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This is a compile-time assertion to ensure that this generated file
|
||||||
|
// is compatible with the grpc package it is being compiled against.
|
||||||
|
// Requires gRPC-Go v1.32.0 or later.
|
||||||
|
const _ = grpc.SupportPackageIsVersion7
|
||||||
|
|
||||||
|
// FlowServiceClient is the client API for FlowService service.
|
||||||
|
//
|
||||||
|
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||||
|
type FlowServiceClient interface {
|
||||||
|
// Client to receiver streams of events and acknowledgements
|
||||||
|
Events(ctx context.Context, opts ...grpc.CallOption) (FlowService_EventsClient, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type flowServiceClient struct {
|
||||||
|
cc grpc.ClientConnInterface
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFlowServiceClient(cc grpc.ClientConnInterface) FlowServiceClient {
|
||||||
|
return &flowServiceClient{cc}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *flowServiceClient) Events(ctx context.Context, opts ...grpc.CallOption) (FlowService_EventsClient, error) {
|
||||||
|
stream, err := c.cc.NewStream(ctx, &FlowService_ServiceDesc.Streams[0], "/flow.FlowService/Events", opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
x := &flowServiceEventsClient{stream}
|
||||||
|
return x, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type FlowService_EventsClient interface {
|
||||||
|
Send(*FlowEvent) error
|
||||||
|
Recv() (*FlowEventAck, error)
|
||||||
|
grpc.ClientStream
|
||||||
|
}
|
||||||
|
|
||||||
|
type flowServiceEventsClient struct {
|
||||||
|
grpc.ClientStream
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *flowServiceEventsClient) Send(m *FlowEvent) error {
|
||||||
|
return x.ClientStream.SendMsg(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *flowServiceEventsClient) Recv() (*FlowEventAck, error) {
|
||||||
|
m := new(FlowEventAck)
|
||||||
|
if err := x.ClientStream.RecvMsg(m); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FlowServiceServer is the server API for FlowService service.
|
||||||
|
// All implementations must embed UnimplementedFlowServiceServer
|
||||||
|
// for forward compatibility
|
||||||
|
type FlowServiceServer interface {
|
||||||
|
// Client to receiver streams of events and acknowledgements
|
||||||
|
Events(FlowService_EventsServer) error
|
||||||
|
mustEmbedUnimplementedFlowServiceServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnimplementedFlowServiceServer must be embedded to have forward compatible implementations.
|
||||||
|
type UnimplementedFlowServiceServer struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedFlowServiceServer) Events(FlowService_EventsServer) error {
|
||||||
|
return status.Errorf(codes.Unimplemented, "method Events not implemented")
|
||||||
|
}
|
||||||
|
func (UnimplementedFlowServiceServer) mustEmbedUnimplementedFlowServiceServer() {}
|
||||||
|
|
||||||
|
// UnsafeFlowServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||||
|
// Use of this interface is not recommended, as added methods to FlowServiceServer will
|
||||||
|
// result in compilation errors.
|
||||||
|
type UnsafeFlowServiceServer interface {
|
||||||
|
mustEmbedUnimplementedFlowServiceServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterFlowServiceServer(s grpc.ServiceRegistrar, srv FlowServiceServer) {
|
||||||
|
s.RegisterService(&FlowService_ServiceDesc, srv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _FlowService_Events_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||||
|
return srv.(FlowServiceServer).Events(&flowServiceEventsServer{stream})
|
||||||
|
}
|
||||||
|
|
||||||
|
type FlowService_EventsServer interface {
|
||||||
|
Send(*FlowEventAck) error
|
||||||
|
Recv() (*FlowEvent, error)
|
||||||
|
grpc.ServerStream
|
||||||
|
}
|
||||||
|
|
||||||
|
type flowServiceEventsServer struct {
|
||||||
|
grpc.ServerStream
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *flowServiceEventsServer) Send(m *FlowEventAck) error {
|
||||||
|
return x.ServerStream.SendMsg(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *flowServiceEventsServer) Recv() (*FlowEvent, error) {
|
||||||
|
m := new(FlowEvent)
|
||||||
|
if err := x.ServerStream.RecvMsg(m); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FlowService_ServiceDesc is the grpc.ServiceDesc for FlowService service.
|
||||||
|
// It's only intended for direct use with grpc.RegisterService,
|
||||||
|
// and not to be introspected or modified (even as a copy)
|
||||||
|
var FlowService_ServiceDesc = grpc.ServiceDesc{
|
||||||
|
ServiceName: "flow.FlowService",
|
||||||
|
HandlerType: (*FlowServiceServer)(nil),
|
||||||
|
Methods: []grpc.MethodDesc{},
|
||||||
|
Streams: []grpc.StreamDesc{
|
||||||
|
{
|
||||||
|
StreamName: "Events",
|
||||||
|
Handler: _FlowService_Events_Handler,
|
||||||
|
ServerStreams: true,
|
||||||
|
ClientStreams: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Metadata: "flow.proto",
|
||||||
|
}
|
||||||
17
flow/proto/generate.sh
Executable file
17
flow/proto/generate.sh
Executable file
@@ -0,0 +1,17 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
if ! which realpath > /dev/null 2>&1
|
||||||
|
then
|
||||||
|
echo realpath is not installed
|
||||||
|
echo run: brew install coreutils
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
old_pwd=$(pwd)
|
||||||
|
script_path=$(dirname $(realpath "$0"))
|
||||||
|
cd "$script_path"
|
||||||
|
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26
|
||||||
|
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
|
||||||
|
protoc -I ./ ./flow.proto --go_out=../ --go-grpc_out=../
|
||||||
|
cd "$old_pwd"
|
||||||
4
go.mod
4
go.mod
@@ -60,7 +60,7 @@ require (
|
|||||||
github.com/miekg/dns v1.1.59
|
github.com/miekg/dns v1.1.59
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
github.com/nadoo/ipset v0.5.0
|
github.com/nadoo/ipset v0.5.0
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250307154727-58660ea9a141
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20250310094048-24724cc8c9c3
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
|
||||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||||
github.com/oschwald/maxminddb-golang v1.12.0
|
github.com/oschwald/maxminddb-golang v1.12.0
|
||||||
@@ -82,6 +82,8 @@ require (
|
|||||||
github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0
|
github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0
|
||||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0
|
github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0
|
||||||
github.com/things-go/go-socks5 v0.0.4
|
github.com/things-go/go-socks5 v0.0.4
|
||||||
|
github.com/ti-mo/conntrack v0.5.1
|
||||||
|
github.com/ti-mo/netfilter v0.5.2
|
||||||
github.com/yusufpapurcu/wmi v1.2.4
|
github.com/yusufpapurcu/wmi v1.2.4
|
||||||
github.com/zcalusic/sysinfo v1.1.3
|
github.com/zcalusic/sysinfo v1.1.3
|
||||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0
|
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0
|
||||||
|
|||||||
8
go.sum
8
go.sum
@@ -529,8 +529,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
|||||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
||||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250307154727-58660ea9a141 h1:GZUkZd9ZMBGahNt+AbYYvZrSMpOnaBLjHiBbloOE7sc=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20250310094048-24724cc8c9c3 h1:+DNgrPvpdNZR/UiyFQ7fb8weENmy7rB5S54zIoTPhYE=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250307154727-58660ea9a141/go.mod h1:A5QUfEZb5J3tw8EUB9e3q7Bgd/JtC0WlFT1onf3HPCY=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20250310094048-24724cc8c9c3/go.mod h1:NZ63GQu65YcqarxJxY9p05ukZY16sEmcW9O3GX92T/A=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
||||||
@@ -699,6 +699,10 @@ github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0 h1:isAwFS3K
|
|||||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0/go.mod h1:ZNYY8vumNCEG9YI59A9d6/YaMY49uwRhmeU563EzFGw=
|
github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0/go.mod h1:ZNYY8vumNCEG9YI59A9d6/YaMY49uwRhmeU563EzFGw=
|
||||||
github.com/things-go/go-socks5 v0.0.4 h1:jMQjIc+qhD4z9cITOMnBiwo9dDmpGuXmBlkRFrl/qD0=
|
github.com/things-go/go-socks5 v0.0.4 h1:jMQjIc+qhD4z9cITOMnBiwo9dDmpGuXmBlkRFrl/qD0=
|
||||||
github.com/things-go/go-socks5 v0.0.4/go.mod h1:sh4K6WHrmHZpjxLTCHyYtXYH8OUuD+yZun41NomR1IQ=
|
github.com/things-go/go-socks5 v0.0.4/go.mod h1:sh4K6WHrmHZpjxLTCHyYtXYH8OUuD+yZun41NomR1IQ=
|
||||||
|
github.com/ti-mo/conntrack v0.5.1 h1:opEwkFICnDbQc0BUXl73PHBK0h23jEIFVjXsqvF4GY0=
|
||||||
|
github.com/ti-mo/conntrack v0.5.1/go.mod h1:T6NCbkMdVU4qEIgwL0njA6lw/iCAbzchlnwm1Sa314o=
|
||||||
|
github.com/ti-mo/netfilter v0.5.2 h1:CTjOwFuNNeZ9QPdRXt1MZFLFUf84cKtiQutNauHWd40=
|
||||||
|
github.com/ti-mo/netfilter v0.5.2/go.mod h1:Btx3AtFiOVdHReTDmP9AE+hlkOcvIy403u7BXXbWZKo=
|
||||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||||
github.com/tklauser/go-sysconf v0.3.14 h1:g5vzr9iPFFz24v2KZXs/pvpvh8/V9Fw6vQK5ZZb78yU=
|
github.com/tklauser/go-sysconf v0.3.14 h1:g5vzr9iPFFz24v2KZXs/pvpvh8/V9Fw6vQK5ZZb78yU=
|
||||||
github.com/tklauser/go-sysconf v0.3.14/go.mod h1:1ym4lWMLUOhuBOPGtRcJm7tEGX4SCYNEEEtghGG/8uY=
|
github.com/tklauser/go-sysconf v0.3.14/go.mod h1:1ym4lWMLUOhuBOPGtRcJm7tEGX4SCYNEEEtghGG/8uY=
|
||||||
|
|||||||
@@ -10,14 +10,13 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
@@ -73,13 +72,13 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
|||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock())
|
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settings.NewManagerMock())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
|
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManagerMock(), peersUpdateManager, secretsManager, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,7 +34,6 @@ import (
|
|||||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/peers"
|
"github.com/netbirdio/netbird/management/server/peers"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
@@ -203,13 +202,14 @@ var (
|
|||||||
}
|
}
|
||||||
|
|
||||||
userManager := users.NewManager(store)
|
userManager := users.NewManager(store)
|
||||||
settingsManager := settings.NewManager(store)
|
extraSettingsManager := integrations.NewManager()
|
||||||
|
settingsManager := settings.NewManager(store, userManager, extraSettingsManager)
|
||||||
permissionsManager := permissions.NewManager(userManager, settingsManager)
|
permissionsManager := permissions.NewManager(userManager, settingsManager)
|
||||||
peersManager := peers.NewManager(store, permissionsManager)
|
peersManager := peers.NewManager(store, permissionsManager)
|
||||||
proxyController := integrations.NewController(store)
|
proxyController := integrations.NewController(store)
|
||||||
|
|
||||||
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
|
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
|
||||||
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController)
|
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to build default manager: %v", err)
|
return fmt.Errorf("failed to build default manager: %v", err)
|
||||||
}
|
}
|
||||||
@@ -276,7 +276,7 @@ var (
|
|||||||
routersManager := routers.NewManager(store, permissionsManager, accountManager)
|
routersManager := routers.NewManager(store, permissionsManager, accountManager)
|
||||||
networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager)
|
networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager)
|
||||||
|
|
||||||
httpAPIHandler, err := nbhttp.NewAPIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, authManager, appMetrics, integratedPeerValidator, proxyController, permissionsManager, peersManager)
|
httpAPIHandler, err := nbhttp.NewAPIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, authManager, appMetrics, integratedPeerValidator, proxyController, permissionsManager, peersManager, settingsManager)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed creating HTTP API handler: %v", err)
|
return fmt.Errorf("failed creating HTTP API handler: %v", err)
|
||||||
|
|||||||
@@ -14,4 +14,4 @@ cd "$script_path"
|
|||||||
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26
|
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26
|
||||||
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
|
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
|
||||||
protoc -I ./ ./management.proto --go_out=../ --go-grpc_out=../
|
protoc -I ./ ./management.proto --go_out=../ --go-grpc_out=../
|
||||||
cd "$old_pwd"
|
cd "$old_pwd"
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,7 @@
|
|||||||
syntax = "proto3";
|
syntax = "proto3";
|
||||||
|
|
||||||
import "google/protobuf/timestamp.proto";
|
import "google/protobuf/timestamp.proto";
|
||||||
|
import "google/protobuf/duration.proto";
|
||||||
|
|
||||||
option go_package = "/proto";
|
option go_package = "/proto";
|
||||||
|
|
||||||
@@ -97,7 +98,7 @@ message LoginRequest {
|
|||||||
string jwtToken = 3;
|
string jwtToken = 3;
|
||||||
// Can be absent for now.
|
// Can be absent for now.
|
||||||
PeerKeys peerKeys = 4;
|
PeerKeys peerKeys = 4;
|
||||||
|
|
||||||
repeated string dnsLabels = 5;
|
repeated string dnsLabels = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -191,6 +192,8 @@ message NetbirdConfig {
|
|||||||
HostConfig signal = 3;
|
HostConfig signal = 3;
|
||||||
|
|
||||||
RelayConfig relay = 4;
|
RelayConfig relay = 4;
|
||||||
|
|
||||||
|
FlowConfig flow = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
// HostConfig describes connection properties of some server (e.g. STUN, Signal, Management)
|
// HostConfig describes connection properties of some server (e.g. STUN, Signal, Management)
|
||||||
@@ -214,6 +217,17 @@ message RelayConfig {
|
|||||||
string tokenSignature = 3;
|
string tokenSignature = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message FlowConfig {
|
||||||
|
string url = 1;
|
||||||
|
string tokenPayload = 2;
|
||||||
|
string tokenSignature = 3;
|
||||||
|
google.protobuf.Duration interval = 4;
|
||||||
|
bool enabled = 5;
|
||||||
|
|
||||||
|
// counters determines if flow packets and bytes counters should be sent
|
||||||
|
bool counters = 6;
|
||||||
|
}
|
||||||
|
|
||||||
// ProtectedHostConfig is similar to HostConfig but has additional user and password
|
// ProtectedHostConfig is similar to HostConfig but has additional user and password
|
||||||
// Mostly used for TURN servers
|
// Mostly used for TURN servers
|
||||||
message ProtectedHostConfig {
|
message ProtectedHostConfig {
|
||||||
@@ -434,6 +448,9 @@ message FirewallRule {
|
|||||||
RuleProtocol Protocol = 4;
|
RuleProtocol Protocol = 4;
|
||||||
string Port = 5;
|
string Port = 5;
|
||||||
PortInfo PortInfo = 6;
|
PortInfo PortInfo = 6;
|
||||||
|
|
||||||
|
// PolicyID is the ID of the policy that this rule belongs to
|
||||||
|
bytes PolicyID = 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
message NetworkAddress {
|
message NetworkAddress {
|
||||||
@@ -483,6 +500,9 @@ message RouteFirewallRule {
|
|||||||
|
|
||||||
// CustomProtocol is a custom protocol ID.
|
// CustomProtocol is a custom protocol ID.
|
||||||
uint32 customProtocol = 8;
|
uint32 customProtocol = 8;
|
||||||
|
|
||||||
|
// PolicyID is the ID of the policy that this rule belongs to
|
||||||
|
bytes PolicyID = 9;
|
||||||
}
|
}
|
||||||
|
|
||||||
message ForwardingRule {
|
message ForwardingRule {
|
||||||
@@ -493,7 +513,6 @@ message ForwardingRule {
|
|||||||
PortInfo destinationPort = 2;
|
PortInfo destinationPort = 2;
|
||||||
|
|
||||||
// IP address of the translated address (remote peer) to send traffic to
|
// IP address of the translated address (remote peer) to send traffic to
|
||||||
// todo type pending
|
|
||||||
bytes translatedAddress = 3;
|
bytes translatedAddress = 3;
|
||||||
|
|
||||||
// Translated port information, where the traffic should be forwarded to
|
// Translated port information, where the traffic should be forwarded to
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -22,7 +21,7 @@ import (
|
|||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
@@ -31,6 +30,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
@@ -49,104 +49,11 @@ const (
|
|||||||
|
|
||||||
type userLoggedInOnce bool
|
type userLoggedInOnce bool
|
||||||
|
|
||||||
type ExternalCacheManager cache.CacheInterface[*idp.UserData]
|
|
||||||
|
|
||||||
func cacheEntryExpiration() time.Duration {
|
func cacheEntryExpiration() time.Duration {
|
||||||
r := rand.Intn(int(CacheExpirationMax.Milliseconds()-CacheExpirationMin.Milliseconds())) + int(CacheExpirationMin.Milliseconds())
|
r := rand.Intn(int(CacheExpirationMax.Milliseconds()-CacheExpirationMin.Milliseconds())) + int(CacheExpirationMin.Milliseconds())
|
||||||
return time.Duration(r) * time.Millisecond
|
return time.Duration(r) * time.Millisecond
|
||||||
}
|
}
|
||||||
|
|
||||||
type AccountManager interface {
|
|
||||||
GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*types.Account, error)
|
|
||||||
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
|
|
||||||
CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType, expiresIn time.Duration,
|
|
||||||
autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error)
|
|
||||||
SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error)
|
|
||||||
CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error)
|
|
||||||
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
|
|
||||||
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
|
|
||||||
InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
|
||||||
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error)
|
|
||||||
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error)
|
|
||||||
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error)
|
|
||||||
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*types.User, addIfNotExists bool) ([]*types.UserInfo, error)
|
|
||||||
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error)
|
|
||||||
GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error)
|
|
||||||
AccountExists(ctx context.Context, accountID string) (bool, error)
|
|
||||||
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
|
|
||||||
GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
|
||||||
DeleteAccount(ctx context.Context, accountID, userID string) error
|
|
||||||
GetUserByID(ctx context.Context, id string) (*types.User, error)
|
|
||||||
GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
|
|
||||||
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
|
||||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
|
||||||
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
|
|
||||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
|
||||||
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
|
||||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
|
||||||
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
|
|
||||||
AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
|
||||||
CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
|
|
||||||
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
|
||||||
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
|
|
||||||
GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error)
|
|
||||||
GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
|
|
||||||
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
|
|
||||||
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
|
|
||||||
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
|
|
||||||
SaveGroup(ctx context.Context, accountID, userID string, group *types.Group) error
|
|
||||||
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error
|
|
||||||
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
|
|
||||||
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
|
|
||||||
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
|
||||||
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
|
|
||||||
GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error)
|
|
||||||
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error)
|
|
||||||
SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error)
|
|
||||||
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
|
|
||||||
ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
|
|
||||||
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
|
||||||
CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
|
||||||
SaveRoute(ctx context.Context, accountID, userID string, route *route.Route) error
|
|
||||||
DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error
|
|
||||||
ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error)
|
|
||||||
GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
|
||||||
CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
|
||||||
SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
|
||||||
DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error
|
|
||||||
ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
|
||||||
GetDNSDomain() string
|
|
||||||
StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
|
|
||||||
GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error)
|
|
||||||
GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error)
|
|
||||||
SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error
|
|
||||||
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
|
||||||
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error)
|
|
||||||
LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
|
||||||
SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
|
||||||
GetAllConnectedPeers() (map[string]struct{}, error)
|
|
||||||
HasConnectedChannel(peerID string) bool
|
|
||||||
GetExternalCacheManager() ExternalCacheManager
|
|
||||||
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
|
||||||
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
|
|
||||||
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
|
|
||||||
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
|
||||||
GetIdpManager() idp.Manager
|
|
||||||
UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error
|
|
||||||
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
|
||||||
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error)
|
|
||||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
|
||||||
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
|
|
||||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
|
||||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
|
||||||
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
|
||||||
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
|
||||||
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
|
|
||||||
UpdateAccountPeers(ctx context.Context, accountID string)
|
|
||||||
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
|
||||||
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type DefaultAccountManager struct {
|
type DefaultAccountManager struct {
|
||||||
Store store.Store
|
Store store.Store
|
||||||
// cacheMux and cacheLoading helps to make sure that only a single cache reload runs at a time per accountID
|
// cacheMux and cacheLoading helps to make sure that only a single cache reload runs at a time per accountID
|
||||||
@@ -156,7 +63,7 @@ type DefaultAccountManager struct {
|
|||||||
peersUpdateManager *PeersUpdateManager
|
peersUpdateManager *PeersUpdateManager
|
||||||
idpManager idp.Manager
|
idpManager idp.Manager
|
||||||
cacheManager cache.CacheInterface[[]*idp.UserData]
|
cacheManager cache.CacheInterface[[]*idp.UserData]
|
||||||
externalCacheManager ExternalCacheManager
|
externalCacheManager account.ExternalCacheManager
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
eventStore activity.Store
|
eventStore activity.Store
|
||||||
geo geolocation.Geolocation
|
geo geolocation.Geolocation
|
||||||
@@ -164,6 +71,7 @@ type DefaultAccountManager struct {
|
|||||||
requestBuffer *AccountRequestBuffer
|
requestBuffer *AccountRequestBuffer
|
||||||
|
|
||||||
proxyController port_forwarding.Controller
|
proxyController port_forwarding.Controller
|
||||||
|
settingsManager settings.Manager
|
||||||
|
|
||||||
// singleAccountMode indicates whether the instance has a single account.
|
// singleAccountMode indicates whether the instance has a single account.
|
||||||
// If true, then every new user will end up under the same account.
|
// If true, then every new user will end up under the same account.
|
||||||
@@ -249,6 +157,7 @@ func BuildManager(
|
|||||||
integratedPeerValidator integrated_validator.IntegratedValidator,
|
integratedPeerValidator integrated_validator.IntegratedValidator,
|
||||||
metrics telemetry.AppMetrics,
|
metrics telemetry.AppMetrics,
|
||||||
proxyController port_forwarding.Controller,
|
proxyController port_forwarding.Controller,
|
||||||
|
settingsManager settings.Manager,
|
||||||
) (*DefaultAccountManager, error) {
|
) (*DefaultAccountManager, error) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -272,6 +181,7 @@ func BuildManager(
|
|||||||
metrics: metrics,
|
metrics: metrics,
|
||||||
requestBuffer: NewAccountRequestBuffer(ctx, store),
|
requestBuffer: NewAccountRequestBuffer(ctx, store),
|
||||||
proxyController: proxyController,
|
proxyController: proxyController,
|
||||||
|
settingsManager: settingsManager,
|
||||||
}
|
}
|
||||||
accountsCounter, err := store.GetAccountsCounter(ctx)
|
accountsCounter, err := store.GetAccountsCounter(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -317,7 +227,7 @@ func BuildManager(
|
|||||||
return am, nil
|
return am, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) GetExternalCacheManager() ExternalCacheManager {
|
func (am *DefaultAccountManager) GetExternalCacheManager() account.ExternalCacheManager {
|
||||||
return am.externalCacheManager
|
return am.externalCacheManager
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -406,6 +316,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = am.settingsManager.UpdateExtraSettings(ctx, accountID, newSettings.Extra)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if updateAccountPeers {
|
if updateAccountPeers {
|
||||||
go am.UpdateAccountPeers(ctx, accountID)
|
go am.UpdateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
@@ -1476,7 +1391,7 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
|
|||||||
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
|
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
|
||||||
defer peerUnlock()
|
defer peerUnlock()
|
||||||
|
|
||||||
peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
peer, netMap, postureChecks, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err)
|
return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1516,7 +1431,7 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
|
|||||||
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
|
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
|
||||||
defer unlockPeer()
|
defer unlockPeer()
|
||||||
|
|
||||||
_, _, _, err = am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
|
_, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return mapError(ctx, err)
|
return mapError(ctx, err)
|
||||||
}
|
}
|
||||||
@@ -1685,3 +1600,7 @@ func separateGroups(autoGroups []string, allGroups []*types.Group) ([]string, ma
|
|||||||
|
|
||||||
return newAutoGroups, jwtAutoGroups
|
return newAutoGroups, jwtAutoGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) GetStore() store.Store {
|
||||||
|
return am.Store
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,19 +0,0 @@
|
|||||||
package account
|
|
||||||
|
|
||||||
type ExtraSettings struct {
|
|
||||||
// PeerApprovalEnabled enables or disables the need for peers bo be approved by an administrator
|
|
||||||
PeerApprovalEnabled bool
|
|
||||||
|
|
||||||
// IntegratedValidatorGroups list of group IDs to be used with integrated approval configurations
|
|
||||||
IntegratedValidatorGroups []string `gorm:"serializer:json"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy copies the ExtraSettings struct
|
|
||||||
func (e *ExtraSettings) Copy() *ExtraSettings {
|
|
||||||
var cpGroup []string
|
|
||||||
|
|
||||||
return &ExtraSettings{
|
|
||||||
PeerApprovalEnabled: e.PeerApprovalEnabled,
|
|
||||||
IntegratedValidatorGroups: append(cpGroup, e.IntegratedValidatorGroups...),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
115
management/server/account/manager.go
Normal file
115
management/server/account/manager.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package account
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/eko/gocache/v3/cache"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ExternalCacheManager cache.CacheInterface[*idp.UserData]
|
||||||
|
|
||||||
|
type Manager interface {
|
||||||
|
GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*types.Account, error)
|
||||||
|
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
|
||||||
|
CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType, expiresIn time.Duration,
|
||||||
|
autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error)
|
||||||
|
SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error)
|
||||||
|
CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error)
|
||||||
|
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
|
||||||
|
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
|
||||||
|
InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
||||||
|
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error)
|
||||||
|
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error)
|
||||||
|
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error)
|
||||||
|
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*types.User, addIfNotExists bool) ([]*types.UserInfo, error)
|
||||||
|
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error)
|
||||||
|
GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error)
|
||||||
|
AccountExists(ctx context.Context, accountID string) (bool, error)
|
||||||
|
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
|
||||||
|
GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
||||||
|
DeleteAccount(ctx context.Context, accountID, userID string) error
|
||||||
|
GetUserByID(ctx context.Context, id string) (*types.User, error)
|
||||||
|
GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
|
||||||
|
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
||||||
|
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||||
|
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
|
||||||
|
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||||
|
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||||
|
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||||
|
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
|
||||||
|
AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||||
|
CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
|
||||||
|
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
||||||
|
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
|
||||||
|
GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error)
|
||||||
|
GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
|
||||||
|
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
|
||||||
|
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
|
||||||
|
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
|
||||||
|
SaveGroup(ctx context.Context, accountID, userID string, group *types.Group) error
|
||||||
|
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error
|
||||||
|
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
|
||||||
|
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
|
||||||
|
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||||
|
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||||
|
GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error)
|
||||||
|
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error)
|
||||||
|
SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error)
|
||||||
|
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
|
||||||
|
ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
|
||||||
|
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||||
|
CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
||||||
|
SaveRoute(ctx context.Context, accountID, userID string, route *route.Route) error
|
||||||
|
DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error
|
||||||
|
ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error)
|
||||||
|
GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||||
|
CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
||||||
|
SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||||
|
DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error
|
||||||
|
ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||||
|
GetDNSDomain() string
|
||||||
|
StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
|
||||||
|
GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error)
|
||||||
|
GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error)
|
||||||
|
SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error
|
||||||
|
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||||
|
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error)
|
||||||
|
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||||
|
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||||
|
GetAllConnectedPeers() (map[string]struct{}, error)
|
||||||
|
HasConnectedChannel(peerID string) bool
|
||||||
|
GetExternalCacheManager() ExternalCacheManager
|
||||||
|
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||||
|
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
|
||||||
|
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||||
|
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||||
|
GetIdpManager() idp.Manager
|
||||||
|
UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error
|
||||||
|
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||||
|
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error)
|
||||||
|
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||||
|
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
|
||||||
|
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||||
|
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
|
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
||||||
|
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
||||||
|
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
|
||||||
|
UpdateAccountPeers(ctx context.Context, accountID string)
|
||||||
|
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||||
|
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
|
||||||
|
GetStore() store.Store
|
||||||
|
}
|
||||||
@@ -13,7 +13,9 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/util"
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
|
|
||||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
@@ -36,7 +38,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *types.Account, userID string) {
|
func verifyCanAddPeerToAccount(t *testing.T, manager nbAccount.Manager, account *types.Account, userID string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
peer := &nbpeer.Peer{
|
peer := &nbpeer.Peer{
|
||||||
Key: "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8=",
|
Key: "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8=",
|
||||||
@@ -1403,7 +1405,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
|||||||
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
|
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
|
||||||
}
|
}
|
||||||
|
|
||||||
func getEvent(t *testing.T, accountID string, manager AccountManager, eventType activity.Activity) *activity.Event {
|
func getEvent(t *testing.T, accountID string, manager nbAccount.Manager, eventType activity.Activity) *activity.Event {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -2809,7 +2811,7 @@ func createManager(t TB) (*DefaultAccountManager, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock())
|
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settings.NewManagerMock())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -3024,7 +3026,7 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
|
|||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
_, _, _, err := manager.LoginPeer(context.Background(), PeerLogin{
|
_, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
|
||||||
WireGuardPubKey: account.Peers["peer-1"].Key,
|
WireGuardPubKey: account.Peers["peer-1"].Key,
|
||||||
SSHKey: "someKey",
|
SSHKey: "someKey",
|
||||||
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
|
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
|
||||||
@@ -3099,7 +3101,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
|||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
_, _, _, err := manager.LoginPeer(context.Background(), PeerLogin{
|
_, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
|
||||||
WireGuardPubKey: "some-new-key" + strconv.Itoa(i),
|
WireGuardPubKey: "some-new-key" + strconv.Itoa(i),
|
||||||
SSHKey: "someKey",
|
SSHKey: "someKey",
|
||||||
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
|
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
@@ -209,7 +210,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
|||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock())
|
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settings.NewManagerMock())
|
||||||
}
|
}
|
||||||
|
|
||||||
func createDNSStore(t *testing.T) (store.Store, error) {
|
func createDNSStore(t *testing.T) (store.Store, error) {
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
@@ -34,7 +35,7 @@ type ephemeralPeer struct {
|
|||||||
// automatically. Inactivity means the peer disconnected from the Management server.
|
// automatically. Inactivity means the peer disconnected from the Management server.
|
||||||
type EphemeralManager struct {
|
type EphemeralManager struct {
|
||||||
store store.Store
|
store store.Store
|
||||||
accountManager AccountManager
|
accountManager nbAccount.Manager
|
||||||
|
|
||||||
headPeer *ephemeralPeer
|
headPeer *ephemeralPeer
|
||||||
tailPeer *ephemeralPeer
|
tailPeer *ephemeralPeer
|
||||||
@@ -43,7 +44,7 @@ type EphemeralManager struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewEphemeralManager instantiate new EphemeralManager
|
// NewEphemeralManager instantiate new EphemeralManager
|
||||||
func NewEphemeralManager(store store.Store, accountManager AccountManager) *EphemeralManager {
|
func NewEphemeralManager(store store.Store, accountManager nbAccount.Manager) *EphemeralManager {
|
||||||
return &EphemeralManager{
|
return &EphemeralManager{
|
||||||
store: store,
|
store: store,
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
@@ -27,7 +28,7 @@ func (s *MockStore) GetAllEphemeralPeers(_ context.Context, _ store.LockingStren
|
|||||||
}
|
}
|
||||||
|
|
||||||
type MocAccountManager struct {
|
type MocAccountManager struct {
|
||||||
AccountManager
|
nbAccount.Manager
|
||||||
store *MockStore
|
store *MockStore
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -36,6 +37,10 @@ func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, user
|
|||||||
return nil //nolint:nil
|
return nil //nolint:nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a MocAccountManager) GetStore() store.Store {
|
||||||
|
return a.store
|
||||||
|
}
|
||||||
|
|
||||||
func TestNewManager(t *testing.T) {
|
func TestNewManager(t *testing.T) {
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
timeNow = func() time.Time {
|
timeNow = func() time.Time {
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
s "github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
@@ -24,13 +24,13 @@ type Manager interface {
|
|||||||
type managerImpl struct {
|
type managerImpl struct {
|
||||||
store store.Store
|
store store.Store
|
||||||
permissionsManager permissions.Manager
|
permissionsManager permissions.Manager
|
||||||
accountManager s.AccountManager
|
accountManager account.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockManager struct {
|
type mockManager struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(store store.Store, permissionsManager permissions.Manager, accountManager s.AccountManager) Manager {
|
func NewManager(store store.Store, permissionsManager permissions.Manager, accountManager account.Manager) Manager {
|
||||||
return &managerImpl{
|
return &managerImpl{
|
||||||
store: store,
|
store: store,
|
||||||
permissionsManager: permissionsManager,
|
permissionsManager: permissionsManager,
|
||||||
|
|||||||
@@ -18,8 +18,11 @@ import (
|
|||||||
"google.golang.org/grpc/peer"
|
"google.golang.org/grpc/peer"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/auth"
|
"github.com/netbirdio/netbird/management/server/auth"
|
||||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
@@ -32,7 +35,7 @@ import (
|
|||||||
|
|
||||||
// GRPCServer an instance of a Management gRPC API server
|
// GRPCServer an instance of a Management gRPC API server
|
||||||
type GRPCServer struct {
|
type GRPCServer struct {
|
||||||
accountManager AccountManager
|
accountManager account.Manager
|
||||||
settingsManager settings.Manager
|
settingsManager settings.Manager
|
||||||
wgKey wgtypes.Key
|
wgKey wgtypes.Key
|
||||||
proto.UnimplementedManagementServiceServer
|
proto.UnimplementedManagementServiceServer
|
||||||
@@ -49,7 +52,7 @@ type GRPCServer struct {
|
|||||||
func NewServer(
|
func NewServer(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
config *Config,
|
config *Config,
|
||||||
accountManager AccountManager,
|
accountManager account.Manager,
|
||||||
settingsManager settings.Manager,
|
settingsManager settings.Manager,
|
||||||
peersUpdateManager *PeersUpdateManager,
|
peersUpdateManager *PeersUpdateManager,
|
||||||
secretsManager SecretsManager,
|
secretsManager SecretsManager,
|
||||||
@@ -457,7 +460,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
|||||||
sshKey = loginReq.GetPeerKeys().GetSshPubKey()
|
sshKey = loginReq.GetPeerKeys().GetSshPubKey()
|
||||||
}
|
}
|
||||||
|
|
||||||
peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, PeerLogin{
|
peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{
|
||||||
WireGuardPubKey: peerKey.String(),
|
WireGuardPubKey: peerKey.String(),
|
||||||
SSHKey: string(sshKey),
|
SSHKey: string(sshKey),
|
||||||
Meta: extractPeerMeta(ctx, loginReq.GetMeta()),
|
Meta: extractPeerMeta(ctx, loginReq.GetMeta()),
|
||||||
@@ -486,7 +489,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
|||||||
|
|
||||||
// if peer has reached this point then it has logged in
|
// if peer has reached this point then it has logged in
|
||||||
loginResp := &proto.LoginResponse{
|
loginResp := &proto.LoginResponse{
|
||||||
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken),
|
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
|
||||||
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(), false),
|
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(), false),
|
||||||
Checks: toProtocolChecks(ctx, postureChecks),
|
Checks: toProtocolChecks(ctx, postureChecks),
|
||||||
}
|
}
|
||||||
@@ -544,7 +547,7 @@ func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toNetbirdConfig(config *Config, turnCredentials *Token, relayToken *Token) *proto.NetbirdConfig {
|
func toNetbirdConfig(config *Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
|
||||||
if config == nil {
|
if config == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -592,7 +595,7 @@ func toNetbirdConfig(config *Config, turnCredentials *Token, relayToken *Token)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &proto.NetbirdConfig{
|
nbConfig := &proto.NetbirdConfig{
|
||||||
Stuns: stuns,
|
Stuns: stuns,
|
||||||
Turns: turns,
|
Turns: turns,
|
||||||
Signal: &proto.HostConfig{
|
Signal: &proto.HostConfig{
|
||||||
@@ -601,6 +604,10 @@ func toNetbirdConfig(config *Config, turnCredentials *Token, relayToken *Token)
|
|||||||
},
|
},
|
||||||
Relay: relayCfg,
|
Relay: relayCfg,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
integrationsConfig.ExtendNetBirdConfig(nbConfig, extraSettings)
|
||||||
|
|
||||||
|
return nbConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, dnsResolutionOnRoutingPeerEnabled bool) *proto.PeerConfig {
|
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, dnsResolutionOnRoutingPeerEnabled bool) *proto.PeerConfig {
|
||||||
@@ -614,10 +621,10 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, dns
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, dnsResolutionOnRoutingPeerEnbled bool) *proto.SyncResponse {
|
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, dnsResolutionOnRoutingPeerEnabled bool, extraSettings *types.ExtraSettings) *proto.SyncResponse {
|
||||||
response := &proto.SyncResponse{
|
response := &proto.SyncResponse{
|
||||||
NetbirdConfig: toNetbirdConfig(config, turnCredentials, relayCredentials),
|
NetbirdConfig: toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings),
|
||||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, dnsResolutionOnRoutingPeerEnbled),
|
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, dnsResolutionOnRoutingPeerEnabled),
|
||||||
NetworkMap: &proto.NetworkMap{
|
NetworkMap: &proto.NetworkMap{
|
||||||
Serial: networkMap.Network.CurrentSerial(),
|
Serial: networkMap.Network.CurrentSerial(),
|
||||||
Routes: toProtocolRoutes(networkMap.Routes),
|
Routes: toProtocolRoutes(networkMap.Routes),
|
||||||
@@ -693,12 +700,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := s.settingsManager.GetSettings(ctx, peer.AccountID, peer.UserID)
|
settings, err := s.settingsManager.GetSettings(ctx, peer.AccountID, activity.SystemInitiator)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.Internal, "error handling request")
|
return status.Errorf(codes.Internal, "error handling request")
|
||||||
}
|
}
|
||||||
|
|
||||||
plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil, settings.RoutingPeerDNSResolutionEnabled)
|
plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil, settings.RoutingPeerDNSResolutionEnabled, settings.Extra)
|
||||||
|
|
||||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -106,6 +106,10 @@ components:
|
|||||||
description: (Cloud only) Enables or disables peer approval globally. If enabled, all peers added will be in pending state until approved by an admin.
|
description: (Cloud only) Enables or disables peer approval globally. If enabled, all peers added will be in pending state until approved by an admin.
|
||||||
type: boolean
|
type: boolean
|
||||||
example: true
|
example: true
|
||||||
|
network_traffic_logs_enabled:
|
||||||
|
description: Enables or disables network traffic logs. If enabled, all network traffic logs from peers will be stored.
|
||||||
|
type: boolean
|
||||||
|
example: true
|
||||||
AccountRequest:
|
AccountRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|||||||
@@ -230,6 +230,9 @@ type Account struct {
|
|||||||
|
|
||||||
// AccountExtraSettings defines model for AccountExtraSettings.
|
// AccountExtraSettings defines model for AccountExtraSettings.
|
||||||
type AccountExtraSettings struct {
|
type AccountExtraSettings struct {
|
||||||
|
// NetworkTrafficLogsEnabled Enables or disables network traffic logs. If enabled, all network traffic logs from peers will be stored.
|
||||||
|
NetworkTrafficLogsEnabled *bool `json:"network_traffic_logs_enabled,omitempty"`
|
||||||
|
|
||||||
// PeerApprovalEnabled (Cloud only) Enables or disables peer approval globally. If enabled, all peers added will be in pending state until approved by an admin.
|
// PeerApprovalEnabled (Cloud only) Enables or disables peer approval globally. If enabled, all peers added will be in pending state until approved by an admin.
|
||||||
PeerApprovalEnabled *bool `json:"peer_approval_enabled,omitempty"`
|
PeerApprovalEnabled *bool `json:"peer_approval_enabled,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,10 +10,12 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
|
||||||
s "github.com/netbirdio/netbird/management/server"
|
|
||||||
"github.com/netbirdio/netbird/management/server/auth"
|
"github.com/netbirdio/netbird/management/server/auth"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
nbgroups "github.com/netbirdio/netbird/management/server/groups"
|
nbgroups "github.com/netbirdio/netbird/management/server/groups"
|
||||||
@@ -41,7 +43,7 @@ const apiPrefix = "/api"
|
|||||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||||
func NewAPIHandler(
|
func NewAPIHandler(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
accountManager s.AccountManager,
|
accountManager account.Manager,
|
||||||
networksManager nbnetworks.Manager,
|
networksManager nbnetworks.Manager,
|
||||||
resourceManager resources.Manager,
|
resourceManager resources.Manager,
|
||||||
routerManager routers.Manager,
|
routerManager routers.Manager,
|
||||||
@@ -53,6 +55,7 @@ func NewAPIHandler(
|
|||||||
proxyController port_forwarding.Controller,
|
proxyController port_forwarding.Controller,
|
||||||
permissionsManager permissions.Manager,
|
permissionsManager permissions.Manager,
|
||||||
peersManager nbpeers.Manager,
|
peersManager nbpeers.Manager,
|
||||||
|
settingsManager settings.Manager,
|
||||||
) (http.Handler, error) {
|
) (http.Handler, error) {
|
||||||
|
|
||||||
authMiddleware := middleware.NewAuthMiddleware(
|
authMiddleware := middleware.NewAuthMiddleware(
|
||||||
@@ -73,11 +76,11 @@ func NewAPIHandler(
|
|||||||
|
|
||||||
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler)
|
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler)
|
||||||
|
|
||||||
if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, integratedValidator, appMetrics.GetMeter(), permissionsManager, peersManager, proxyController); err != nil {
|
if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, integratedValidator, appMetrics.GetMeter(), permissionsManager, peersManager, proxyController, settingsManager); err != nil {
|
||||||
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
accounts.AddEndpoints(accountManager, router)
|
accounts.AddEndpoints(accountManager, settingsManager, router)
|
||||||
peers.AddEndpoints(accountManager, router)
|
peers.AddEndpoints(accountManager, router)
|
||||||
users.AddEndpoints(accountManager, router)
|
users.AddEndpoints(accountManager, router)
|
||||||
setup_keys.AddEndpoints(accountManager, router)
|
setup_keys.AddEndpoints(accountManager, router)
|
||||||
|
|||||||
@@ -7,31 +7,33 @@ import (
|
|||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// handler is a handler that handles the server.Account HTTP endpoints
|
// handler is a handler that handles the server.Account HTTP endpoints
|
||||||
type handler struct {
|
type handler struct {
|
||||||
accountManager server.AccountManager
|
accountManager account.Manager
|
||||||
|
settingsManager settings.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
|
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router) {
|
||||||
accountsHandler := newHandler(accountManager)
|
accountsHandler := newHandler(accountManager, settingsManager)
|
||||||
router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS")
|
router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS")
|
||||||
router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS")
|
router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS")
|
||||||
router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS")
|
router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS")
|
||||||
}
|
}
|
||||||
|
|
||||||
// newHandler creates a new handler HTTP handler
|
// newHandler creates a new handler HTTP handler
|
||||||
func newHandler(accountManager server.AccountManager) *handler {
|
func newHandler(accountManager account.Manager, settingsManager settings.Manager) *handler {
|
||||||
return &handler{
|
return &handler{
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
|
settingsManager: settingsManager,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -45,7 +47,7 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||||
|
|
||||||
settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, userID)
|
settings, err := h.settingsManager.GetSettings(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@@ -89,7 +91,14 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if req.Settings.Extra != nil {
|
if req.Settings.Extra != nil {
|
||||||
settings.Extra = &account.ExtraSettings{PeerApprovalEnabled: *req.Settings.Extra.PeerApprovalEnabled}
|
flowEnabled := false
|
||||||
|
if req.Settings.Extra.NetworkTrafficLogsEnabled != nil {
|
||||||
|
flowEnabled = *req.Settings.Extra.NetworkTrafficLogsEnabled
|
||||||
|
}
|
||||||
|
settings.Extra = &types.ExtraSettings{
|
||||||
|
PeerApprovalEnabled: *req.Settings.Extra.PeerApprovalEnabled,
|
||||||
|
FlowEnabled: flowEnabled,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Settings.JwtGroupsEnabled != nil {
|
if req.Settings.JwtGroupsEnabled != nil {
|
||||||
@@ -163,7 +172,10 @@ func toAccountResponse(accountID string, settings *types.Settings) *api.Account
|
|||||||
}
|
}
|
||||||
|
|
||||||
if settings.Extra != nil {
|
if settings.Extra != nil {
|
||||||
apiSettings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &settings.Extra.PeerApprovalEnabled}
|
apiSettings.Extra = &api.AccountExtraSettings{
|
||||||
|
PeerApprovalEnabled: &settings.Extra.PeerApprovalEnabled,
|
||||||
|
NetworkTrafficLogsEnabled: &settings.Extra.FlowEnabled,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &api.Account{
|
return &api.Account{
|
||||||
|
|||||||
@@ -16,11 +16,16 @@ import (
|
|||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
func initAccountsTestData(account *types.Account) *handler {
|
func initAccountsTestData(account *types.Account) *handler {
|
||||||
|
settingsMock := settings.NewManagerMock()
|
||||||
|
settingsMock.GetSettingsFunc = func(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
|
||||||
|
return account.Settings, nil
|
||||||
|
}
|
||||||
return &handler{
|
return &handler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
|
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
|
||||||
@@ -41,6 +46,7 @@ func initAccountsTestData(account *types.Account) *handler {
|
|||||||
return accCopy, nil
|
return accCopy, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
settingsManager: settingsMock,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
@@ -16,22 +16,22 @@ import (
|
|||||||
|
|
||||||
// dnsSettingsHandler is a handler that returns the DNS settings of the account
|
// dnsSettingsHandler is a handler that returns the DNS settings of the account
|
||||||
type dnsSettingsHandler struct {
|
type dnsSettingsHandler struct {
|
||||||
accountManager server.AccountManager
|
accountManager account.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
|
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||||
addDNSSettingEndpoint(accountManager, router)
|
addDNSSettingEndpoint(accountManager, router)
|
||||||
addDNSNameserversEndpoint(accountManager, router)
|
addDNSNameserversEndpoint(accountManager, router)
|
||||||
}
|
}
|
||||||
|
|
||||||
func addDNSSettingEndpoint(accountManager server.AccountManager, router *mux.Router) {
|
func addDNSSettingEndpoint(accountManager account.Manager, router *mux.Router) {
|
||||||
dnsSettingsHandler := newDNSSettingsHandler(accountManager)
|
dnsSettingsHandler := newDNSSettingsHandler(accountManager)
|
||||||
router.HandleFunc("/dns/settings", dnsSettingsHandler.getDNSSettings).Methods("GET", "OPTIONS")
|
router.HandleFunc("/dns/settings", dnsSettingsHandler.getDNSSettings).Methods("GET", "OPTIONS")
|
||||||
router.HandleFunc("/dns/settings", dnsSettingsHandler.updateDNSSettings).Methods("PUT", "OPTIONS")
|
router.HandleFunc("/dns/settings", dnsSettingsHandler.updateDNSSettings).Methods("PUT", "OPTIONS")
|
||||||
}
|
}
|
||||||
|
|
||||||
// newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler
|
// newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler
|
||||||
func newDNSSettingsHandler(accountManager server.AccountManager) *dnsSettingsHandler {
|
func newDNSSettingsHandler(accountManager account.Manager) *dnsSettingsHandler {
|
||||||
return &dnsSettingsHandler{accountManager: accountManager}
|
return &dnsSettingsHandler{accountManager: accountManager}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
@@ -18,10 +18,10 @@ import (
|
|||||||
|
|
||||||
// nameserversHandler is the nameserver group handler of the account
|
// nameserversHandler is the nameserver group handler of the account
|
||||||
type nameserversHandler struct {
|
type nameserversHandler struct {
|
||||||
accountManager server.AccountManager
|
accountManager account.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func addDNSNameserversEndpoint(accountManager server.AccountManager, router *mux.Router) {
|
func addDNSNameserversEndpoint(accountManager account.Manager, router *mux.Router) {
|
||||||
nameserversHandler := newNameserversHandler(accountManager)
|
nameserversHandler := newNameserversHandler(accountManager)
|
||||||
router.HandleFunc("/dns/nameservers", nameserversHandler.getAllNameservers).Methods("GET", "OPTIONS")
|
router.HandleFunc("/dns/nameservers", nameserversHandler.getAllNameservers).Methods("GET", "OPTIONS")
|
||||||
router.HandleFunc("/dns/nameservers", nameserversHandler.createNameserverGroup).Methods("POST", "OPTIONS")
|
router.HandleFunc("/dns/nameservers", nameserversHandler.createNameserverGroup).Methods("POST", "OPTIONS")
|
||||||
@@ -31,7 +31,7 @@ func addDNSNameserversEndpoint(accountManager server.AccountManager, router *mux
|
|||||||
}
|
}
|
||||||
|
|
||||||
// newNameserversHandler returns a new instance of nameserversHandler handler
|
// newNameserversHandler returns a new instance of nameserversHandler handler
|
||||||
func newNameserversHandler(accountManager server.AccountManager) *nameserversHandler {
|
func newNameserversHandler(accountManager account.Manager) *nameserversHandler {
|
||||||
return &nameserversHandler{accountManager: accountManager}
|
return &nameserversHandler{accountManager: accountManager}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
@@ -17,16 +17,16 @@ import (
|
|||||||
|
|
||||||
// handler HTTP handler
|
// handler HTTP handler
|
||||||
type handler struct {
|
type handler struct {
|
||||||
accountManager server.AccountManager
|
accountManager account.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
|
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||||
eventsHandler := newHandler(accountManager)
|
eventsHandler := newHandler(accountManager)
|
||||||
router.HandleFunc("/events", eventsHandler.getAllEvents).Methods("GET", "OPTIONS")
|
router.HandleFunc("/events", eventsHandler.getAllEvents).Methods("GET", "OPTIONS")
|
||||||
}
|
}
|
||||||
|
|
||||||
// newHandler creates a new events handler
|
// newHandler creates a new events handler
|
||||||
func newHandler(accountManager server.AccountManager) *handler {
|
func newHandler(accountManager account.Manager) *handler {
|
||||||
return &handler{accountManager: accountManager}
|
return &handler{accountManager: accountManager}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,10 +7,10 @@ import (
|
|||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
@@ -19,10 +19,10 @@ import (
|
|||||||
|
|
||||||
// handler is a handler that returns groups of the account
|
// handler is a handler that returns groups of the account
|
||||||
type handler struct {
|
type handler struct {
|
||||||
accountManager server.AccountManager
|
accountManager account.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
|
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||||
groupsHandler := newHandler(accountManager)
|
groupsHandler := newHandler(accountManager)
|
||||||
router.HandleFunc("/groups", groupsHandler.getAllGroups).Methods("GET", "OPTIONS")
|
router.HandleFunc("/groups", groupsHandler.getAllGroups).Methods("GET", "OPTIONS")
|
||||||
router.HandleFunc("/groups", groupsHandler.createGroup).Methods("POST", "OPTIONS")
|
router.HandleFunc("/groups", groupsHandler.createGroup).Methods("POST", "OPTIONS")
|
||||||
@@ -32,7 +32,7 @@ func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// newHandler creates a new groups handler
|
// newHandler creates a new groups handler
|
||||||
func newHandler(accountManager server.AccountManager) *handler {
|
func newHandler(accountManager account.Manager) *handler {
|
||||||
return &handler{
|
return &handler{
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
s "github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
@@ -28,12 +28,12 @@ type handler struct {
|
|||||||
networksManager networks.Manager
|
networksManager networks.Manager
|
||||||
resourceManager resources.Manager
|
resourceManager resources.Manager
|
||||||
routerManager routers.Manager
|
routerManager routers.Manager
|
||||||
accountManager s.AccountManager
|
accountManager account.Manager
|
||||||
|
|
||||||
groupsManager groups.Manager
|
groupsManager groups.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, router *mux.Router) {
|
func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager, router *mux.Router) {
|
||||||
addRouterEndpoints(routerManager, router)
|
addRouterEndpoints(routerManager, router)
|
||||||
addResourceEndpoints(resourceManager, groupsManager, router)
|
addResourceEndpoints(resourceManager, groupsManager, router)
|
||||||
|
|
||||||
@@ -45,7 +45,7 @@ func AddEndpoints(networksManager networks.Manager, resourceManager resources.Ma
|
|||||||
router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS")
|
router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager) *handler {
|
func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager) *handler {
|
||||||
return &handler{
|
return &handler{
|
||||||
networksManager: networksManager,
|
networksManager: networksManager,
|
||||||
resourceManager: resourceManager,
|
resourceManager: resourceManager,
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user