Compare commits

...

19 Commits

Author SHA1 Message Date
bcmmbaga
ed3f549072 Merge branch 'main' into traffic-correlation-policy 2025-05-19 15:10:22 +03:00
bcmmbaga
c953e141bd Merge branch 'unidirectional-policy-all-ports' into traffic-correlation-policy 2025-05-19 11:27:54 +03:00
bcmmbaga
23fefab827 Merge branch 'flow-events-correlation' into traffic-correlation-policy 2025-05-19 11:27:25 +03:00
bcmmbaga
c33fc4f2e9 Merge branch 'policy-port-ranges' into traffic-correlation-policy 2025-05-19 11:27:17 +03:00
bcmmbaga
a1f8076dae Merge branch 'main' into flow-events-correlation 2025-05-18 14:35:48 +03:00
bcmmbaga
7275731762 allow unidirectional policy for all ports
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-14 22:01:50 +03:00
bcmmbaga
aa41e0d6af add tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-14 17:34:44 +03:00
bcmmbaga
1a734f0a50 add port range support in firewall rules
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-14 15:03:56 +03:00
bcmmbaga
a76f075e13 Merge branch 'main' into flow-events-correlation 2025-05-13 17:36:55 +03:00
bcmmbaga
de7cf68169 regenerate api specs
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-04-30 12:59:00 +03:00
bcmmbaga
6ada015360 Merge branch 'main' into flow-events-correlation
# Conflicts:
#	management/server/http/api/openapi.yml
2025-04-30 12:54:29 +03:00
bcmmbaga
bf091cdaff Fix func signature
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-04-25 17:54:46 +03:00
bcmmbaga
34e74ffb8a fix sonar
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-04-22 15:10:49 +03:00
bcmmbaga
f2e2e93bea Reuse MySQL and Postgres containers in tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-04-22 14:49:21 +03:00
bcmmbaga
1dbca5a772 fix test containers cleanup
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-04-22 12:36:11 +03:00
bcmmbaga
c32f02ac6d Refactor test container functions to return DSN string
Updated MySQL and PostgreSQL test container functions to return the DSN string instead of setting environment variables

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-04-18 23:46:01 +03:00
bcmmbaga
cab9b4caf4 Update user properties to remove nullable types
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-04-16 16:10:18 +03:00
bcmmbaga
446ad0e9de Remove id and received timestamp from correlated event
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-04-16 00:55:38 +03:00
bcmmbaga
263bf53c80 update network traffic schema
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-04-15 12:38:37 +03:00
8 changed files with 253 additions and 141 deletions

View File

