Compare commits

...

9 Commits

Author SHA1 Message Date
Zoltan Papp
aa07b3b87b Fix deadlock (#3904) 2025-05-30 23:38:02 +02:00
Bethuel Mmbaga
2bef214cc0 [management] Fix user groups propagation (#3902) 2025-05-30 18:12:30 +03:00
hakansa
cfb2d82352 [client] Refactor exclude list handling to use a map for permanent connections (#3901)
[client] Refactor exclude list handling to use a map for permanent connections (#3901)
2025-05-30 16:54:49 +03:00
Bethuel Mmbaga
684501fd35 [management] Prevent deletion of peers linked to network routers (#3881)
- Prevent deletion of peers linked to network routers
- Add API endpoint to list all network routers
2025-05-29 18:50:00 +03:00
Zoltan Papp
0492c1724a [client, android] Fix/notifier threading (#3807)
- Fix potential deadlocks
- When adding a listener, immediately notify with the last known IP and fqdn.
2025-05-27 17:12:04 +02:00
Zoltan Papp
6f436e57b5 [server-test] Install libs for i386 tests (#3887)
Install libs for i386 tests
2025-05-27 16:42:06 +02:00
Bethuel Mmbaga
a0d28f9851 [management] Reset test containers after cleanup (#3885) 2025-05-27 14:42:00 +03:00
Zoltan Papp
cdd27a9fe5 [client, android] Fix/android enable server route (#3806)
Enable the server route; otherwise, the manager throws an error and the engine will restart.
2025-05-27 13:32:54 +02:00
Bethuel Mmbaga
5523040acd [management] Add correlated network traffic event schema (#3680) 2025-05-27 13:47:53 +03:00
15 changed files with 357 additions and 209 deletions

View File

@@ -223,6 +223,10 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
@@ -269,6 +273,10 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV

View File

@@ -98,14 +98,14 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er
}
// SetExcludeList sets the list of peer IDs that should always have permanent connections.
func (e *ConnMgr) SetExcludeList(peerIDs []string) {
func (e *ConnMgr) SetExcludeList(peerIDs map[string]bool) {
if e.lazyConnMgr == nil {
return
}
excludedPeers := make([]lazyconn.PeerConfig, 0, len(peerIDs))
for _, peerID := range peerIDs {
for peerID := range peerIDs {
var peerConn *peer.Conn
var exists bool
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {

View File

@@ -1927,14 +1927,16 @@ func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewal
return forwardingRules, nberrors.FormatErrorOrNil(merr)
}
func (e *Engine) toExcludedLazyPeers(routes []*route.Route, rules []firewallManager.ForwardRule, peers []*mgmProto.RemotePeerConfig) []string {
excludedPeers := make([]string, 0)
func (e *Engine) toExcludedLazyPeers(routes []*route.Route, rules []firewallManager.ForwardRule, peers []*mgmProto.RemotePeerConfig) map[string]bool {
excludedPeers := make(map[string]bool)
for _, r := range routes {
if r.Peer == "" {
continue
}
log.Infof("exclude router peer from lazy connection: %s", r.Peer)
excludedPeers = append(excludedPeers, r.Peer)
if !excludedPeers[r.Peer] {
log.Infof("exclude router peer from lazy connection: %s", r.Peer)
excludedPeers[r.Peer] = true
}
}
for _, r := range rules {
@@ -1945,7 +1947,7 @@ func (e *Engine) toExcludedLazyPeers(routes []*route.Route, rules []firewallMana
continue
}
log.Infof("exclude forwarder peer from lazy connection: %s", p.GetWgPubKey())
excludedPeers = append(excludedPeers, p.GetWgPubKey())
excludedPeers[p.GetWgPubKey()] = true
}
}
}

View File

@@ -691,8 +691,7 @@ func (conn *Conn) evalStatus() ConnStatus {
}
func (conn *Conn) isConnectedOnAllWay() (connected bool) {
conn.mu.Lock()
defer conn.mu.Unlock()
// would be better to protect this with a mutex, but it could cause deadlock with Close function
defer func() {
if !connected {

View File

@@ -18,6 +18,8 @@ type notifier struct {
currentClientState bool
lastNotification int
lastNumberOfPeers int
lastFqdnAddress string
lastIPAddress string
}
func newNotifier() *notifier {
@@ -25,15 +27,22 @@ func newNotifier() *notifier {
}
func (n *notifier) setListener(listener Listener) {
n.serverStateLock.Lock()
lastNotification := n.lastNotification
numOfPeers := n.lastNumberOfPeers
fqdnAddress := n.lastFqdnAddress
address := n.lastIPAddress
n.serverStateLock.Unlock()
n.listenersLock.Lock()
defer n.listenersLock.Unlock()
n.serverStateLock.Lock()
n.notifyListener(listener, n.lastNotification)
listener.OnPeersListChanged(n.lastNumberOfPeers)
n.serverStateLock.Unlock()
n.listener = listener
listener.OnAddressChanged(fqdnAddress, address)
notifyListener(listener, lastNotification)
// run on go routine to avoid on Java layer to call go functions on same thread
go listener.OnPeersListChanged(numOfPeers)
}
func (n *notifier) removeListener() {
@@ -44,41 +53,44 @@ func (n *notifier) removeListener() {
func (n *notifier) updateServerStates(mgmState bool, signalState bool) {
n.serverStateLock.Lock()
defer n.serverStateLock.Unlock()
calculatedState := n.calculateState(mgmState, signalState)
if !n.isServerStateChanged(calculatedState) {
n.serverStateLock.Unlock()
return
}
n.lastNotification = calculatedState
n.serverStateLock.Unlock()
n.notify(n.lastNotification)
n.notify(calculatedState)
}
func (n *notifier) clientStart() {
n.serverStateLock.Lock()
defer n.serverStateLock.Unlock()
n.currentClientState = true
n.lastNotification = stateConnecting
n.notify(n.lastNotification)
n.serverStateLock.Unlock()
n.notify(stateConnecting)
}
func (n *notifier) clientStop() {
n.serverStateLock.Lock()
defer n.serverStateLock.Unlock()
n.currentClientState = false
n.lastNotification = stateDisconnected
n.notify(n.lastNotification)
n.serverStateLock.Unlock()
n.notify(stateDisconnected)
}
func (n *notifier) clientTearDown() {
n.serverStateLock.Lock()
defer n.serverStateLock.Unlock()
n.currentClientState = false
n.lastNotification = stateDisconnecting
n.notify(n.lastNotification)
n.serverStateLock.Unlock()
n.notify(stateDisconnecting)
}
func (n *notifier) isServerStateChanged(newState int) bool {
@@ -87,26 +99,14 @@ func (n *notifier) isServerStateChanged(newState int) bool {
func (n *notifier) notify(state int) {
n.listenersLock.Lock()
defer n.listenersLock.Unlock()
if n.listener == nil {
listener := n.listener
n.listenersLock.Unlock()
if listener == nil {
return
}
n.notifyListener(n.listener, state)
}
func (n *notifier) notifyListener(l Listener, state int) {
go func() {
switch state {
case stateDisconnected:
l.OnDisconnected()
case stateConnected:
l.OnConnected()
case stateConnecting:
l.OnConnecting()
case stateDisconnecting:
l.OnDisconnecting()
}
}()
notifyListener(listener, state)
}
func (n *notifier) calculateState(managementConn, signalConn bool) int {
@@ -126,20 +126,48 @@ func (n *notifier) calculateState(managementConn, signalConn bool) int {
}
func (n *notifier) peerListChanged(numOfPeers int) {
n.serverStateLock.Lock()
n.lastNumberOfPeers = numOfPeers
n.serverStateLock.Unlock()
n.listenersLock.Lock()
defer n.listenersLock.Unlock()
if n.listener == nil {
listener := n.listener
n.listenersLock.Unlock()
if listener == nil {
return
}
n.listener.OnPeersListChanged(numOfPeers)
// run on go routine to avoid on Java layer to call go functions on same thread
go listener.OnPeersListChanged(numOfPeers)
}
func (n *notifier) localAddressChanged(fqdn, address string) {
n.serverStateLock.Lock()
n.lastFqdnAddress = fqdn
n.lastIPAddress = address
n.serverStateLock.Unlock()
n.listenersLock.Lock()
defer n.listenersLock.Unlock()
if n.listener == nil {
listener := n.listener
n.listenersLock.Unlock()
if listener == nil {
return
}
n.listener.OnAddressChanged(fqdn, address)
listener.OnAddressChanged(fqdn, address)
}
func notifyListener(l Listener, state int) {
switch state {
case stateDisconnected:
l.OnDisconnected()
case stateConnected:
l.OnConnected()
case stateConnecting:
l.OnConnecting()
case stateDisconnecting:
l.OnDisconnecting()
}
}

View File

@@ -1,5 +1,3 @@
//go:build !android
package routemanager
import (

View File

@@ -1,27 +0,0 @@
//go:build android
package routemanager
import (
"context"
"fmt"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/route"
)
type serverRouter struct {
}
func (r serverRouter) cleanUp() {
}
func (r serverRouter) updateRoutes(map[route.ID]*route.Route, bool) error {
return nil
}
func newServerRouter(context.Context, iface.WGIface, firewall.Manager, *peer.Status) (*serverRouter, error) {
return nil, fmt.Errorf("server route not supported on this os")
}

View File

@@ -1925,13 +1925,71 @@ components:
- os
- address
- dns_label
NetworkTrafficEvent:
NetworkTrafficUser:
type: object
properties:
id:
type: string
description: "ID of the event. Unique."
example: "18e204d6-f7c6-405d-8025-70becb216add"
description: "UserID is the ID of the user that initiated the event (can be empty as not every event is user-initiated)."
example: "google-oauth2|123456789012345678901"
email:
type: string
description: "Email of the user who initiated the event (if any)."
example: "alice@netbird.io"
name:
type: string
description: "Name of the user who initiated the event (if any)."
example: "Alice Smith"
required:
- id
- email
- name
NetworkTrafficPolicy:
type: object
properties:
id:
type: string
description: "ID of the policy that allowed this event."
example: "ch8i4ug6lnn4g9hqv7m0"
name:
type: string
description: "Name of the policy that allowed this event."
example: "All to All"
required:
- id
- name
NetworkTrafficICMP:
type: object
properties:
type:
type: integer
description: "ICMP type (if applicable)."
example: 8
code:
type: integer
description: "ICMP code (if applicable)."
example: 0
required:
- type
- code
NetworkTrafficSubEvent:
type: object
properties:
type:
type: string
description: Type of the event (e.g., TYPE_UNKNOWN, TYPE_START, TYPE_END, TYPE_DROP).
example: TYPE_START
timestamp:
type: string
format: date-time
description: Timestamp of the event as sent by the peer.
example: 2025-03-20T16:23:58.125397Z
required:
- type
- timestamp
NetworkTrafficEvent:
type: object
properties:
flow_id:
type: string
description: "FlowID is the ID of the connection flow. Not unique because it can be the same for multiple events (e.g., start and end of the connection)."
@@ -1940,43 +1998,20 @@ components:
type: string
description: "ID of the reporter of the event (e.g., the peer that reported the event)."
example: "ch8i4ug6lnn4g9hqv7m0"
timestamp:
type: string
format: date-time
description: "Timestamp of the event. Send by the peer."
example: "2025-03-20T16:23:58.125397Z"
receive_timestamp:
type: string
format: date-time
description: "Timestamp when the event was received by our API."
example: "2025-03-20T16:23:58.125397Z"
source:
$ref: '#/components/schemas/NetworkTrafficEndpoint'
user_id:
type: string
nullable: true
description: "UserID is the ID of the user that initiated the event (can be empty as not every event is user-initiated)."
example: "google-oauth2|123456789012345678901"
user_email:
type: string
nullable: true
description: "Email of the user who initiated the event (if any)."
example: "alice@netbird.io"
user_name:
type: string
nullable: true
description: "Name of the user who initiated the event (if any)."
example: "Alice Smith"
destination:
$ref: '#/components/schemas/NetworkTrafficEndpoint'
user:
$ref: '#/components/schemas/NetworkTrafficUser'
policy:
$ref: '#/components/schemas/NetworkTrafficPolicy'
icmp:
$ref: '#/components/schemas/NetworkTrafficICMP'
protocol:
type: integer
description: "Protocol is the protocol of the traffic (e.g. 1 = ICMP, 6 = TCP, 17 = UDP, etc.)."
example: 6
type:
type: string
description: "Type of the event (e.g. TYPE_UNKNOWN, TYPE_START, TYPE_END, TYPE_DROP)."
example: "TYPE_START"
direction:
type: string
description: "Direction of the traffic (e.g. DIRECTION_UNKNOWN, INGRESS, EGRESS)."
@@ -1997,43 +2032,28 @@ components:
type: integer
description: "Number of packets transmitted."
example: 5
policy_id:
type: string
description: "ID of the policy that allowed this event."
example: "ch8i4ug6lnn4g9hqv7m0"
policy_name:
type: string
description: "Name of the policy that allowed this event."
example: "All to All"
icmp_type:
type: integer
description: "ICMP type (if applicable)."
example: 8
icmp_code:
type: integer
description: "ICMP code (if applicable)."
example: 0
events:
type: array
description: "List of events that are correlated to this flow (e.g., start, end)."
items:
$ref: '#/components/schemas/NetworkTrafficSubEvent'
required:
- id
- flow_id
- reporter_id
- timestamp
- receive_timestamp
- source
- user_id
- user_email
- destination
- user
- policy
- icmp
- protocol
- type
- direction
- rx_bytes
- rx_packets
- tx_bytes
- tx_packets
- policy_id
- policy_name
- icmp_type
- icmp_code
- events
NetworkTrafficEventsResponse:
type: object
properties:
@@ -4048,6 +4068,31 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/networks/routers:
get:
summary: List all Network Routers
description: Returns a list of all routers in a network
tags: [ Networks ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
responses:
'200':
description: A JSON Array of Routers
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/NetworkRouter'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/dns/nameservers:
get:
summary: List all Nameserver Groups

View File

@@ -883,30 +883,17 @@ type NetworkTrafficEvent struct {
// Direction Direction of the traffic (e.g. DIRECTION_UNKNOWN, INGRESS, EGRESS).
Direction string `json:"direction"`
// Events List of events that are correlated to this flow (e.g., start, end).
Events []NetworkTrafficSubEvent `json:"events"`
// FlowId FlowID is the ID of the connection flow. Not unique because it can be the same for multiple events (e.g., start and end of the connection).
FlowId string `json:"flow_id"`
// IcmpCode ICMP code (if applicable).
IcmpCode int `json:"icmp_code"`
// IcmpType ICMP type (if applicable).
IcmpType int `json:"icmp_type"`
// Id ID of the event. Unique.
Id string `json:"id"`
// PolicyId ID of the policy that allowed this event.
PolicyId string `json:"policy_id"`
// PolicyName Name of the policy that allowed this event.
PolicyName string `json:"policy_name"`
FlowId string `json:"flow_id"`
Icmp NetworkTrafficICMP `json:"icmp"`
Policy NetworkTrafficPolicy `json:"policy"`
// Protocol Protocol is the protocol of the traffic (e.g. 1 = ICMP, 6 = TCP, 17 = UDP, etc.).
Protocol int `json:"protocol"`
// ReceiveTimestamp Timestamp when the event was received by our API.
ReceiveTimestamp time.Time `json:"receive_timestamp"`
// ReporterId ID of the reporter of the event (e.g., the peer that reported the event).
ReporterId string `json:"reporter_id"`
@@ -917,26 +904,12 @@ type NetworkTrafficEvent struct {
RxPackets int `json:"rx_packets"`
Source NetworkTrafficEndpoint `json:"source"`
// Timestamp Timestamp of the event. Send by the peer.
Timestamp time.Time `json:"timestamp"`
// TxBytes Number of bytes transmitted.
TxBytes int `json:"tx_bytes"`
// TxPackets Number of packets transmitted.
TxPackets int `json:"tx_packets"`
// Type Type of the event (e.g. TYPE_UNKNOWN, TYPE_START, TYPE_END, TYPE_DROP).
Type string `json:"type"`
// UserEmail Email of the user who initiated the event (if any).
UserEmail *string `json:"user_email"`
// UserId UserID is the ID of the user that initiated the event (can be empty as not every event is user-initiated).
UserId *string `json:"user_id"`
// UserName Name of the user who initiated the event (if any).
UserName *string `json:"user_name"`
TxPackets int `json:"tx_packets"`
User NetworkTrafficUser `json:"user"`
}
// NetworkTrafficEventsResponse defines model for NetworkTrafficEventsResponse.
@@ -957,6 +930,15 @@ type NetworkTrafficEventsResponse struct {
TotalRecords int `json:"total_records"`
}
// NetworkTrafficICMP defines model for NetworkTrafficICMP.
type NetworkTrafficICMP struct {
// Code ICMP code (if applicable).
Code int `json:"code"`
// Type ICMP type (if applicable).
Type int `json:"type"`
}
// NetworkTrafficLocation defines model for NetworkTrafficLocation.
type NetworkTrafficLocation struct {
// CityName Name of the city (if known).
@@ -966,6 +948,36 @@ type NetworkTrafficLocation struct {
CountryCode string `json:"country_code"`
}
// NetworkTrafficPolicy defines model for NetworkTrafficPolicy.
type NetworkTrafficPolicy struct {
// Id ID of the policy that allowed this event.
Id string `json:"id"`
// Name Name of the policy that allowed this event.
Name string `json:"name"`
}
// NetworkTrafficSubEvent defines model for NetworkTrafficSubEvent.
type NetworkTrafficSubEvent struct {
// Timestamp Timestamp of the event as sent by the peer.
Timestamp time.Time `json:"timestamp"`
// Type Type of the event (e.g., TYPE_UNKNOWN, TYPE_START, TYPE_END, TYPE_DROP).
Type string `json:"type"`
}
// NetworkTrafficUser defines model for NetworkTrafficUser.
type NetworkTrafficUser struct {
// Email Email of the user who initiated the event (if any).
Email string `json:"email"`
// Id UserID is the ID of the user that initiated the event (can be empty as not every event is user-initiated).
Id string `json:"id"`
// Name Name of the user who initiated the event (if any).
Name string `json:"name"`
}
// OSVersionCheck Posture check for the version of operating system
type OSVersionCheck struct {
// Android Posture check for the version of operating system

View File

@@ -19,7 +19,8 @@ type routersHandler struct {
func addRouterEndpoints(routersManager routers.Manager, router *mux.Router) {
routersHandler := newRoutersHandler(routersManager)
router.HandleFunc("/networks/{networkId}/routers", routersHandler.getAllRouters).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/routers", routersHandler.getAllRouters).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers", routersHandler.getNetworkRouters).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers", routersHandler.createRouter).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.getRouter).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.updateRouter).Methods("PUT", "OPTIONS")
@@ -41,6 +42,31 @@ func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) {
accountID, userID := userAuth.AccountId, userAuth.UserId
routersMap, err := h.routersManager.GetAllRoutersInAccount(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
routersResponse := make([]*api.NetworkRouter, 0)
for _, routers := range routersMap {
for _, router := range routers {
routersResponse = append(routersResponse, router.ToAPIResponse())
}
}
util.WriteJSONObject(r.Context(), w, routersResponse)
}
func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
networkID := mux.Vars(r)["networkId"]
routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), accountID, userID, networkID)
if err != nil {

View File

@@ -17,6 +17,7 @@ import (
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/geolocation"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
@@ -352,7 +353,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
if err = am.validatePeerDelete(ctx, accountID, peerID); err != nil {
if err = am.validatePeerDelete(ctx, transaction, accountID, peerID); err != nil {
return err
}
@@ -1543,7 +1544,7 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
}
// validatePeerDelete checks if the peer can be deleted.
func (am *DefaultAccountManager) validatePeerDelete(ctx context.Context, accountId, peerId string) error {
func (am *DefaultAccountManager) validatePeerDelete(ctx context.Context, transaction store.Store, accountId, peerId string) error {
linkedInIngressPorts, err := am.proxyController.IsPeerInIngressPorts(ctx, accountId, peerId)
if err != nil {
return err
@@ -1553,5 +1554,27 @@ func (am *DefaultAccountManager) validatePeerDelete(ctx context.Context, account
return status.Errorf(status.PreconditionFailed, "peer is linked to ingress ports: %s", peerId)
}
linked, router := isPeerLinkedToNetworkRouter(ctx, transaction, accountId, peerId)
if linked {
return status.Errorf(status.PreconditionFailed, "peer is linked to a network router: %s", router.ID)
}
return nil
}
// isPeerLinkedToNetworkRouter checks if a peer is linked to any network router in the account.
func isPeerLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, peerID string) (bool, *routerTypes.NetworkRouter) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving network routers while checking peer linkage: %v", err)
return false, nil
}
for _, router := range routers {
if router.Peer == peerID {
return true, router
}
}
return false, nil
}

View File

@@ -365,11 +365,14 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (
return nil, nil, fmt.Errorf("failed to add all group to account: %v", err)
}
var sqlStore Store
var cleanup func()
maxRetries := 2
for i := 0; i < maxRetries; i++ {
sqlStore, cleanUp, err := getSqlStoreEngine(ctx, store, kind)
sqlStore, cleanup, err = getSqlStoreEngine(ctx, store, kind)
if err == nil {
return sqlStore, cleanUp, nil
return sqlStore, cleanup, nil
}
if i < maxRetries-1 {
time.Sleep(100 * time.Millisecond)
@@ -427,16 +430,16 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind types.Engine)
}
func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) {
if envDsn, ok := os.LookupEnv(postgresDsnEnv); !ok || envDsn == "" {
dsn, ok := os.LookupEnv(postgresDsnEnv)
if !ok || dsn == "" {
var err error
_, err = testutil.CreatePostgresTestContainer()
_, dsn, err = testutil.CreatePostgresTestContainer()
if err != nil {
return nil, nil, err
}
}
dsn, ok := os.LookupEnv(postgresDsnEnv)
if !ok {
if dsn == "" {
return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv)
}
@@ -447,28 +450,28 @@ func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Eng
dsn, cleanup, err := createRandomDB(dsn, db, kind)
if err != nil {
return nil, cleanup, err
return nil, nil, err
}
store, err = NewPostgresqlStoreFromSqlStore(ctx, store, dsn, nil)
if err != nil {
return nil, cleanup, err
return nil, nil, err
}
return store, cleanup, nil
}
func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) {
if envDsn, ok := os.LookupEnv(mysqlDsnEnv); !ok || envDsn == "" {
dsn, ok := os.LookupEnv(mysqlDsnEnv)
if !ok || dsn == "" {
var err error
_, err = testutil.CreateMysqlTestContainer()
_, dsn, err = testutil.CreateMysqlTestContainer()
if err != nil {
return nil, nil, err
}
}
dsn, ok := os.LookupEnv(mysqlDsnEnv)
if !ok {
if dsn == "" {
return nil, nil, fmt.Errorf("%s is not set", mysqlDsnEnv)
}
@@ -479,7 +482,7 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine
dsn, cleanup, err := createRandomDB(dsn, db, kind)
if err != nil {
return nil, cleanup, err
return nil, nil, err
}
store, err = NewMysqlStoreFromSqlStore(ctx, store, dsn, nil)

View File

@@ -5,7 +5,6 @@ package testutil
import (
"context"
"os"
"time"
log "github.com/sirupsen/logrus"
@@ -16,11 +15,25 @@ import (
"github.com/testcontainers/testcontainers-go/wait"
)
var (
pgContainer *postgres.PostgresContainer
mysqlContainer *mysql.MySQLContainer
)
// CreateMysqlTestContainer creates a new MySQL container for testing.
func CreateMysqlTestContainer() (func(), error) {
func CreateMysqlTestContainer() (func(), string, error) {
ctx := context.Background()
myContainer, err := mysql.RunContainer(ctx,
if mysqlContainer != nil {
connStr, err := mysqlContainer.ConnectionString(ctx)
if err != nil {
return nil, "", err
}
return noOpCleanup, connStr, nil
}
var err error
mysqlContainer, err = mysql.RunContainer(ctx,
testcontainers.WithImage("mlsmaycon/warmed-mysql:8"),
mysql.WithDatabase("testing"),
mysql.WithUsername("root"),
@@ -31,31 +44,42 @@ func CreateMysqlTestContainer() (func(), error) {
),
)
if err != nil {
return nil, err
return nil, "", err
}
cleanup := func() {
os.Unsetenv("NETBIRD_STORE_ENGINE_MYSQL_DSN")
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second)
defer cancelFunc()
if err = myContainer.Terminate(timeoutCtx); err != nil {
log.WithContext(ctx).Warnf("failed to stop mysql container %s: %s", myContainer.GetContainerID(), err)
if mysqlContainer != nil {
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second)
defer cancelFunc()
if err = mysqlContainer.Terminate(timeoutCtx); err != nil {
log.WithContext(ctx).Warnf("failed to stop mysql container %s: %s", mysqlContainer.GetContainerID(), err)
}
mysqlContainer = nil // reset the container to allow recreation
}
}
talksConn, err := myContainer.ConnectionString(ctx)
talksConn, err := mysqlContainer.ConnectionString(ctx)
if err != nil {
return nil, err
return nil, "", err
}
return cleanup, os.Setenv("NETBIRD_STORE_ENGINE_MYSQL_DSN", talksConn)
return cleanup, talksConn, nil
}
// CreatePostgresTestContainer creates a new PostgreSQL container for testing.
func CreatePostgresTestContainer() (func(), error) {
func CreatePostgresTestContainer() (func(), string, error) {
ctx := context.Background()
pgContainer, err := postgres.RunContainer(ctx,
if pgContainer != nil {
connStr, err := pgContainer.ConnectionString(ctx)
if err != nil {
return nil, "", err
}
return noOpCleanup, connStr, nil
}
var err error
pgContainer, err = postgres.RunContainer(ctx,
testcontainers.WithImage("postgres:16-alpine"),
postgres.WithDatabase("netbird"),
postgres.WithUsername("root"),
@@ -66,24 +90,31 @@ func CreatePostgresTestContainer() (func(), error) {
),
)
if err != nil {
return nil, err
return nil, "", err
}
cleanup := func() {
os.Unsetenv("NETBIRD_STORE_ENGINE_POSTGRES_DSN")
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second)
defer cancelFunc()
if err = pgContainer.Terminate(timeoutCtx); err != nil {
log.WithContext(ctx).Warnf("failed to stop postgres container %s: %s", pgContainer.GetContainerID(), err)
if pgContainer != nil {
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second)
defer cancelFunc()
if err = pgContainer.Terminate(timeoutCtx); err != nil {
log.WithContext(ctx).Warnf("failed to stop postgres container %s: %s", pgContainer.GetContainerID(), err)
}
pgContainer = nil // reset the container to allow recreation
}
}
talksConn, err := pgContainer.ConnectionString(ctx)
if err != nil {
return nil, err
return nil, "", err
}
return cleanup, os.Setenv("NETBIRD_STORE_ENGINE_POSTGRES_DSN", talksConn)
return cleanup, talksConn, nil
}
func noOpCleanup() {
// no-op
}
// CreateRedisTestContainer creates a new Redis container for testing.

View File

@@ -3,16 +3,16 @@
package testutil
func CreatePostgresTestContainer() (func(), error) {
func CreatePostgresTestContainer() (func(), string, error) {
return func() {
// Empty function for Postgres
}, nil
}, "", nil
}
func CreateMysqlTestContainer() (func(), error) {
func CreateMysqlTestContainer() (func(), string, error) {
return func() {
// Empty function for MySQL
}, nil
}, "", nil
}
func CreateRedisTestContainer() (func(), string, error) {

View File

@@ -676,7 +676,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
return false, nil, nil, nil, fmt.Errorf("error modifying user peers in groups: %w", err)
}
if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, update.AccountID, updatedGroups); err != nil {
if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, accountID, updatedGroups); err != nil {
return false, nil, nil, nil, fmt.Errorf("error saving groups: %w", err)
}
}