@@ -1920,13 +1920,71 @@ components:
- os - os
- address - address
- dns_label - dns_label
NetworkTrafficEvent: NetworkTrafficUser:
type: object type: object
properties: properties:
id: id:
type: string type: string
description: "ID of the event. Unique." description: "UserID is the ID of the user that initiated the event (can be empty as not every event is user-initiated)."
example: "18e204d6-f7c6-405d-8025-70becb216add" example: "google-oauth2|123456789012345678901"
email:
type: string
description: "Email of the user who initiated the event (if any)."
example: "alice@netbird.io"
name:
type: string
description: "Name of the user who initiated the event (if any)."
example: "Alice Smith"
required:
- id
- email
- name
NetworkTrafficPolicy:
type: object
properties:
id:
type: string
description: "ID of the policy that allowed this event."
example: "ch8i4ug6lnn4g9hqv7m0"
name:
type: string
description: "Name of the policy that allowed this event."
example: "All to All"
required:
- id
- name
NetworkTrafficICMP:
type: object
properties:
type:
type: integer
description: "ICMP type (if applicable)."
example: 8
code:
type: integer
description: "ICMP code (if applicable)."
example: 0
required:
- type
- code
NetworkTrafficSubEvent:
type: object
properties:
type:
type: string
description: Type of the event (e.g., TYPE_UNKNOWN, TYPE_START, TYPE_END, TYPE_DROP).
example: TYPE_START
timestamp:
type: string
format: date-time
description: Timestamp of the event as sent by the peer.
example: 2025-03-20T16:23:58.125397Z
required:
- type
- timestamp
NetworkTrafficEvent:
type: object
properties:
flow_id: flow_id:
type: string type: string
description: "FlowID is the ID of the connection flow. Not unique because it can be the same for multiple events (e.g., start and end of the connection)." description: "FlowID is the ID of the connection flow. Not unique because it can be the same for multiple events (e.g., start and end of the connection)."
@@ -1935,43 +1993,20 @@ components:
type: string type: string
description: "ID of the reporter of the event (e.g., the peer that reported the event)." description: "ID of the reporter of the event (e.g., the peer that reported the event)."
example: "ch8i4ug6lnn4g9hqv7m0" example: "ch8i4ug6lnn4g9hqv7m0"
timestamp:
type: string
format: date-time
description: "Timestamp of the event. Send by the peer."
example: "2025-03-20T16:23:58.125397Z"
receive_timestamp:
type: string
format: date-time
description: "Timestamp when the event was received by our API."
example: "2025-03-20T16:23:58.125397Z"
source: source:
$ref: '#/components/schemas/NetworkTrafficEndpoint' $ref: '#/components/schemas/NetworkTrafficEndpoint'
user_id:
type: string
nullable: true
description: "UserID is the ID of the user that initiated the event (can be empty as not every event is user-initiated)."
example: "google-oauth2|123456789012345678901"
user_email:
type: string
nullable: true
description: "Email of the user who initiated the event (if any)."
example: "alice@netbird.io"
user_name:
type: string
nullable: true
description: "Name of the user who initiated the event (if any)."
example: "Alice Smith"
destination: destination:
$ref: '#/components/schemas/NetworkTrafficEndpoint' $ref: '#/components/schemas/NetworkTrafficEndpoint'
user:
$ref: '#/components/schemas/NetworkTrafficUser'
policy:
$ref: '#/components/schemas/NetworkTrafficPolicy'
icmp:
$ref: '#/components/schemas/NetworkTrafficICMP'
protocol: protocol:
type: integer type: integer
description: "Protocol is the protocol of the traffic (e.g. 1 = ICMP, 6 = TCP, 17 = UDP, etc.)." description: "Protocol is the protocol of the traffic (e.g. 1 = ICMP, 6 = TCP, 17 = UDP, etc.)."
example: 6 example: 6
type:
type: string
description: "Type of the event (e.g. TYPE_UNKNOWN, TYPE_START, TYPE_END, TYPE_DROP)."
example: "TYPE_START"
direction: direction:
type: string type: string
description: "Direction of the traffic (e.g. DIRECTION_UNKNOWN, INGRESS, EGRESS)." description: "Direction of the traffic (e.g. DIRECTION_UNKNOWN, INGRESS, EGRESS)."
@@ -1992,43 +2027,28 @@ components:
type: integer type: integer
description: "Number of packets transmitted." description: "Number of packets transmitted."
example: 5 example: 5
policy_id: events:
type: string type: array
description: "ID of the policy that allowed this event." description: "List of events that are correlated to this flow (e.g., start, end)."
example: "ch8i4ug6lnn4g9hqv7m0" items:
policy_name: $ref: '#/components/schemas/NetworkTrafficSubEvent'
type: string
description: "Name of the policy that allowed this event."
example: "All to All"
icmp_type:
type: integer
description: "ICMP type (if applicable)."
example: 8
icmp_code:
type: integer
description: "ICMP code (if applicable)."
example: 0
required: required:
- id - id
- flow_id - flow_id
- reporter_id - reporter_id
- timestamp
- receive_timestamp - receive_timestamp
- source - source
- user_id
- user_email
- destination - destination
- user
- policy
- icmp
- protocol - protocol
- type
- direction - direction
- rx_bytes - rx_bytes
- rx_packets - rx_packets
- tx_bytes - tx_bytes
- tx_packets - tx_packets
- policy_id - events
- policy_name
- icmp_type
- icmp_code
NetworkTrafficEventsResponse: NetworkTrafficEventsResponse:
type: object type: object
properties: properties:

View File

@@ -880,30 +880,17 @@ type NetworkTrafficEvent struct {
// Direction Direction of the traffic (e.g. DIRECTION_UNKNOWN, INGRESS, EGRESS). // Direction Direction of the traffic (e.g. DIRECTION_UNKNOWN, INGRESS, EGRESS).
Direction string `json:"direction"` Direction string `json:"direction"`
// Events List of events that are correlated to this flow (e.g., start, end).
Events []NetworkTrafficSubEvent `json:"events"`
// FlowId FlowID is the ID of the connection flow. Not unique because it can be the same for multiple events (e.g., start and end of the connection). // FlowId FlowID is the ID of the connection flow. Not unique because it can be the same for multiple events (e.g., start and end of the connection).
FlowId string `json:"flow_id"` FlowId string `json:"flow_id"`
Icmp NetworkTrafficICMP `json:"icmp"`
// IcmpCode ICMP code (if applicable). Policy NetworkTrafficPolicy `json:"policy"`
IcmpCode int `json:"icmp_code"`
// IcmpType ICMP type (if applicable).
IcmpType int `json:"icmp_type"`
// Id ID of the event. Unique.
Id string `json:"id"`
// PolicyId ID of the policy that allowed this event.
PolicyId string `json:"policy_id"`
// PolicyName Name of the policy that allowed this event.
PolicyName string `json:"policy_name"`
// Protocol Protocol is the protocol of the traffic (e.g. 1 = ICMP, 6 = TCP, 17 = UDP, etc.). // Protocol Protocol is the protocol of the traffic (e.g. 1 = ICMP, 6 = TCP, 17 = UDP, etc.).
Protocol int `json:"protocol"` Protocol int `json:"protocol"`
// ReceiveTimestamp Timestamp when the event was received by our API.
ReceiveTimestamp time.Time `json:"receive_timestamp"`
// ReporterId ID of the reporter of the event (e.g., the peer that reported the event). // ReporterId ID of the reporter of the event (e.g., the peer that reported the event).
ReporterId string `json:"reporter_id"` ReporterId string `json:"reporter_id"`
@@ -914,26 +901,12 @@ type NetworkTrafficEvent struct {
RxPackets int `json:"rx_packets"` RxPackets int `json:"rx_packets"`
Source NetworkTrafficEndpoint `json:"source"` Source NetworkTrafficEndpoint `json:"source"`
// Timestamp Timestamp of the event. Send by the peer.
Timestamp time.Time `json:"timestamp"`
// TxBytes Number of bytes transmitted. // TxBytes Number of bytes transmitted.
TxBytes int `json:"tx_bytes"` TxBytes int `json:"tx_bytes"`
// TxPackets Number of packets transmitted. // TxPackets Number of packets transmitted.
TxPackets int `json:"tx_packets"` TxPackets int `json:"tx_packets"`
User NetworkTrafficUser `json:"user"`
// Type Type of the event (e.g. TYPE_UNKNOWN, TYPE_START, TYPE_END, TYPE_DROP).
Type string `json:"type"`
// UserEmail Email of the user who initiated the event (if any).
UserEmail *string `json:"user_email"`
// UserId UserID is the ID of the user that initiated the event (can be empty as not every event is user-initiated).
UserId *string `json:"user_id"`
// UserName Name of the user who initiated the event (if any).
UserName *string `json:"user_name"`
} }
// NetworkTrafficEventsResponse defines model for NetworkTrafficEventsResponse. // NetworkTrafficEventsResponse defines model for NetworkTrafficEventsResponse.
@@ -954,6 +927,15 @@ type NetworkTrafficEventsResponse struct {
TotalRecords int `json:"total_records"` TotalRecords int `json:"total_records"`
} }
// NetworkTrafficICMP defines model for NetworkTrafficICMP.
type NetworkTrafficICMP struct {
// Code ICMP code (if applicable).
Code int `json:"code"`
// Type ICMP type (if applicable).
Type int `json:"type"`
}
// NetworkTrafficLocation defines model for NetworkTrafficLocation. // NetworkTrafficLocation defines model for NetworkTrafficLocation.
type NetworkTrafficLocation struct { type NetworkTrafficLocation struct {
// CityName Name of the city (if known). // CityName Name of the city (if known).
@@ -963,6 +945,36 @@ type NetworkTrafficLocation struct {
CountryCode string `json:"country_code"` CountryCode string `json:"country_code"`
} }
// NetworkTrafficPolicy defines model for NetworkTrafficPolicy.
type NetworkTrafficPolicy struct {
// Id ID of the policy that allowed this event.
Id string `json:"id"`
// Name Name of the policy that allowed this event.
Name string `json:"name"`
}
// NetworkTrafficSubEvent defines model for NetworkTrafficSubEvent.
type NetworkTrafficSubEvent struct {
// Timestamp Timestamp of the event as sent by the peer.
Timestamp time.Time `json:"timestamp"`
// Type Type of the event (e.g., TYPE_UNKNOWN, TYPE_START, TYPE_END, TYPE_DROP).
Type string `json:"type"`
}
// NetworkTrafficUser defines model for NetworkTrafficUser.
type NetworkTrafficUser struct {
// Email Email of the user who initiated the event (if any).
Email string `json:"email"`
// Id UserID is the ID of the user that initiated the event (can be empty as not every event is user-initiated).
Id string `json:"id"`
// Name Name of the user who initiated the event (if any).
Name string `json:"name"`
}
// OSVersionCheck Posture check for the version of operating system // OSVersionCheck Posture check for the version of operating system
type OSVersionCheck struct { type OSVersionCheck struct {
// Android Posture check for the version of operating system // Android Posture check for the version of operating system

View File

@@ -255,23 +255,12 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
} }
// validate policy object // validate policy object
switch pr.Protocol { if pr.Protocol == types.PolicyRuleProtocolALL || pr.Protocol == types.PolicyRuleProtocolICMP {
case types.PolicyRuleProtocolALL, types.PolicyRuleProtocolICMP:
if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 { if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w)
return return
} }
if !pr.Bidirectional {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
return
} }
case types.PolicyRuleProtocolTCP, types.PolicyRuleProtocolUDP:
if !pr.Bidirectional && (len(pr.Ports) == 0 || len(pr.PortRanges) != 0) {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
return
}
}
policy.Rules = append(policy.Rules, &pr) policy.Rules = append(policy.Rules, &pr)
} }

View File

@@ -58,6 +58,11 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
IP: net.ParseIP("100.65.29.55"), IP: net.ParseIP("100.65.29.55"),
Status: &nbpeer.PeerStatus{}, Status: &nbpeer.PeerStatus{},
}, },
"peerI": {
ID: "peerI",
IP: net.ParseIP("100.65.31.2"),
Status: &nbpeer.PeerStatus{},
},
}, },
Groups: map[string]*types.Group{ Groups: map[string]*types.Group{
"GroupAll": { "GroupAll": {
@@ -99,6 +104,13 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
"peerH", "peerH",
}, },
}, },
"GroupDMZ": {
ID: "GroupDMZ",
Name: "dmz",
Peers: []string{
"peerI",
},
},
}, },
Policies: []*types.Policy{ Policies: []*types.Policy{
{ {
@@ -148,6 +160,35 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
}, },
}, },
}, },
{
ID: "RuleDMZ",
Name: "Dmz",
Description: "No description",
Enabled: true,
Rules: []*types.PolicyRule{
{
ID: "RuleDMZ",
Name: "Dmz",
Description: "No description",
Bidirectional: true,
Enabled: true,
Protocol: types.PolicyRuleProtocolTCP,
Action: types.PolicyTrafficActionAccept,
PortRanges: []types.RulePortRange{
{
Start: 8080,
End: 8083,
},
},
Sources: []string{
"GroupWorkstations",
},
Destinations: []string{
"GroupDMZ",
},
},
},
},
}, },
} }
@@ -166,7 +207,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
t.Run("check first peer map details", func(t *testing.T) { t.Run("check first peer map details", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers)
assert.Len(t, peers, 7) assert.Len(t, peers, 8)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
assert.Contains(t, peers, account.Peers["peerD"]) assert.Contains(t, peers, account.Peers["peerD"])
@@ -174,8 +215,9 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
assert.Contains(t, peers, account.Peers["peerF"]) assert.Contains(t, peers, account.Peers["peerF"])
assert.Contains(t, peers, account.Peers["peerG"]) assert.Contains(t, peers, account.Peers["peerG"])
assert.Contains(t, peers, account.Peers["peerH"]) assert.Contains(t, peers, account.Peers["peerH"])
assert.Contains(t, peers, account.Peers["peerI"])
epectedFirewallRules := []*types.FirewallRule{ expectedFirewallRules := []*types.FirewallRule{
{ {
PeerIP: "0.0.0.0", PeerIP: "0.0.0.0",
Direction: types.FirewallRuleDirectionIN, Direction: types.FirewallRuleDirectionIN,
@@ -292,12 +334,28 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
Port: "", Port: "",
PolicyID: "RuleSwarm", PolicyID: "RuleSwarm",
}, },
{
PeerIP: "100.65.31.2",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "tcp",
PortRange: types.RulePortRange{Start: 8080, End: 8083},
PolicyID: "RuleDMZ",
},
{
PeerIP: "100.65.31.2",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
PortRange: types.RulePortRange{Start: 8080, End: 8083},
PolicyID: "RuleDMZ",
},
} }
assert.Len(t, firewallRules, len(epectedFirewallRules)) assert.Len(t, firewallRules, len(expectedFirewallRules))
for _, rule := range firewallRules { for _, rule := range firewallRules {
contains := false contains := false
for _, expectedRule := range epectedFirewallRules { for _, expectedRule := range expectedFirewallRules {
if rule.Equal(expectedRule) { if rule.Equal(expectedRule) {
contains = true contains = true
break break

View File

@@ -364,11 +364,14 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (
return nil, nil, fmt.Errorf("failed to add all group to account: %v", err) return nil, nil, fmt.Errorf("failed to add all group to account: %v", err)
} }
var sqlStore Store
var cleanup func()
maxRetries := 2 maxRetries := 2
for i := 0; i < maxRetries; i++ { for i := 0; i < maxRetries; i++ {
sqlStore, cleanUp, err := getSqlStoreEngine(ctx, store, kind) sqlStore, cleanup, err = getSqlStoreEngine(ctx, store, kind)
if err == nil { if err == nil {
return sqlStore, cleanUp, nil return sqlStore, cleanup, nil
} }
if i < maxRetries-1 { if i < maxRetries-1 {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
@@ -426,16 +429,16 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind types.Engine)
} }
func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) { func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) {
if envDsn, ok := os.LookupEnv(postgresDsnEnv); !ok || envDsn == "" { dsn, ok := os.LookupEnv(postgresDsnEnv)
if !ok || dsn == "" {
var err error var err error
_, err = testutil.CreatePostgresTestContainer() _, dsn, err = testutil.CreatePostgresTestContainer()
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
} }
dsn, ok := os.LookupEnv(postgresDsnEnv) if dsn == "" {
if !ok {
return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv) return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv)
} }
@@ -446,28 +449,28 @@ func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Eng
dsn, cleanup, err := createRandomDB(dsn, db, kind) dsn, cleanup, err := createRandomDB(dsn, db, kind)
if err != nil { if err != nil {
return nil, cleanup, err return nil, nil, err
} }
store, err = NewPostgresqlStoreFromSqlStore(ctx, store, dsn, nil) store, err = NewPostgresqlStoreFromSqlStore(ctx, store, dsn, nil)
if err != nil { if err != nil {
return nil, cleanup, err return nil, nil, err
} }
return store, cleanup, nil return store, cleanup, nil
} }
func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) { func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) {
if envDsn, ok := os.LookupEnv(mysqlDsnEnv); !ok || envDsn == "" { dsn, ok := os.LookupEnv(mysqlDsnEnv)
if !ok || dsn == "" {
var err error var err error
_, err = testutil.CreateMysqlTestContainer() _, dsn, err = testutil.CreateMysqlTestContainer()
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
} }
dsn, ok := os.LookupEnv(mysqlDsnEnv) if dsn == "" {
if !ok {
return nil, nil, fmt.Errorf("%s is not set", mysqlDsnEnv) return nil, nil, fmt.Errorf("%s is not set", mysqlDsnEnv)
} }
@@ -478,7 +481,7 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine
dsn, cleanup, err := createRandomDB(dsn, db, kind) dsn, cleanup, err := createRandomDB(dsn, db, kind)
if err != nil { if err != nil {
return nil, cleanup, err return nil, nil, err
} }
store, err = NewMysqlStoreFromSqlStore(ctx, store, dsn, nil) store, err = NewMysqlStoreFromSqlStore(ctx, store, dsn, nil)

View File

@@ -5,7 +5,6 @@ package testutil
import ( import (
"context" "context"
"os"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -16,11 +15,25 @@ import (
"github.com/testcontainers/testcontainers-go/wait" "github.com/testcontainers/testcontainers-go/wait"
) )
var (
pgContainer *postgres.PostgresContainer
mysqlContainer *mysql.MySQLContainer
)
// CreateMysqlTestContainer creates a new MySQL container for testing. // CreateMysqlTestContainer creates a new MySQL container for testing.
func CreateMysqlTestContainer() (func(), error) { func CreateMysqlTestContainer() (func(), string, error) {
ctx := context.Background() ctx := context.Background()
myContainer, err := mysql.RunContainer(ctx, if mysqlContainer != nil {
connStr, err := mysqlContainer.ConnectionString(ctx)
if err != nil {
return nil, "", err
}
return noOpCleanup, connStr, nil
}
var err error
mysqlContainer, err = mysql.RunContainer(ctx,
testcontainers.WithImage("mlsmaycon/warmed-mysql:8"), testcontainers.WithImage("mlsmaycon/warmed-mysql:8"),
mysql.WithDatabase("testing"), mysql.WithDatabase("testing"),
mysql.WithUsername("root"), mysql.WithUsername("root"),
@@ -31,31 +44,39 @@ func CreateMysqlTestContainer() (func(), error) {
), ),
) )
if err != nil { if err != nil {
return nil, err return nil, "", err
} }
cleanup := func() { cleanup := func() {
os.Unsetenv("NETBIRD_STORE_ENGINE_MYSQL_DSN")
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second) timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second)
defer cancelFunc() defer cancelFunc()
if err = myContainer.Terminate(timeoutCtx); err != nil { if err = mysqlContainer.Terminate(timeoutCtx); err != nil {
log.WithContext(ctx).Warnf("failed to stop mysql container %s: %s", myContainer.GetContainerID(), err) log.WithContext(ctx).Warnf("failed to stop mysql container %s: %s", mysqlContainer.GetContainerID(), err)
} }
} }
talksConn, err := myContainer.ConnectionString(ctx) talksConn, err := mysqlContainer.ConnectionString(ctx)
if err != nil { if err != nil {
return nil, err return nil, "", err
} }
return cleanup, os.Setenv("NETBIRD_STORE_ENGINE_MYSQL_DSN", talksConn) return cleanup, talksConn, nil
} }
// CreatePostgresTestContainer creates a new PostgreSQL container for testing. // CreatePostgresTestContainer creates a new PostgreSQL container for testing.
func CreatePostgresTestContainer() (func(), error) { func CreatePostgresTestContainer() (func(), string, error) {
ctx := context.Background() ctx := context.Background()
pgContainer, err := postgres.RunContainer(ctx, if pgContainer != nil {
connStr, err := pgContainer.ConnectionString(ctx)
if err != nil {
return nil, "", err
}
return noOpCleanup, connStr, nil
}
var err error
pgContainer, err = postgres.RunContainer(ctx,
testcontainers.WithImage("postgres:16-alpine"), testcontainers.WithImage("postgres:16-alpine"),
postgres.WithDatabase("netbird"), postgres.WithDatabase("netbird"),
postgres.WithUsername("root"), postgres.WithUsername("root"),
@@ -66,11 +87,10 @@ func CreatePostgresTestContainer() (func(), error) {
), ),
) )
if err != nil { if err != nil {
return nil, err return nil, "", err
} }
cleanup := func() { cleanup := func() {
os.Unsetenv("NETBIRD_STORE_ENGINE_POSTGRES_DSN")
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second) timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second)
defer cancelFunc() defer cancelFunc()
if err = pgContainer.Terminate(timeoutCtx); err != nil { if err = pgContainer.Terminate(timeoutCtx); err != nil {
@@ -80,10 +100,14 @@ func CreatePostgresTestContainer() (func(), error) {
talksConn, err := pgContainer.ConnectionString(ctx) talksConn, err := pgContainer.ConnectionString(ctx)
if err != nil { if err != nil {
return nil, err return nil, "", err
} }
return cleanup, os.Setenv("NETBIRD_STORE_ENGINE_POSTGRES_DSN", talksConn) return cleanup, talksConn, nil
}
func noOpCleanup() {
// no-op
} }
// CreateRedisTestContainer creates a new Redis container for testing. // CreateRedisTestContainer creates a new Redis container for testing.

View File

@@ -3,16 +3,16 @@
package testutil package testutil
func CreatePostgresTestContainer() (func(), error) { func CreatePostgresTestContainer() (func(), string, error) {
return func() { return func() {
// Empty function for Postgres // Empty function for Postgres
}, nil }, "", nil
} }
func CreateMysqlTestContainer() (func(), error) { func CreateMysqlTestContainer() (func(), string, error) {
return func() { return func() {
// Empty function for MySQL // Empty function for MySQL
}, nil }, "", nil
} }
func CreateRedisTestContainer() (func(), string, error) { func CreateRedisTestContainer() (func(), string, error) {

View File

@@ -1046,7 +1046,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
} }
rulesExists[ruleID] = struct{}{} rulesExists[ruleID] = struct{}{}
if len(rule.Ports) == 0 { if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 {
rules = append(rules, &fr) rules = append(rules, &fr)
continue continue
} }
@@ -1056,6 +1056,12 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
pr.Port = port pr.Port = port
rules = append(rules, &pr) rules = append(rules, &pr)
} }
for _, portRange := range rule.PortRanges {
pr := fr
pr.PortRange = portRange
rules = append(rules, &pr)
}
} }
}, func() ([]*nbpeer.Peer, []*FirewallRule) { }, func() ([]*nbpeer.Peer, []*FirewallRule) {
return peers, rules return peers, rules