From 3b7b9d25bcb2861e85118d564bcf6e71f8ed3e72 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 23 Apr 2025 01:07:04 +0200 Subject: [PATCH 01/45] [client] Keep new routes selected unless all are deselected (#3692) --- .../internal/routeselector/routeselector.go | 125 +++------- .../routeselector/routeselector_test.go | 231 +++++++++++++++++- 2 files changed, 260 insertions(+), 96 deletions(-) diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go index 72c4758f4..8ebdc63e5 100644 --- a/client/internal/routeselector/routeselector.go +++ b/client/internal/routeselector/routeselector.go @@ -14,23 +14,15 @@ import ( ) type RouteSelector struct { - mu sync.RWMutex - selectedRoutes map[route.NetID]struct{} - selectAll bool - - // Indicates if new routes should be automatically selected - includeNewRoutes bool - - // All known routes at the time of deselection - knownRoutes []route.NetID + mu sync.RWMutex + deselectedRoutes map[route.NetID]struct{} + deselectAll bool } func NewRouteSelector() *RouteSelector { return &RouteSelector{ - selectedRoutes: map[route.NetID]struct{}{}, - selectAll: true, - includeNewRoutes: false, - knownRoutes: []route.NetID{}, + deselectedRoutes: map[route.NetID]struct{}{}, + deselectAll: false, } } @@ -39,8 +31,11 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al rs.mu.Lock() defer rs.mu.Unlock() - if !appendRoute { - rs.selectedRoutes = map[route.NetID]struct{}{} + if !appendRoute || rs.deselectAll { + maps.Clear(rs.deselectedRoutes) + for _, r := range allRoutes { + rs.deselectedRoutes[r] = struct{}{} + } } var err *multierror.Error @@ -49,11 +44,10 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route)) continue } - - rs.selectedRoutes[route] = struct{}{} + delete(rs.deselectedRoutes, route) } - rs.selectAll = false - rs.includeNewRoutes = false + + rs.deselectAll = false return errors.FormatErrorOrNil(err) } @@ -63,38 +57,26 @@ func (rs *RouteSelector) SelectAllRoutes() { rs.mu.Lock() defer rs.mu.Unlock() - rs.selectAll = true - rs.selectedRoutes = map[route.NetID]struct{}{} - rs.includeNewRoutes = false + rs.deselectAll = false + maps.Clear(rs.deselectedRoutes) } // DeselectRoutes removes specific routes from the selection. -// If the selector is in "select all" mode, it will transition to "select specific" mode -// but will keep new routes selected. func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error { rs.mu.Lock() defer rs.mu.Unlock() - if rs.selectAll { - rs.selectAll = false - rs.includeNewRoutes = true - rs.knownRoutes = make([]route.NetID, len(allRoutes)) - copy(rs.knownRoutes, allRoutes) - - rs.selectedRoutes = map[route.NetID]struct{}{} - for _, route := range allRoutes { - rs.selectedRoutes[route] = struct{}{} - } + if rs.deselectAll { + return nil } var err *multierror.Error - for _, route := range routes { if !slices.Contains(allRoutes, route) { err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route)) continue } - delete(rs.selectedRoutes, route) + rs.deselectedRoutes[route] = struct{}{} } return errors.FormatErrorOrNil(err) @@ -105,9 +87,8 @@ func (rs *RouteSelector) DeselectAllRoutes() { rs.mu.Lock() defer rs.mu.Unlock() - rs.selectAll = false - rs.includeNewRoutes = false - rs.selectedRoutes = map[route.NetID]struct{}{} + rs.deselectAll = true + maps.Clear(rs.deselectedRoutes) } // IsSelected checks if a specific route is selected. @@ -115,23 +96,12 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { rs.mu.RLock() defer rs.mu.RUnlock() - if rs.selectAll { - return true + if rs.deselectAll { + return false } - // Check if the route exists in selectedRoutes - _, selected := rs.selectedRoutes[routeID] - if selected { - return true - } - - // If includeNewRoutes is true and this is a new route (not in knownRoutes), - // then it should be selected - if rs.includeNewRoutes && !slices.Contains(rs.knownRoutes, routeID) { - return true - } - - return false + _, deselected := rs.deselectedRoutes[routeID] + return !deselected } // FilterSelected removes unselected routes from the provided map. @@ -139,17 +109,15 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap { rs.mu.RLock() defer rs.mu.RUnlock() - if rs.selectAll { - return maps.Clone(routes) + if rs.deselectAll { + return route.HAMap{} } filtered := route.HAMap{} for id, rt := range routes { netID := id.NetID() - _, selected := rs.selectedRoutes[netID] - - // Include if directly selected or if it's a new route and includeNewRoutes is true - if selected || (rs.includeNewRoutes && !slices.Contains(rs.knownRoutes, netID)) { + _, deselected := rs.deselectedRoutes[netID] + if !deselected { filtered[id] = rt } } @@ -162,15 +130,11 @@ func (rs *RouteSelector) MarshalJSON() ([]byte, error) { defer rs.mu.RUnlock() return json.Marshal(struct { - SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` - SelectAll bool `json:"select_all"` - IncludeNewRoutes bool `json:"include_new_routes"` - KnownRoutes []route.NetID `json:"known_routes"` + DeselectedRoutes map[route.NetID]struct{} `json:"deselected_routes"` + DeselectAll bool `json:"deselect_all"` }{ - SelectAll: rs.selectAll, - SelectedRoutes: rs.selectedRoutes, - IncludeNewRoutes: rs.includeNewRoutes, - KnownRoutes: rs.knownRoutes, + DeselectedRoutes: rs.deselectedRoutes, + DeselectAll: rs.deselectAll, }) } @@ -182,34 +146,25 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error { // Check for null or empty JSON if len(data) == 0 || string(data) == "null" { - rs.selectedRoutes = map[route.NetID]struct{}{} - rs.selectAll = true - rs.includeNewRoutes = false - rs.knownRoutes = []route.NetID{} + rs.deselectedRoutes = map[route.NetID]struct{}{} + rs.deselectAll = false return nil } var temp struct { - SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` - SelectAll bool `json:"select_all"` - IncludeNewRoutes bool `json:"include_new_routes"` - KnownRoutes []route.NetID `json:"known_routes"` + DeselectedRoutes map[route.NetID]struct{} `json:"deselected_routes"` + DeselectAll bool `json:"deselect_all"` } if err := json.Unmarshal(data, &temp); err != nil { return err } - rs.selectedRoutes = temp.SelectedRoutes - rs.selectAll = temp.SelectAll - rs.includeNewRoutes = temp.IncludeNewRoutes - rs.knownRoutes = temp.KnownRoutes + rs.deselectedRoutes = temp.DeselectedRoutes + rs.deselectAll = temp.DeselectAll - if rs.selectedRoutes == nil { - rs.selectedRoutes = map[route.NetID]struct{}{} - } - if rs.knownRoutes == nil { - rs.knownRoutes = []route.NetID{} + if rs.deselectedRoutes == nil { + rs.deselectedRoutes = map[route.NetID]struct{}{} } return nil diff --git a/client/internal/routeselector/routeselector_test.go b/client/internal/routeselector/routeselector_test.go index a1461dff6..cfa723246 100644 --- a/client/internal/routeselector/routeselector_test.go +++ b/client/internal/routeselector/routeselector_test.go @@ -66,12 +66,10 @@ func TestRouteSelector_SelectRoutes(t *testing.T) { t.Run(tt.name, func(t *testing.T) { rs := routeselector.NewRouteSelector() - if tt.initialSelected != nil { - err := rs.SelectRoutes(tt.initialSelected, false, allRoutes) - require.NoError(t, err) - } + err := rs.SelectRoutes(tt.initialSelected, false, allRoutes) + require.NoError(t, err) - err := rs.SelectRoutes(tt.selectRoutes, tt.append, allRoutes) + err = rs.SelectRoutes(tt.selectRoutes, tt.append, allRoutes) if tt.wantError { assert.Error(t, err) } else { @@ -251,7 +249,8 @@ func TestRouteSelector_IsSelected(t *testing.T) { assert.True(t, rs.IsSelected("route1")) assert.True(t, rs.IsSelected("route2")) assert.False(t, rs.IsSelected("route3")) - assert.False(t, rs.IsSelected("route4")) + // Unknown route is selected by default + assert.True(t, rs.IsSelected("route4")) } func TestRouteSelector_FilterSelected(t *testing.T) { @@ -297,8 +296,8 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) { initialState: func(rs *routeselector.RouteSelector) error { return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, initialRoutes) }, - // When specific routes were selected, new routes should remain unselected - wantNewSelected: []route.NetID{"route1", "route2"}, + // When specific routes were selected, new routes should be selected + wantNewSelected: []route.NetID{"route1", "route2", "route4", "route5"}, }, { name: "New routes after deselect all", @@ -315,7 +314,7 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) { rs.SelectAllRoutes() return rs.DeselectRoutes([]route.NetID{"route1"}, initialRoutes) }, - // After deselecting specific routes, new routes should remain unselected + // After deselecting specific routes, new routes should be selected wantNewSelected: []route.NetID{"route2", "route3", "route4", "route5"}, }, { @@ -323,8 +322,8 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) { initialState: func(rs *routeselector.RouteSelector) error { return rs.SelectRoutes([]route.NetID{"route1"}, true, initialRoutes) }, - // When routes were appended, new routes should remain unselected - wantNewSelected: []route.NetID{"route1"}, + // When routes were appended, new routes should be selected + wantNewSelected: []route.NetID{"route1", "route2", "route3", "route4", "route5"}, }, } @@ -428,3 +427,213 @@ func TestRouteSelector_MixedSelectionDeselection(t *testing.T) { }) } } + +func TestRouteSelector_AfterDeselectAll(t *testing.T) { + allRoutes := []route.NetID{"route1", "route2", "route3"} + + tests := []struct { + name string + initialAction func(rs *routeselector.RouteSelector) error + secondAction func(rs *routeselector.RouteSelector) error + wantSelected []route.NetID + wantError bool + }{ + { + name: "Deselect all -> select specific routes", + initialAction: func(rs *routeselector.RouteSelector) error { + rs.DeselectAllRoutes() + return nil + }, + secondAction: func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, allRoutes) + }, + wantSelected: []route.NetID{"route1", "route2"}, + }, + { + name: "Deselect all -> select with append", + initialAction: func(rs *routeselector.RouteSelector) error { + rs.DeselectAllRoutes() + return nil + }, + secondAction: func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route1"}, true, allRoutes) + }, + wantSelected: []route.NetID{"route1"}, + }, + { + name: "Deselect all -> deselect specific", + initialAction: func(rs *routeselector.RouteSelector) error { + rs.DeselectAllRoutes() + return nil + }, + secondAction: func(rs *routeselector.RouteSelector) error { + return rs.DeselectRoutes([]route.NetID{"route1"}, allRoutes) + }, + wantSelected: []route.NetID{}, + }, + { + name: "Deselect all -> select all", + initialAction: func(rs *routeselector.RouteSelector) error { + rs.DeselectAllRoutes() + return nil + }, + secondAction: func(rs *routeselector.RouteSelector) error { + rs.SelectAllRoutes() + return nil + }, + wantSelected: []route.NetID{"route1", "route2", "route3"}, + }, + { + name: "Deselect all -> deselect non-existent route", + initialAction: func(rs *routeselector.RouteSelector) error { + rs.DeselectAllRoutes() + return nil + }, + secondAction: func(rs *routeselector.RouteSelector) error { + return rs.DeselectRoutes([]route.NetID{"route4"}, allRoutes) + }, + wantSelected: []route.NetID{}, + wantError: false, + }, + { + name: "Select specific -> deselect all -> select different", + initialAction: func(rs *routeselector.RouteSelector) error { + err := rs.SelectRoutes([]route.NetID{"route1"}, false, allRoutes) + if err != nil { + return err + } + rs.DeselectAllRoutes() + return nil + }, + secondAction: func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route2", "route3"}, false, allRoutes) + }, + wantSelected: []route.NetID{"route2", "route3"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rs := routeselector.NewRouteSelector() + + err := tt.initialAction(rs) + require.NoError(t, err) + + err = tt.secondAction(rs) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + for _, id := range allRoutes { + expected := slices.Contains(tt.wantSelected, id) + assert.Equal(t, expected, rs.IsSelected(id), + "Route %s selection state incorrect, expected %v", id, expected) + } + + routes := route.HAMap{ + "route1|10.0.0.0/8": {}, + "route2|192.168.0.0/16": {}, + "route3|172.16.0.0/12": {}, + } + + filtered := rs.FilterSelected(routes) + assert.Equal(t, len(tt.wantSelected), len(filtered), + "FilterSelected returned wrong number of routes") + }) + } +} + +func TestRouteSelector_ComplexScenarios(t *testing.T) { + allRoutes := []route.NetID{"route1", "route2", "route3", "route4"} + + tests := []struct { + name string + actions []func(rs *routeselector.RouteSelector) error + wantSelected []route.NetID + }{ + { + name: "Select all -> deselect specific -> select different with append", + actions: []func(rs *routeselector.RouteSelector) error{ + func(rs *routeselector.RouteSelector) error { + rs.SelectAllRoutes() + return nil + }, + func(rs *routeselector.RouteSelector) error { + return rs.DeselectRoutes([]route.NetID{"route1", "route2"}, allRoutes) + }, + func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route1"}, true, allRoutes) + }, + }, + wantSelected: []route.NetID{"route1", "route3", "route4"}, + }, + { + name: "Deselect all -> select specific -> deselect one -> select different with append", + actions: []func(rs *routeselector.RouteSelector) error{ + func(rs *routeselector.RouteSelector) error { + rs.DeselectAllRoutes() + return nil + }, + func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, allRoutes) + }, + func(rs *routeselector.RouteSelector) error { + return rs.DeselectRoutes([]route.NetID{"route2"}, allRoutes) + }, + func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route3"}, true, allRoutes) + }, + }, + wantSelected: []route.NetID{"route1", "route3"}, + }, + { + name: "Select specific -> deselect specific -> select all -> deselect different", + actions: []func(rs *routeselector.RouteSelector) error{ + func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, allRoutes) + }, + func(rs *routeselector.RouteSelector) error { + return rs.DeselectRoutes([]route.NetID{"route2"}, allRoutes) + }, + func(rs *routeselector.RouteSelector) error { + rs.SelectAllRoutes() + return nil + }, + func(rs *routeselector.RouteSelector) error { + return rs.DeselectRoutes([]route.NetID{"route3", "route4"}, allRoutes) + }, + }, + wantSelected: []route.NetID{"route1", "route2"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rs := routeselector.NewRouteSelector() + + for i, action := range tt.actions { + err := action(rs) + require.NoError(t, err, "Action %d failed", i) + } + + for _, id := range allRoutes { + expected := slices.Contains(tt.wantSelected, id) + assert.Equal(t, expected, rs.IsSelected(id), + "Route %s selection state incorrect", id) + } + + routes := route.HAMap{ + "route1|10.0.0.0/8": {}, + "route2|192.168.0.0/16": {}, + "route3|172.16.0.0/12": {}, + "route4|10.10.0.0/16": {}, + } + + filtered := rs.FilterSelected(routes) + assert.Equal(t, len(tt.wantSelected), len(filtered), + "FilterSelected returned wrong number of routes") + }) + } +} From f74ea64c7b6949ece3c54201becbeed54917d3c8 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 23 Apr 2025 10:20:51 +0200 Subject: [PATCH 02/45] Bump golang.org/x/net from 0.36.0 to 0.38.0 (#3695) Bumps [golang.org/x/net](https://github.com/golang/net) from 0.36.0 to 0.38.0. - [Commits](https://github.com/golang/net/compare/v0.36.0...v0.38.0) --- updated-dependencies: - dependency-name: golang.org/x/net dependency-version: 0.38.0 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index c00f32063..cb4b9850a 100644 --- a/go.mod +++ b/go.mod @@ -100,7 +100,7 @@ require ( goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a - golang.org/x/net v0.36.0 + golang.org/x/net v0.38.0 golang.org/x/oauth2 v0.24.0 golang.org/x/sync v0.12.0 golang.org/x/term v0.30.0 diff --git a/go.sum b/go.sum index f00b42beb..f97f7527d 100644 --- a/go.sum +++ b/go.sum @@ -846,8 +846,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.36.0 h1:vWF2fRbw4qslQsQzgFqZff+BItCvGFQqKzKIzx1rmoA= -golang.org/x/net v0.36.0/go.mod h1:bFmbeoIPfrw4sMHNhb4J9f6+tPziuGjq7Jk/38fxi1I= +golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= +golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= From 197761ba4dfcac723b0818a235643639862cf6aa Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 23 Apr 2025 10:21:36 +0200 Subject: [PATCH 03/45] Bump github.com/redis/go-redis/v9 from 9.7.1 to 9.7.3 (#3553) Bumps [github.com/redis/go-redis/v9](https://github.com/redis/go-redis) from 9.7.1 to 9.7.3. - [Release notes](https://github.com/redis/go-redis/releases) - [Changelog](https://github.com/redis/go-redis/blob/master/CHANGELOG.md) - [Commits](https://github.com/redis/go-redis/compare/v9.7.1...v9.7.3) --- updated-dependencies: - dependency-name: github.com/redis/go-redis/v9 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index cb4b9850a..b1b01d446 100644 --- a/go.mod +++ b/go.mod @@ -75,7 +75,7 @@ require ( github.com/pion/turn/v3 v3.0.1 github.com/prometheus/client_golang v1.22.0 github.com/quic-go/quic-go v0.48.2 - github.com/redis/go-redis/v9 v9.7.1 + github.com/redis/go-redis/v9 v9.7.3 github.com/rs/xid v1.3.0 github.com/shirou/gopsutil/v3 v3.24.4 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 diff --git a/go.sum b/go.sum index f97f7527d..fb351dd25 100644 --- a/go.sum +++ b/go.sum @@ -576,8 +576,8 @@ github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0leargg github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE= github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= -github.com/redis/go-redis/v9 v9.7.1 h1:4LhKRCIduqXqtvCUlaq9c8bdHOkICjDMrr1+Zb3osAc= -github.com/redis/go-redis/v9 v9.7.1/go.mod h1:f6zhXITC7JUJIlPEiBOTXxJgPLdZcA93GewI7inzyWw= +github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM= +github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= From 986eb8c1e01f263d477c1841465c89252868c8c1 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 23 Apr 2025 15:54:49 +0200 Subject: [PATCH 04/45] [management] fix lastLogin on dashboard (#3725) --- management/server/http/handler.go | 1 + .../server/http/middleware/auth_middleware.go | 24 ++++++++++++++----- .../http/middleware/auth_middleware_test.go | 6 +++++ 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 483bb989a..3d4de31d0 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -62,6 +62,7 @@ func NewAPIHandler( authManager, accountManager.GetAccountIDFromUserAuth, accountManager.SyncUserJWTGroups, + accountManager.GetUserFromUserAuth, ) corsMiddleware := cors.AllowAll() diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index a8e6790a9..6f0d1556f 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -15,16 +15,20 @@ import ( "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth) error +type GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) + // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens type AuthMiddleware struct { - authManager auth.Manager - ensureAccount EnsureAccountFunc - syncUserJWTGroups SyncUserJWTGroupsFunc + authManager auth.Manager + ensureAccount EnsureAccountFunc + getUserFromUserAuth GetUserFromUserAuthFunc + syncUserJWTGroups SyncUserJWTGroupsFunc } // NewAuthMiddleware instance constructor @@ -32,11 +36,13 @@ func NewAuthMiddleware( authManager auth.Manager, ensureAccount EnsureAccountFunc, syncUserJWTGroups SyncUserJWTGroupsFunc, + getUserFromUserAuth GetUserFromUserAuthFunc, ) *AuthMiddleware { return &AuthMiddleware{ - authManager: authManager, - ensureAccount: ensureAccount, - syncUserJWTGroups: syncUserJWTGroups, + authManager: authManager, + ensureAccount: ensureAccount, + syncUserJWTGroups: syncUserJWTGroups, + getUserFromUserAuth: getUserFromUserAuth, } } @@ -123,6 +129,12 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*h log.WithContext(ctx).Errorf("HTTP server failed to sync user JWT groups: %s", err) } + _, err = m.getUserFromUserAuth(ctx, userAuth) + if err != nil { + log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err) + return r, err + } + return nbcontext.SetUserAuthInRequest(r, userAuth), nil } diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index 3dc7d51cb..410ff7e15 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -190,6 +190,9 @@ func TestAuthMiddleware_Handler(t *testing.T) { func(ctx context.Context, userAuth nbcontext.UserAuth) error { return nil }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, ) handlerToTest := authMiddleware.Handler(nextHandler) @@ -291,6 +294,9 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { func(ctx context.Context, userAuth nbcontext.UserAuth) error { return nil }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, ) for _, tc := range tt { From c69df13515b49e570241900b0682a88be8b36621 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Wed, 23 Apr 2025 18:44:22 +0200 Subject: [PATCH 05/45] [management] Add account meta (#3724) --- management/server/account.go | 13 ++++++++++ management/server/account/manager.go | 1 + management/server/http/api/openapi.yml | 21 +++++++++++++++ management/server/http/api/types.gen.go | 12 +++++++++ .../handlers/accounts/accounts_handler.go | 26 +++++++++++++++---- .../accounts/accounts_handler_test.go | 6 +++++ management/server/mock_server/account_mock.go | 9 +++++++ management/server/store/sql_store.go | 15 +++++++++++ management/server/store/sql_store_test.go | 16 ++++++++++++ management/server/store/store.go | 1 + management/server/testdata/extended-store.sql | 2 +- management/server/types/account.go | 21 +++++++++++++++ 12 files changed, 137 insertions(+), 6 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index d7f108dfe..fb0a9b65e 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1057,6 +1057,19 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s return am.Store.GetAccount(ctx, accountID) } +// GetAccountMeta returns the account metadata associated with this account ID. +func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + return am.Store.GetAccountMeta(ctx, store.LockingStrengthShare, accountID) +} + func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { if userAuth.UserId == "" { return "", "", errors.New(emptyUserID) diff --git a/management/server/account/manager.go b/management/server/account/manager.go index ea664d10e..b6eb7de05 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -37,6 +37,7 @@ type Manager interface { SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) + GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) AccountExists(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index c699e9eef..1717c89ac 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -43,9 +43,30 @@ components: example: ch8i4ug6lnn4g9hqv7l0 settings: $ref: '#/components/schemas/AccountSettings' + domain: + description: Account domain + type: string + example: netbird.io + domain_category: + description: Account domain category + type: string + example: private + created_at: + description: Account creation date (UTC) + type: string + format: date-time + example: "2023-05-05T09:00:35.477782Z" + created_by: + description: Account creator + type: string + example: google-oauth2|277474792786460067937 required: - id - settings + - domain + - domain_category + - created_at + - created_by AccountSettings: type: object properties: diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 9bdb3e4ac..3fca40366 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -223,6 +223,18 @@ type AccessiblePeer struct { // Account defines model for Account. type Account struct { + // CreatedAt Account creation date (UTC) + CreatedAt time.Time `json:"created_at"` + + // CreatedBy Account creator + CreatedBy string `json:"created_by"` + + // Domain Account domain + Domain string `json:"domain"` + + // DomainCategory Account domain category + DomainCategory string `json:"domain_category"` + // Id Account ID Id string `json:"id"` Settings AccountSettings `json:"settings"` diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index 6c8f8028a..c0851102f 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -47,13 +47,19 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { accountID, userID := userAuth.AccountId, userAuth.UserId + meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + settings, err := h.settingsManager.GetSettings(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return } - resp := toAccountResponse(accountID, settings) + resp := toAccountResponse(accountID, settings, meta) util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } @@ -120,7 +126,13 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { return } - resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings) + meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings, meta) util.WriteJSONObject(r.Context(), w, &resp) } @@ -149,7 +161,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func toAccountResponse(accountID string, settings *types.Settings) *api.Account { +func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta) *api.Account { jwtAllowGroups := settings.JWTAllowGroups if jwtAllowGroups == nil { jwtAllowGroups = []string{} @@ -177,7 +189,11 @@ func toAccountResponse(accountID string, settings *types.Settings) *api.Account } return &api.Account{ - Id: accountID, - Settings: apiSettings, + Id: accountID, + Settings: apiSettings, + CreatedAt: meta.CreatedAt, + CreatedBy: meta.CreatedBy, + Domain: meta.Domain, + DomainCategory: meta.DomainCategory, } } diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index e971a6514..2acca4f49 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -50,6 +50,12 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler { accCopy.UpdateSettings(newSettings) return accCopy, nil }, + GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) { + return account.Copy(), nil + }, + GetAccountMetaFunc: func(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) { + return account.GetMeta(), nil + }, }, settingsManager: settingsMockManager, } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 870fe3219..804877a66 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -116,6 +116,7 @@ type MockAccountManager struct { UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error) GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error) GetCurrentUserInfoFunc func(ctx context.Context, accountID, userID string) (*types.UserInfo, error) + GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error) } func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { @@ -803,6 +804,14 @@ func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID stri return nil, status.Errorf(codes.Unimplemented, "method GetAccountByID is not implemented") } +// GetAccountByID mocks GetAccountByID of the AccountManager interface +func (am *MockAccountManager) GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) { + if am.GetAccountMetaFunc != nil { + return am.GetAccountMetaFunc(ctx, accountID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAccountMeta is not implemented") +} + // GetUserByID mocks GetUserByID of the AccountManager interface func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) { if am.GetUserByIDFunc != nil { diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index aacb56ab8..b73c372ae 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -658,6 +658,21 @@ func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) { return all } +func (s *SqlStore) GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error) { + var accountMeta types.AccountMeta + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). + First(&accountMeta, idQueryCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("error when getting account meta %s from the store: %s", accountID, result.Error) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.NewGetAccountFromStoreError(result.Error) + } + + return &accountMeta, nil +} + func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { start := time.Now() defer func() { diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 589e727e9..c16a50108 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -3247,3 +3247,19 @@ func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) { require.NoError(t, err) require.Equal(t, 8003, len(accountGroups)) } + +func TestSqlStore_GetAccountMeta(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + accountMeta, err := store.GetAccountMeta(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err) + require.NotNil(t, accountMeta) + require.Equal(t, accountID, accountMeta.AccountID) + require.Equal(t, "edafee4e-63fb-11ec-90d6-0242ac120003", accountMeta.CreatedBy) + require.Equal(t, "test.com", accountMeta.Domain) + require.Equal(t, "private", accountMeta.DomainCategory) + require.Equal(t, time.Date(2024, time.October, 2, 14, 1, 38, 210000000, time.UTC), accountMeta.CreatedAt.UTC()) +} diff --git a/management/server/store/store.go b/management/server/store/store.go index c13a8dfe6..4a26bf5c3 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -50,6 +50,7 @@ type Store interface { GetAccountsCounter(ctx context.Context) (int64, error) GetAllAccounts(ctx context.Context) []*types.Account GetAccount(ctx context.Context, accountID string) (*types.Account, error) + GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql index 2859e82c8..7900dabf5 100644 --- a/management/server/testdata/extended-store.sql +++ b/management/server/testdata/extended-store.sql @@ -25,7 +25,7 @@ CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); -INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:01:38.210014+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','edafee4e-63fb-11ec-90d6-0242ac120003','2024-10-02 16:01:38.210000+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBB','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["cfefqs706sqkneg59g2g"]',0,0); INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBC','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBC','Faulty key with non existing group','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["abcd"]',0,0); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','["cfefqs706sqkneg59g3g"]',0,NULL,'2024-10-02 16:01:38.210678+02:00','api',0,''); diff --git a/management/server/types/account.go b/management/server/types/account.go index 687709991..ea5f50001 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -40,6 +40,17 @@ const ( type LookupMap map[string]struct{} +// AccountMeta is a struct that contains a stripped down version of the Account object. +// It doesn't carry any peers, groups, policies, or routes, etc. Just some metadata (e.g. ID, created by, created at, etc). +type AccountMeta struct { + // AccountId is the unique identifier of the account + AccountID string `gorm:"column:id"` + CreatedAt time.Time + CreatedBy string + Domain string + DomainCategory string +} + // Account represents a unique account of the system type Account struct { // we have to name column to aid as it collides with Network.Id when work with associations @@ -855,6 +866,16 @@ func (a *Account) Copy() *Account { } } +func (a *Account) GetMeta() *AccountMeta { + return &AccountMeta{ + AccountID: a.Id, + CreatedBy: a.CreatedBy, + CreatedAt: a.CreatedAt, + Domain: a.Domain, + DomainCategory: a.DomainCategory, + } +} + func (a *Account) GetGroupAll() (*Group, error) { for _, g := range a.Groups { if g.Name == "All" { From 8db05838cad9353ff3d3bd22739f3eb16ad72850 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 23 Apr 2025 19:35:26 +0200 Subject: [PATCH 06/45] [misc] Change github runner for docker test (#3707) --- .github/workflows/golang-test-linux.yml | 42 ++++++++++++++++--------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index e727aa4e5..4e690ff1b 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -545,7 +545,7 @@ jobs: test_client_on_docker: name: "Client (Docker) / Unit" needs: [ build-cache ] - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 steps: - name: Install Go uses: actions/setup-go@v5 @@ -559,7 +559,7 @@ jobs: - name: Get Go environment run: | echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV - echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV + echo "modcache=$(go env.GOMODCACHE)" >> $GITHUB_ENV - name: Cache Go modules uses: actions/cache/restore@v4 @@ -577,17 +577,31 @@ jobs: - name: Install modules run: go mod tidy - - name: check git status + - name: Check git status run: git --no-pager diff --exit-code - name: Generate Shared Sock Test bin run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock - name: Generate RouteManager Test bin - run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager + run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager - - name: Generate SystemOps Test bin - run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops + - name: Generate SystemOps Test bin (static via Alpine) + run: | + docker run --rm -v $PWD:/app -w /app \ + alpine:latest \ + sh -c " + apk add --no-cache go gcc musl-dev libpcap-dev dbus-dev && \ + adduser -D -u $(id -u) builder && \ + su builder -c '\ + cd /app && \ + CGO_ENABLED=1 GOOS=linux GOARCH=amd64 \ + go test -c -o /app/systemops-testing.bin \ + -tags netgo \ + -ldflags=\"-w -extldflags \\\"-static -ldbus-1 -lpcap\\\"\" \ + ./client/internal/routemanager/systemops \ + ' + " - name: Generate nftables Manager Test bin run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/... @@ -601,25 +615,25 @@ jobs: - run: chmod +x *testing.bin - name: Run Shared Sock tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1 + run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /ci/sharedsock-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1 - name: Run Iface tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/... + run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/... - name: Run RouteManager tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1 + run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /ci/routemanager-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1 - name: Run SystemOps tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1 + run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /ci/systemops-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1 - name: Run nftables Manager tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1 + run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /ci/nftablesmanager-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1 - name: Run Engine tests in docker with file store - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1 + run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /ci/engine-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1 - name: Run Engine tests in docker with sqlite store - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1 + run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /ci/engine-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1 - name: Run Peer tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1 + run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /ci/peer-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1 From 312bfd9bd789c10c17e9c6d068df7dabefc6618f Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 23 Apr 2025 19:36:53 +0200 Subject: [PATCH 07/45] [management] support custom domains per account (#3726) --- client/cmd/testutil_test.go | 5 ++ management/server/account.go | 21 +++++++- management/server/account/manager.go | 2 +- management/server/activity/codes.go | 4 ++ management/server/group.go | 11 +++- management/server/grpcserver.go | 44 ++++++++++----- management/server/http/api/openapi.yml | 4 ++ management/server/http/api/types.gen.go | 4 +- .../handlers/accounts/accounts_handler.go | 4 ++ .../accounts/accounts_handler_test.go | 4 ++ .../http/handlers/peers/peers_handler.go | 25 +++++++-- .../http/handlers/peers/peers_handler_test.go | 5 +- management/server/mock_server/account_mock.go | 6 +-- management/server/peer.go | 53 ++++++++++++------- management/server/types/settings.go | 4 ++ management/server/user.go | 8 ++- 16 files changed, 158 insertions(+), 46 deletions(-) diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 70abe4abe..258a8daff 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -98,6 +98,11 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc settingsMockManager := settings.NewMockManager(ctrl) permissionsManagerMock := permissions.NewMockManager(ctrl) + settingsMockManager.EXPECT(). + GetSettings(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&types.Settings{}, nil). + AnyTimes() + accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock) if err != nil { t.Fatal(err) diff --git a/management/server/account.go b/management/server/account.go index fb0a9b65e..cc5ca309a 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -275,6 +275,10 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") } + if newSettings.DNSDomain != "" && !isDomainValid(newSettings.DNSDomain) { + return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) + } + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -325,6 +329,12 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco account.Network.Serial++ } + if oldSettings.DNSDomain != newSettings.DNSDomain { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, nil) + updateAccountPeers = true + account.Network.Serial++ + } + err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) if err != nil { return nil, err @@ -1493,8 +1503,15 @@ func isDomainValid(domain string) bool { } // GetDNSDomain returns the configured dnsDomain -func (am *DefaultAccountManager) GetDNSDomain() string { - return am.dnsDomain +func (am *DefaultAccountManager) GetDNSDomain(settings *types.Settings) string { + if settings == nil { + return am.dnsDomain + } + if settings.DNSDomain == "" { + return am.dnsDomain + } + + return settings.DNSDomain } func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) { diff --git a/management/server/account/manager.go b/management/server/account/manager.go index b6eb7de05..aed83349f 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -81,7 +81,7 @@ type Manager interface { SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) - GetDNSDomain() string + GetDNSDomain(settings *types.Settings) string StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 46ae754cf..ed4be82e2 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -169,6 +169,8 @@ const ( ResourceAddedToGroup Activity = 82 ResourceRemovedFromGroup Activity = 83 + + AccountDNSDomainUpdated Activity = 84 ) var activityMap = map[Activity]Code{ @@ -264,6 +266,8 @@ var activityMap = map[Activity]Code{ ResourceAddedToGroup: {"Resource added to group", "resource.group.add"}, ResourceRemovedFromGroup: {"Resource removed from group", "resource.group.delete"}, + + AccountDNSDomainUpdated: {"Account DNS domain updated", "account.dns.domain.update"}, } // StringCode returns a string code of the activity diff --git a/management/server/group.go b/management/server/group.go index 0bd840798..87d649228 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -158,6 +158,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac return nil } + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err) + return nil + } + dnsDomain := am.GetDNSDomain(settings) + for _, peerID := range addedPeers { peer, ok := peers[peerID] if !ok { @@ -168,7 +175,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac eventsToStore = append(eventsToStore, func() { meta := map[string]any{ "group": newGroup.Name, "group_id": newGroup.ID, - "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(dnsDomain), } am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta) }) @@ -184,7 +191,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac eventsToStore = append(eventsToStore, func() { meta := map[string]any{ "group": newGroup.Name, "group_id": newGroup.ID, - "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(dnsDomain), } am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta) }) diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index a7ed639c3..43d35f643 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -480,20 +480,12 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p s.ephemeralManager.OnPeerDisconnected(ctx, peer) } - var relayToken *Token - if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 { - relayToken, err = s.secretsManager.GenerateRelayToken() - if err != nil { - log.Errorf("failed generating Relay token: %v", err) - } + loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks) + if err != nil { + log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err) + return nil, status.Errorf(codes.Internal, "failed logging in peer") } - // if peer has reached this point then it has logged in - loginResp := &proto.LoginResponse{ - NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil), - PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(), false), - Checks: toProtocolChecks(ctx, postureChecks), - } encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) if err != nil { log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID) @@ -506,6 +498,32 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p }, nil } +func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) { + var relayToken *Token + var err error + if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 { + relayToken, err = s.secretsManager.GenerateRelayToken() + if err != nil { + log.Errorf("failed generating Relay token: %v", err) + } + } + + settings, err := s.settingsManager.GetSettings(ctx, peer.AccountID, activity.SystemInitiator) + if err != nil { + log.WithContext(ctx).Warnf("failed getting settings for peer %s: %s", peer.Key, err) + return nil, status.Errorf(codes.Internal, "failed getting settings") + } + + // if peer has reached this point then it has logged in + loginResp := &proto.LoginResponse{ + NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil), + PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(settings), false), + Checks: toProtocolChecks(ctx, postureChecks), + } + + return loginResp, nil +} + // processJwtToken validates the existence of a JWT token in the login request, and returns the corresponding user ID if // the token is valid. // @@ -712,7 +730,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p return status.Errorf(codes.Internal, "error handling request") } - plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil, settings.RoutingPeerDNSResolutionEnabled, settings.Extra) + plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings.RoutingPeerDNSResolutionEnabled, settings.Extra) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) if err != nil { diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 1717c89ac..c0ce06daa 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -112,6 +112,10 @@ components: description: Enables or disables DNS resolution on the routing peers type: boolean example: true + dns_domain: + description: Allows to define a custom dns domain for the account + type: string + example: my-organization.org extra: $ref: '#/components/schemas/AccountExtraSettings' required: diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 3fca40366..243f2fdf9 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -259,7 +259,9 @@ type AccountRequest struct { // AccountSettings defines model for AccountSettings. type AccountSettings struct { - Extra *AccountExtraSettings `json:"extra,omitempty"` + // DnsDomain Allows to define a custom dns domain for the account + DnsDomain *string `json:"dns_domain,omitempty"` + Extra *AccountExtraSettings `json:"extra,omitempty"` // GroupsPropagationEnabled Allows propagate the new user auto groups to peers that belongs to the user GroupsPropagationEnabled *bool `json:"groups_propagation_enabled,omitempty"` diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index c0851102f..7cad26bd6 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -119,6 +119,9 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { if req.Settings.RoutingPeerDnsResolutionEnabled != nil { settings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled } + if req.Settings.DnsDomain != nil { + settings.DNSDomain = *req.Settings.DnsDomain + } updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) if err != nil { @@ -178,6 +181,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A JwtAllowGroups: &jwtAllowGroups, RegularUsersViewBlocked: settings.RegularUsersViewBlocked, RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled, + DnsDomain: &settings.DNSDomain, } if settings.Extra != nil { diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index 2acca4f49..57bbffc7c 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -108,6 +108,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { JwtAllowGroups: &[]string{}, RegularUsersViewBlocked: true, RoutingPeerDnsResolutionEnabled: br(false), + DnsDomain: sr(""), }, expectedArray: true, expectedID: accountID, @@ -128,6 +129,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { JwtAllowGroups: &[]string{}, RegularUsersViewBlocked: false, RoutingPeerDnsResolutionEnabled: br(false), + DnsDomain: sr(""), }, expectedArray: false, expectedID: accountID, @@ -148,6 +150,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { JwtAllowGroups: &[]string{"test"}, RegularUsersViewBlocked: true, RoutingPeerDnsResolutionEnabled: br(false), + DnsDomain: sr(""), }, expectedArray: false, expectedID: accountID, @@ -168,6 +171,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { JwtAllowGroups: &[]string{}, RegularUsersViewBlocked: true, RoutingPeerDnsResolutionEnabled: br(false), + DnsDomain: sr(""), }, expectedArray: false, expectedID: accountID, diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index fa78836d8..58ea06ea3 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -65,7 +65,13 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, util.WriteError(ctx, err, w) return } - dnsDomain := h.accountManager.GetDNSDomain() + settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator) + if err != nil { + util.WriteError(ctx, err, w) + return + } + + dnsDomain := h.accountManager.GetDNSDomain(settings) grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) @@ -110,7 +116,13 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri util.WriteError(ctx, err, w) return } - dnsDomain := h.accountManager.GetDNSDomain() + + settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator) + if err != nil { + util.WriteError(ctx, err, w) + return + } + dnsDomain := h.accountManager.GetDNSDomain(settings) peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID) if err != nil { @@ -192,7 +204,12 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { return } - dnsDomain := h.accountManager.GetDNSDomain() + settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, activity.SystemInitiator) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + dnsDomain := h.accountManager.GetDNSDomain(settings) grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) @@ -279,7 +296,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { return } - dnsDomain := h.accountManager.GetDNSDomain() + dnsDomain := h.accountManager.GetDNSDomain(account.Settings) customZone := account.GetPeersCustomZone(r.Context(), dnsDomain) netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index a03c3c29d..a1fc13dd3 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -152,7 +152,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { }, }, nil }, - GetDNSDomainFunc: func() string { + GetDNSDomainFunc: func(settings *types.Settings) string { return "netbird.selfhosted" }, GetAccountFunc: func(ctx context.Context, accountID string) (*types.Account, error) { @@ -172,6 +172,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { _, ok := statuses[peerID] return ok }, + GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) { + return account.Settings, nil + }, }, } } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 804877a66..2b57e6888 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -83,7 +83,7 @@ type MockAccountManager struct { CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) DeleteAccountFunc func(ctx context.Context, accountID, userID string) error - GetDNSDomainFunc func() string + GetDNSDomainFunc func(settings *types.Settings) string StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error) GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error) @@ -620,9 +620,9 @@ func (am *MockAccountManager) GetPeers(ctx context.Context, accountID, userID, n } // GetDNSDomain mocks GetDNSDomain of the AccountManager interface -func (am *MockAccountManager) GetDNSDomain() string { +func (am *MockAccountManager) GetDNSDomain(settings *types.Settings) string { if am.GetDNSDomainFunc != nil { - return am.GetDNSDomainFunc() + return am.GetDNSDomainFunc(settings) } return "" } diff --git a/management/server/peer.go b/management/server/peer.go index 27825a148..908610fbe 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -206,6 +206,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user var sshChanged bool var loginExpirationChanged bool var inactivityExpirationChanged bool + var dnsDomain string err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, update.ID) @@ -223,7 +224,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user return err } - update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), peerGroupList, settings.Extra) + dnsDomain = am.GetDNSDomain(settings) + + update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, dnsDomain, peerGroupList, settings.Extra) if err != nil { return err } @@ -276,11 +279,11 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user if !peer.SSHEnabled { event = activity.PeerSSHDisabled } - am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(dnsDomain)) } if peerLabelChanged { - am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain())) + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(dnsDomain)) } if loginExpirationChanged { @@ -288,7 +291,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user if !peer.LoginExpirationEnabled { event = activity.PeerLoginExpirationDisabled } - am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(dnsDomain)) if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { am.checkAndSchedulePeerLoginExpiration(ctx, accountID) @@ -300,7 +303,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user if !peer.InactivityExpirationEnabled { event = activity.PeerInactivityExpirationDisabled } - am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(dnsDomain)) if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled { am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) @@ -413,7 +416,7 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin if err != nil { return nil, err } - customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) + customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings)) proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id) if err != nil { @@ -574,8 +577,13 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s ExtraDNSLabels: peer.ExtraDNSLabels, AllowExtraDNSLabels: allowExtraDNSLabels, } + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return fmt.Errorf("failed to get account settings: %w", err) + } + opEvent.TargetID = newPeer.ID - opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain()) + opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings)) if !addedByUser { opEvent.Meta["setup_key_name"] = setupKeyName } @@ -591,10 +599,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } } - settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) - if err != nil { - return fmt.Errorf("failed to get account settings: %w", err) - } newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra) err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID) @@ -1024,7 +1028,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return nil, nil, nil, err } - customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) + customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings)) proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id) if err != nil { @@ -1060,7 +1064,12 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact log.WithContext(ctx).Debugf("failed to update user last login: %v", err) } - am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain())) + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, peer.AccountID) + if err != nil { + return fmt.Errorf("failed to get account settings: %w", err) + } + + am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain(settings))) return nil } @@ -1174,7 +1183,8 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account semaphore := make(chan struct{}, 10) dnsCache := &DNSConfigCache{} - customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) + dnsDomain := am.GetDNSDomain(account.Settings) + customZone := account.GetPeersCustomZone(ctx, dnsDomain) resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() @@ -1215,7 +1225,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account return } - update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSetting) + update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSetting) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) }(peer) } @@ -1270,7 +1280,8 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI } dnsCache := &DNSConfigCache{} - customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) + dnsDomain := am.GetDNSDomain(account.Settings) + customZone := account.GetPeersCustomZone(ctx, dnsDomain) resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() @@ -1299,7 +1310,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI return } - update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSettings) + update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSettings) am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) } @@ -1484,6 +1495,12 @@ func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { var peerDeletedEvents []func() + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + dnsDomain := am.GetDNSDomain(settings) + for _, peer := range peers { if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID); err != nil { return nil, err @@ -1514,7 +1531,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto }) am.peersUpdateManager.CloseChannel(ctx, peer.ID) peerDeletedEvents = append(peerDeletedEvents, func() { - am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain)) }) } diff --git a/management/server/types/settings.go b/management/server/types/settings.go index 7054ede8c..c8de2a98c 100644 --- a/management/server/types/settings.go +++ b/management/server/types/settings.go @@ -39,6 +39,9 @@ type Settings struct { // RoutingPeerDNSResolutionEnabled enabled the DNS resolution on the routing peers RoutingPeerDNSResolutionEnabled bool + // DNSDomain is the custom domain for that account + DNSDomain string + // Extra is a dictionary of Account settings Extra *ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"` } @@ -58,6 +61,7 @@ func (s *Settings) Copy() *Settings { PeerInactivityExpiration: s.PeerInactivityExpiration, RoutingPeerDNSResolutionEnabled: s.RoutingPeerDNSResolutionEnabled, + DNSDomain: s.DNSDomain, } if s.Extra != nil { settings.Extra = s.Extra.Copy() diff --git a/management/server/user.go b/management/server/user.go index 9ec16e72c..b46ed24cf 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -940,6 +940,12 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a // expireAndUpdatePeers expires all peers of the given user and updates them in the account func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error { + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return err + } + dnsDomain := am.GetDNSDomain(settings) + var peerIDs []string for _, peer := range peers { // nolint:staticcheck @@ -957,7 +963,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou am.StoreEvent( ctx, peer.UserID, peer.ID, accountID, - activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()), + activity.PeerLoginExpired, peer.EventMeta(dnsDomain), ) } From 4013298e22b8aa0d69391ec2eafd9c0c1d418b71 Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Wed, 23 Apr 2025 22:04:38 +0300 Subject: [PATCH 08/45] [client/ui] add connecting state to status handling (#3712) --- client/ui/client_ui.go | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index b2a6404bb..d0b1bacf6 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -457,7 +457,7 @@ func (s *serviceClient) menuUpClick() error { if status.Status == string(internal.StatusConnected) { log.Warnf("already connected") - return err + return nil } if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil { @@ -482,7 +482,7 @@ func (s *serviceClient) menuDownClick() error { return err } - if status.Status != string(internal.StatusConnected) { + if status.Status != string(internal.StatusConnected) && status.Status != string(internal.StatusConnecting) { log.Warnf("already down") return nil } @@ -520,7 +520,9 @@ func (s *serviceClient) updateStatus() error { } var systrayIconState bool - if status.Status == string(internal.StatusConnected) && !s.mUp.Disabled() { + + switch { + case status.Status == string(internal.StatusConnected): s.connected = true s.sendNotification = true if s.isUpdateIconActive { @@ -535,7 +537,9 @@ func (s *serviceClient) updateStatus() error { s.mNetworks.Enable() go s.updateExitNodes() systrayIconState = true - } else if status.Status != string(internal.StatusConnected) && s.mUp.Disabled() { + case status.Status == string(internal.StatusConnecting): + s.setConnectingStatus() + case status.Status != string(internal.StatusConnected) && s.mUp.Disabled(): s.setDisconnectedStatus() systrayIconState = false } @@ -594,6 +598,17 @@ func (s *serviceClient) setDisconnectedStatus() { go s.updateExitNodes() } +func (s *serviceClient) setConnectingStatus() { + s.connected = false + systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) + systray.SetTooltip("NetBird (Connecting)") + s.mStatus.SetTitle("Connecting") + s.mUp.Disable() + s.mDown.Enable() + s.mNetworks.Disable() + s.mExitNode.Disable() +} + func (s *serviceClient) onTrayReady() { systray.SetTemplateIcon(iconDisconnectedMacOS, s.icDisconnected) systray.SetTooltip("NetBird") From 400b9fca329cbe0d7c2dc3e37cdef3115126a7e8 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 23 Apr 2025 21:29:46 +0200 Subject: [PATCH 09/45] [management] Add firewall rule route ID and missing route domains (#3700) --- management/proto/management.pb.go | 119 ++++++++++-------- management/proto/management.proto | 3 + management/server/route.go | 2 + management/server/route_test.go | 19 +++ management/server/types/account.go | 1 + management/server/types/firewall_rule.go | 1 + .../server/types/route_firewall_rule.go | 4 + 7 files changed, 95 insertions(+), 54 deletions(-) diff --git a/management/proto/management.pb.go b/management/proto/management.pb.go index f3f53bfd4..9d7fdc682 100644 --- a/management/proto/management.pb.go +++ b/management/proto/management.pb.go @@ -3057,6 +3057,8 @@ type RouteFirewallRule struct { CustomProtocol uint32 `protobuf:"varint,8,opt,name=customProtocol,proto3" json:"customProtocol,omitempty"` // PolicyID is the ID of the policy that this rule belongs to PolicyID []byte `protobuf:"bytes,9,opt,name=PolicyID,proto3" json:"PolicyID,omitempty"` + // RouteID is the ID of the route that this rule belongs to + RouteID string `protobuf:"bytes,10,opt,name=RouteID,proto3" json:"RouteID,omitempty"` } func (x *RouteFirewallRule) Reset() { @@ -3154,6 +3156,13 @@ func (x *RouteFirewallRule) GetPolicyID() []byte { return nil } +func (x *RouteFirewallRule) GetRouteID() string { + if x != nil { + return x.RouteID + } + return "" +} + type ForwardingRule struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -3702,7 +3711,7 @@ var file_management_proto_rawDesc = []byte{ 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, - 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0xed, 0x02, 0x0a, 0x11, 0x52, 0x6f, + 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, @@ -3725,66 +3734,68 @@ var file_management_proto_rawDesc = []byte{ 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, - 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, - 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, - 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, - 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, - 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, - 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, - 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, - 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, - 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, - 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, - 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, - 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x2a, 0x4c, - 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, - 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, - 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, - 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, - 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, - 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, - 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, - 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, - 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, - 0x10, 0x01, 0x32, 0x90, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, - 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, + 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, + 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, + 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, + 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, + 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, + 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, + 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, + 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, + 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, + 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, + 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, + 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, + 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, + 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, + 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, + 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, + 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, + 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, + 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x32, 0x90, 0x04, 0x0a, + 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, + 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, + 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, - 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, - 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, - 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, - 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, - 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, - 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, - 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, - 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, - 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, - 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, - 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, + 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, + 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, + 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, + 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, + 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, - 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, - 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, + 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, + 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, - 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, - 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, - 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, + 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, + 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x33, } var ( diff --git a/management/proto/management.proto b/management/proto/management.proto index 0f1cdb97a..f0dc16ce2 100644 --- a/management/proto/management.proto +++ b/management/proto/management.proto @@ -509,6 +509,9 @@ message RouteFirewallRule { // PolicyID is the ID of the policy that this rule belongs to bytes PolicyID = 9; + + // RouteID is the ID of the route that this rule belongs to + string RouteID = 10; } message ForwardingRule { diff --git a/management/server/route.go b/management/server/route.go index 8b91e127a..02755a708 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -398,7 +398,9 @@ func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.Ro Protocol: getProtoProtocol(rule.Protocol), PortInfo: getProtoPortInfo(rule), IsDynamic: rule.IsDynamic, + Domains: rule.Domains.ToPunycodeList(), PolicyID: []byte(rule.PolicyID), + RouteID: string(rule.RouteID), } } diff --git a/management/server/route_test.go b/management/server/route_test.go index dcda3e6d1..833477b55 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1850,6 +1850,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Destination: "192.168.0.0/16", Protocol: "all", Port: 80, + RouteID: "route1:peerA", }, { SourceRanges: []string{ @@ -1861,6 +1862,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Destination: "192.168.0.0/16", Protocol: "all", Port: 320, + RouteID: "route1:peerA", }, } additionalFirewallRule := []*types.RouteFirewallRule{ @@ -1872,6 +1874,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Destination: "192.168.10.0/16", Protocol: "tcp", Port: 80, + RouteID: "route4:peerA", }, { SourceRanges: []string{ @@ -1880,6 +1883,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Action: "accept", Destination: "192.168.10.0/16", Protocol: "all", + RouteID: "route4:peerA", }, } @@ -1888,6 +1892,9 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { // peerD is also the routing peer for route1, should contain same routes firewall rules as peerA routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) assert.Len(t, routesFirewallRules, 2) + for _, rule := range expectedRoutesFirewallRules { + rule.RouteID = "route1:peerD" + } assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) // peerE is a single routing peer for route 2 and route 3 @@ -1901,6 +1908,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Destination: existingNetwork.String(), Protocol: "tcp", PortRange: types.RulePortRange{Start: 80, End: 350}, + RouteID: "route2", }, { SourceRanges: []string{"0.0.0.0/0"}, @@ -1909,6 +1917,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Protocol: "all", Domains: domain.List{"example.com"}, IsDynamic: true, + RouteID: "route3", }, { SourceRanges: []string{"::/0"}, @@ -1917,6 +1926,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Protocol: "all", Domains: domain.List{"example.com"}, IsDynamic: true, + RouteID: "route3", }, } assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) @@ -2676,6 +2686,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Destination: "192.168.0.0/16", Protocol: "all", Port: 80, + RouteID: "resource2:peerA", }, { SourceRanges: []string{ @@ -2687,6 +2698,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Destination: "192.168.0.0/16", Protocol: "all", Port: 320, + RouteID: "resource2:peerA", }, } @@ -2701,6 +2713,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Port: 80, Domains: domain.List{"example.com"}, IsDynamic: true, + RouteID: "resource4:peerA", }, { SourceRanges: []string{ @@ -2711,6 +2724,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Protocol: "all", Domains: domain.List{"example.com"}, IsDynamic: true, + RouteID: "resource4:peerA", }, } assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(append(expectedFirewallRules, additionalFirewallRules...))) @@ -2719,6 +2733,9 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerD", resourcePoliciesMap, resourceRoutersMap) firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerD"], validatedPeers, routes, resourcePoliciesMap) assert.Len(t, firewallRules, 2) + for _, rule := range expectedFirewallRules { + rule.RouteID = "resource2:peerD" + } assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules)) assert.Len(t, sourcePeers, 3) @@ -2736,6 +2753,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Destination: "10.10.10.0/24", Protocol: "tcp", PortRange: types.RulePortRange{Start: 80, End: 350}, + RouteID: "resource1:peerE", }, } assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules)) @@ -2758,6 +2776,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Destination: "10.12.12.1/32", Protocol: "tcp", Port: 8080, + RouteID: "resource5:peerL", }, } assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules)) diff --git a/management/server/types/account.go b/management/server/types/account.go index ea5f50001..e9fa37085 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -1240,6 +1240,7 @@ func getDefaultPermit(route *route.Route) []*RouteFirewallRule { Protocol: string(PolicyRuleProtocolALL), Domains: route.Domains, IsDynamic: route.IsDynamic(), + RouteID: route.ID, } rules = append(rules, &rule) diff --git a/management/server/types/firewall_rule.go b/management/server/types/firewall_rule.go index d98a56871..ef54abea2 100644 --- a/management/server/types/firewall_rule.go +++ b/management/server/types/firewall_rule.go @@ -62,6 +62,7 @@ func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule baseRule := RouteFirewallRule{ PolicyID: rule.PolicyID, + RouteID: route.ID, SourceRanges: sourceRanges, Action: string(rule.Action), Destination: route.Network.String(), diff --git a/management/server/types/route_firewall_rule.go b/management/server/types/route_firewall_rule.go index 5b752bc36..c09c64a3d 100644 --- a/management/server/types/route_firewall_rule.go +++ b/management/server/types/route_firewall_rule.go @@ -2,6 +2,7 @@ package types import ( "github.com/netbirdio/netbird/management/domain" + "github.com/netbirdio/netbird/route" ) // RouteFirewallRule a firewall rule applicable for a routed network. @@ -9,6 +10,9 @@ type RouteFirewallRule struct { // PolicyID is the ID of the policy this rule is derived from PolicyID string + // RouteID is the ID of the route this rule belongs to. + RouteID route.ID + // SourceRanges IP ranges of the routing peers. SourceRanges []string From 714beb6e3b9559ea7fadcc207a5317674745d140 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 24 Apr 2025 12:36:05 +0200 Subject: [PATCH 10/45] [client] Fix exit node deselection (#3722) --- client/ui/network.go | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/client/ui/network.go b/client/ui/network.go index b21554f09..ddd8d5000 100644 --- a/client/ui/network.go +++ b/client/ui/network.go @@ -456,19 +456,27 @@ func (s *serviceClient) toggleExitNode(nodeID string, item *systray.MenuItem) er } } - if item.Checked() && len(ids) == 0 { - // exit node is the only selected node, deselect it + // exit node is the only selected node, deselect it + deselectAll := item.Checked() && len(ids) == 0 + if deselectAll { ids = append(ids, nodeID) - exitNode = nil + for _, node := range exitNodes { + if node.ID == nodeID { + // set desired state for recreation + node.Selected = false + } + } } // deselect all other selected exit nodes - if err := s.deselectOtherExitNodes(conn, ids, item); err != nil { + if err := s.deselectOtherExitNodes(conn, ids); err != nil { return err } - if err := s.selectNewExitNode(conn, exitNode, nodeID, item); err != nil { - return err + if !deselectAll { + if err := s.selectNewExitNode(conn, exitNode, nodeID, item); err != nil { + return err + } } // linux/bsd doesn't handle Check/Uncheck well, so we recreate the menu @@ -479,7 +487,7 @@ func (s *serviceClient) toggleExitNode(nodeID string, item *systray.MenuItem) er return nil } -func (s *serviceClient) deselectOtherExitNodes(conn proto.DaemonServiceClient, ids []string, currentItem *systray.MenuItem) error { +func (s *serviceClient) deselectOtherExitNodes(conn proto.DaemonServiceClient, ids []string) error { // deselect all other selected exit nodes if len(ids) > 0 { deselectReq := &proto.SelectNetworksRequest{ @@ -494,9 +502,6 @@ func (s *serviceClient) deselectOtherExitNodes(conn proto.DaemonServiceClient, i // uncheck all other exit node menu items for _, i := range s.mExitNodeItems { - if i.MenuItem == currentItem { - continue - } i.Uncheck() log.Infof("Unchecked exit node %v", i) } @@ -518,6 +523,7 @@ func (s *serviceClient) selectNewExitNode(conn proto.DaemonServiceClient, exitNo } item.Check() + log.Infof("Checked exit node '%s'", nodeID) return nil } From 85f92f8321b16a7cd444fe09ff6379414f9e7a41 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 24 Apr 2025 12:57:46 +0200 Subject: [PATCH 11/45] [client] Add more userspace filter ACL test cases (#3730) --- .../uspfilter/uspfilter_filter_test.go | 439 +++++++++++++++++- 1 file changed, 419 insertions(+), 20 deletions(-) diff --git a/client/firewall/uspfilter/uspfilter_filter_test.go b/client/firewall/uspfilter/uspfilter_filter_test.go index ba97c2643..9c0a54e3f 100644 --- a/client/firewall/uspfilter/uspfilter_filter_test.go +++ b/client/firewall/uspfilter/uspfilter_filter_test.go @@ -188,6 +188,281 @@ func TestPeerACLFiltering(t *testing.T) { ruleAction: fw.ActionAccept, shouldBeBlocked: true, }, + { + name: "Allow TCP traffic without port specification", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "Allow UDP traffic without port specification", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 53, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolUDP, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "TCP packet doesn't match UDP filter with same port", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolUDP, + ruleDstPort: &fw.Port{Values: []uint16{443}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: true, + }, + { + name: "UDP packet doesn't match TCP filter with same port", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 443, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{Values: []uint16{443}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: true, + }, + { + name: "ICMP packet doesn't match TCP filter", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolICMP, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleAction: fw.ActionAccept, + shouldBeBlocked: true, + }, + { + name: "ICMP packet doesn't match UDP filter", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolICMP, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolUDP, + ruleAction: fw.ActionAccept, + shouldBeBlocked: true, + }, + { + name: "Allow TCP traffic within port range", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8080, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "Block TCP traffic outside port range", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 7999, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: true, + }, + { + name: "Edge Case - Port at Range Boundary", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8100, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "UDP Port Range", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 5060, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolUDP, + ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{5060, 5070}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "Allow multiple destination ports", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8080, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "Allow multiple source ports", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + // New drop test cases + { + name: "Drop TCP traffic from WG peer", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{Values: []uint16{443}}, + ruleAction: fw.ActionDrop, + shouldBeBlocked: true, + }, + { + name: "Drop UDP traffic from WG peer", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 53, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolUDP, + ruleDstPort: &fw.Port{Values: []uint16{53}}, + ruleAction: fw.ActionDrop, + shouldBeBlocked: true, + }, + { + name: "Drop ICMP traffic from WG peer", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolICMP, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolICMP, + ruleAction: fw.ActionDrop, + shouldBeBlocked: true, + }, + { + name: "Drop all traffic from WG peer", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolALL, + ruleAction: fw.ActionDrop, + shouldBeBlocked: true, + }, + { + name: "Drop traffic from multiple source ports", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}}, + ruleAction: fw.ActionDrop, + shouldBeBlocked: true, + }, + { + name: "Drop multiple destination ports", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8080, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}}, + ruleAction: fw.ActionDrop, + shouldBeBlocked: true, + }, + { + name: "Drop TCP traffic within port range", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8080, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}}, + ruleAction: fw.ActionDrop, + shouldBeBlocked: true, + }, + { + name: "Accept TCP traffic outside drop port range", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 7999, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}}, + ruleAction: fw.ActionDrop, + shouldBeBlocked: false, + }, + { + name: "Drop TCP traffic with source port range", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 32100, + dstPort: 80, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleSrcPort: &fw.Port{IsRange: true, Values: []uint16{32000, 33000}}, + ruleAction: fw.ActionDrop, + shouldBeBlocked: true, + }, + { + name: "Mixed rule - drop specific port but allow other ports", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{Values: []uint16{443}}, + ruleAction: fw.ActionDrop, + shouldBeBlocked: true, + }, } t.Run("Implicit DROP (no rules)", func(t *testing.T) { @@ -198,6 +473,28 @@ func TestPeerACLFiltering(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + + if tc.ruleAction == fw.ActionDrop { + // add general accept rule to test drop rule + // TODO: this only works because 0.0.0.0 is tested last, we need to implement order + rules, err := manager.AddPeerFiltering( + nil, + net.ParseIP("0.0.0.0"), + fw.ProtocolALL, + nil, + nil, + fw.ActionAccept, + "", + ) + require.NoError(t, err) + require.NotEmpty(t, rules) + t.Cleanup(func() { + for _, rule := range rules { + require.NoError(t, manager.DeletePeerRule(rule)) + } + }) + } + rules, err := manager.AddPeerFiltering( nil, net.ParseIP(tc.ruleIP), @@ -543,26 +840,6 @@ func TestRouteACLFiltering(t *testing.T) { }, shouldPass: true, }, - { - name: "Multiple source networks with mismatched protocol", - srcIP: "172.16.0.1", - dstIP: "192.168.1.100", - // Should not match TCP rule - proto: fw.ProtocolUDP, - srcPort: 12345, - dstPort: 80, - rule: rule{ - sources: []netip.Prefix{ - netip.MustParsePrefix("100.10.0.0/16"), - netip.MustParsePrefix("172.16.0.0/16"), - }, - dest: netip.MustParsePrefix("192.168.1.0/24"), - proto: fw.ProtocolTCP, - dstPort: &fw.Port{Values: []uint16{80}}, - action: fw.ActionAccept, - }, - shouldPass: false, - }, { name: "Allow multiple destination ports", srcIP: "100.10.0.1", @@ -798,10 +1075,132 @@ func TestRouteACLFiltering(t *testing.T) { }, shouldPass: false, }, + { + name: "Accept TCP traffic outside drop port range", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 7999, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}}, + action: fw.ActionDrop, + }, + shouldPass: true, + }, + { + name: "Allow TCP traffic without port specification", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Allow UDP traffic without port specification", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 53, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolUDP, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "TCP packet doesn't match UDP filter with same port", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolUDP, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionAccept, + }, + shouldPass: false, + }, + { + name: "UDP packet doesn't match TCP filter with same port", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionAccept, + }, + shouldPass: false, + }, + { + name: "ICMP packet doesn't match TCP filter", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolICMP, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + action: fw.ActionAccept, + }, + shouldPass: false, + }, + { + name: "ICMP packet doesn't match UDP filter", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolICMP, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolUDP, + action: fw.ActionAccept, + }, + shouldPass: false, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + if tc.rule.action == fw.ActionDrop { + // add general accept rule to test drop rule + rule, err := manager.AddRouteFiltering( + nil, + []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + netip.MustParsePrefix("0.0.0.0/0"), + fw.ProtocolALL, + nil, + nil, + fw.ActionAccept, + ) + require.NoError(t, err) + require.NotNil(t, rule) + t.Cleanup(func() { + require.NoError(t, manager.DeleteRouteRule(rule)) + }) + } + rule, err := manager.AddRouteFiltering( nil, tc.rule.sources, From 4a9049566a5176304d802ae8331573ccad51c312 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 24 Apr 2025 17:37:28 +0200 Subject: [PATCH 12/45] [client] Set up firewall rules for dns routes dynamically based on dns response (#3702) --- client/firewall/iptables/manager_linux.go | 17 +- client/firewall/iptables/router_linux.go | 158 ++++++--- client/firewall/iptables/router_linux_test.go | 85 ++++- client/firewall/manager/firewall.go | 66 ++-- client/firewall/manager/firewall_test.go | 16 +- client/firewall/manager/routerpair.go | 6 +- client/firewall/manager/set.go | 74 +++++ client/firewall/nftables/manager_linux.go | 19 +- .../firewall/nftables/manager_linux_test.go | 6 +- client/firewall/nftables/router_linux.go | 306 ++++++++++++------ client/firewall/nftables/router_linux_test.go | 10 +- client/firewall/test/cases_linux.go | 12 +- client/firewall/uspfilter/allow_netbird.go | 2 +- .../uspfilter/allow_netbird_windows.go | 6 +- client/firewall/uspfilter/rule.go | 17 +- client/firewall/uspfilter/tracer_test.go | 4 +- client/firewall/uspfilter/uspfilter.go | 163 +++++++--- .../uspfilter/uspfilter_filter_test.go | 158 ++++++--- client/firewall/uspfilter/uspfilter_test.go | 201 ++++++++++++ client/internal/acl/id/id.go | 2 +- client/internal/acl/manager.go | 54 +++- client/internal/acl/manager_test.go | 10 +- client/internal/debug/debug_linux.go | 32 ++ client/internal/dnsfwd/forwarder.go | 146 ++++++--- client/internal/dnsfwd/forwarder_test.go | 50 +-- client/internal/dnsfwd/manager.go | 35 +- client/internal/engine.go | 58 ++-- client/internal/peer/route.go | 20 +- client/internal/peer/status.go | 14 +- .../routemanager/dnsinterceptor/handler.go | 9 +- client/internal/routemanager/manager.go | 4 +- .../internal/routemanager/server_android.go | 2 +- .../routemanager/server_nonandroid.go | 88 ++--- .../routemanager/systemops/systemops_linux.go | 6 +- client/server/network.go | 2 +- client/status/status.go | 5 +- dns/dns.go | 3 +- go.mod | 21 +- go.sum | 43 ++- management/domain/domain.go | 12 +- management/domain/list.go | 5 +- management/domain/validate.go | 2 - management/server/types/account.go | 2 +- route/hauniqueid.go | 3 +- route/route.go | 36 ++- 45 files changed, 1399 insertions(+), 591 deletions(-) create mode 100644 client/firewall/manager/set.go diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 652ab1b3e..b229688fc 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -113,17 +113,16 @@ func (m *Manager) AddPeerFiltering( func (m *Manager) AddRouteFiltering( id []byte, sources []netip.Prefix, - destination netip.Prefix, + destination firewall.Network, proto firewall.Protocol, - sPort *firewall.Port, - dPort *firewall.Port, + sPort, dPort *firewall.Port, action firewall.Action, ) (firewall.Rule, error) { m.mutex.Lock() defer m.mutex.Unlock() - if !destination.Addr().Is4() { - return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) + if destination.IsPrefix() && !destination.Prefix.Addr().Is4() { + return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String()) } return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) @@ -243,6 +242,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { return m.router.DeleteDNATRule(rule) } +// UpdateSet updates the set with the given prefixes +func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.UpdateSet(set, prefixes) +} + func getConntrackEstablished() []string { return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} } diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 869b0b359..b59c88580 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -57,18 +57,18 @@ type ruleInfo struct { } type routeFilteringRuleParams struct { - Sources []netip.Prefix - Destination netip.Prefix + Source firewall.Network + Destination firewall.Network Proto firewall.Protocol SPort *firewall.Port DPort *firewall.Port Direction firewall.RuleDirection Action firewall.Action - SetName string } type routeRules map[string][]string +// the ipset library currently does not support comments, so we use the name only (string) type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}] type router struct { @@ -129,7 +129,7 @@ func (r *router) init(stateManager *statemanager.Manager) error { func (r *router) AddRouteFiltering( id []byte, sources []netip.Prefix, - destination netip.Prefix, + destination firewall.Network, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, @@ -140,27 +140,28 @@ func (r *router) AddRouteFiltering( return ruleKey, nil } - var setName string + var source firewall.Network if len(sources) > 1 { - setName = firewall.GenerateSetName(sources) - if _, err := r.ipsetCounter.Increment(setName, sources); err != nil { - return nil, fmt.Errorf("create or get ipset: %w", err) - } + source.Set = firewall.NewPrefixSet(sources) + } else if len(sources) > 0 { + source.Prefix = sources[0] } params := routeFilteringRuleParams{ - Sources: sources, + Source: source, Destination: destination, Proto: proto, SPort: sPort, DPort: dPort, Action: action, - SetName: setName, } - rule := genRouteFilteringRuleSpec(params) + rule, err := r.genRouteRuleSpec(params, sources) + if err != nil { + return nil, fmt.Errorf("generate route rule spec: %w", err) + } + // Insert DROP rules at the beginning, append ACCEPT rules at the end - var err error if action == firewall.ActionDrop { // after the established rule err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...) @@ -183,17 +184,13 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error { ruleKey := rule.ID() if rule, exists := r.rules[ruleKey]; exists { - setName := r.findSetNameInRule(rule) - if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil { return fmt.Errorf("delete route rule: %v", err) } delete(r.rules, ruleKey) - if setName != "" { - if _, err := r.ipsetCounter.Decrement(setName); err != nil { - return fmt.Errorf("failed to remove ipset: %w", err) - } + if err := r.decrementSetCounter(rule); err != nil { + return fmt.Errorf("decrement ipset counter: %w", err) } } else { log.Debugf("route rule %s not found", ruleKey) @@ -204,13 +201,26 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error { return nil } -func (r *router) findSetNameInRule(rule []string) string { - for i, arg := range rule { - if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet { - return rule[i+3] +func (r *router) decrementSetCounter(rule []string) error { + sets := r.findSets(rule) + var merr *multierror.Error + for _, setName := range sets { + if _, err := r.ipsetCounter.Decrement(setName); err != nil { + merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err)) } } - return "" + + return nberrors.FormatErrorOrNil(merr) +} + +func (r *router) findSets(rule []string) []string { + var sets []string + for i, arg := range rule { + if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet { + sets = append(sets, rule[i+3]) + } + } + return sets } func (r *router) createIpSet(setName string, sources []netip.Prefix) error { @@ -231,6 +241,8 @@ func (r *router) deleteIpSet(setName string) error { if err := ipset.Destroy(setName); err != nil { return fmt.Errorf("destroy set %s: %w", setName, err) } + + log.Debugf("Deleted unused ipset %s", setName) return nil } @@ -270,12 +282,14 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error { log.Errorf("%v", err) } - if err := r.removeNatRule(pair); err != nil { - return fmt.Errorf("remove nat rule: %w", err) - } + if pair.Masquerade { + if err := r.removeNatRule(pair); err != nil { + return fmt.Errorf("remove nat rule: %w", err) + } - if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { - return fmt.Errorf("remove inverse nat rule: %w", err) + if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { + return fmt.Errorf("remove inverse nat rule: %w", err) + } } if err := r.removeLegacyRouteRule(pair); err != nil { @@ -313,8 +327,10 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error { return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) } delete(r.rules, ruleKey) - } else { - log.Debugf("legacy forwarding rule %s not found", ruleKey) + + if err := r.decrementSetCounter(rule); err != nil { + return fmt.Errorf("decrement ipset counter: %w", err) + } } return nil @@ -599,12 +615,24 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { rule = append(rule, "-m", "conntrack", "--ctstate", "NEW", - "-s", pair.Source.String(), - "-d", pair.Destination.String(), + ) + sourceExp, err := r.applyNetwork("-s", pair.Source, nil) + if err != nil { + return fmt.Errorf("apply network -s: %w", err) + } + destExp, err := r.applyNetwork("-d", pair.Destination, nil) + if err != nil { + return fmt.Errorf("apply network -d: %w", err) + } + + rule = append(rule, sourceExp...) + rule = append(rule, destExp...) + rule = append(rule, "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue), ) if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil { + // TODO: rollback ipset counter return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err) } @@ -622,6 +650,10 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error { return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err) } delete(r.rules, ruleKey) + + if err := r.decrementSetCounter(rule); err != nil { + return fmt.Errorf("decrement ipset counter: %w", err) + } } else { log.Debugf("marking rule %s not found", ruleKey) } @@ -787,17 +819,21 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error { return nberrors.FormatErrorOrNil(merr) } -func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { +func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []netip.Prefix) ([]string, error) { var rule []string - if params.SetName != "" { - rule = append(rule, "-m", "set", matchSet, params.SetName, "src") - } else if len(params.Sources) > 0 { - source := params.Sources[0] - rule = append(rule, "-s", source.String()) + sourceExp, err := r.applyNetwork("-s", params.Source, sources) + if err != nil { + return nil, fmt.Errorf("apply network -s: %w", err) + + } + destExp, err := r.applyNetwork("-d", params.Destination, nil) + if err != nil { + return nil, fmt.Errorf("apply network -d: %w", err) } - rule = append(rule, "-d", params.Destination.String()) + rule = append(rule, sourceExp...) + rule = append(rule, destExp...) if params.Proto != firewall.ProtocolALL { rule = append(rule, "-p", strings.ToLower(string(params.Proto))) @@ -807,7 +843,47 @@ func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { rule = append(rule, "-j", actionToStr(params.Action)) - return rule + return rule, nil +} + +func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) { + direction := "src" + if flag == "-d" { + direction = "dst" + } + + if network.IsSet() { + if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil { + return nil, fmt.Errorf("create or get ipset: %w", err) + } + + return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil + } + if network.IsPrefix() { + return []string{flag, network.Prefix.String()}, nil + } + + // nolint:nilnil + return nil, nil +} + +func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { + var merr *multierror.Error + for _, prefix := range prefixes { + // TODO: Implement IPv6 support + if prefix.Addr().Is6() { + log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) + continue + } + if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err)) + } + } + if merr == nil { + log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes) + } + + return nberrors.FormatErrorOrNil(merr) } func applyPort(flag string, port *firewall.Port) []string { diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index dad77dee7..e9eeff863 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -60,8 +60,8 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { pair := firewall.RouterPair{ ID: "abc", - Source: netip.MustParsePrefix("100.100.100.1/32"), - Destination: netip.MustParsePrefix("100.100.100.0/24"), + Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")}, + Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.0/24")}, Masquerade: true, } @@ -332,7 +332,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) + ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action) require.NoError(t, err, "AddRouteFiltering failed") // Check if the rule is in the internal map @@ -347,23 +347,29 @@ func TestRouter_AddRouteFiltering(t *testing.T) { assert.NoError(t, err, "Failed to check rule existence") assert.True(t, exists, "Rule not found in iptables") + var source firewall.Network + if len(tt.sources) > 1 { + source.Set = firewall.NewPrefixSet(tt.sources) + } else if len(tt.sources) > 0 { + source.Prefix = tt.sources[0] + } // Verify rule content params := routeFilteringRuleParams{ - Sources: tt.sources, - Destination: tt.destination, + Source: source, + Destination: firewall.Network{Prefix: tt.destination}, Proto: tt.proto, SPort: tt.sPort, DPort: tt.dPort, Action: tt.action, - SetName: "", } - expectedRule := genRouteFilteringRuleSpec(params) + expectedRule, err := r.genRouteRuleSpec(params, nil) + require.NoError(t, err, "Failed to generate expected rule spec") if tt.expectSet { - setName := firewall.GenerateSetName(tt.sources) - params.SetName = setName - expectedRule = genRouteFilteringRuleSpec(params) + setName := firewall.NewPrefixSet(tt.sources).HashedName() + expectedRule, err = r.genRouteRuleSpec(params, nil) + require.NoError(t, err, "Failed to generate expected rule spec with set") // Check if the set was created _, exists := r.ipsetCounter.Get(setName) @@ -378,3 +384,62 @@ func TestRouter_AddRouteFiltering(t *testing.T) { }) } } + +func TestFindSetNameInRule(t *testing.T) { + r := &router{} + + testCases := []struct { + name string + rule []string + expected []string + }{ + { + name: "Basic rule with two sets", + rule: []string{ + "-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-m", "set", "--match-set", "nb-2e5a2a05", "src", + "-m", "set", "--match-set", "nb-349ae051", "dst", "-m", "tcp", "--dport", "8080", "-j", "ACCEPT", + }, + expected: []string{"nb-2e5a2a05", "nb-349ae051"}, + }, + { + name: "No sets", + rule: []string{"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-j", "ACCEPT"}, + expected: []string{}, + }, + { + name: "Multiple sets with different positions", + rule: []string{ + "-m", "set", "--match-set", "set1", "src", "-p", "tcp", + "-m", "set", "--match-set", "set-abc123", "dst", "-j", "ACCEPT", + }, + expected: []string{"set1", "set-abc123"}, + }, + { + name: "Boundary case - sequence appears at end", + rule: []string{"-p", "tcp", "-m", "set", "--match-set", "final-set"}, + expected: []string{"final-set"}, + }, + { + name: "Incomplete pattern - missing set name", + rule: []string{"-p", "tcp", "-m", "set", "--match-set"}, + expected: []string{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := r.findSets(tc.rule) + + if len(result) != len(tc.expected) { + t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result) + return + } + + for i, set := range result { + if set != tc.expected[i] { + t.Errorf("Expected set %q at position %d, got %q", tc.expected[i], i, set) + } + } + }) + } +} diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 1d71051ef..084d19423 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -1,13 +1,10 @@ package manager import ( - "crypto/sha256" - "encoding/hex" "fmt" "net" "net/netip" "sort" - "strings" log "github.com/sirupsen/logrus" @@ -43,6 +40,18 @@ const ( // Action is the action to be taken on a rule type Action int +// String returns the string representation of the action +func (a Action) String() string { + switch a { + case ActionAccept: + return "accept" + case ActionDrop: + return "drop" + default: + return "unknown" + } +} + const ( // ActionAccept is the action to accept a packet ActionAccept Action = iota @@ -50,6 +59,33 @@ const ( ActionDrop ) +// Network is a rule destination, either a set or a prefix +type Network struct { + Set Set + Prefix netip.Prefix +} + +// String returns the string representation of the destination +func (d Network) String() string { + if d.Prefix.IsValid() { + return d.Prefix.String() + } + if d.IsSet() { + return d.Set.HashedName() + } + return "" +} + +// IsSet returns true if the destination is a set +func (d Network) IsSet() bool { + return d.Set != Set{} +} + +// IsPrefix returns true if the destination is a valid prefix +func (d Network) IsPrefix() bool { + return d.Prefix.IsValid() +} + // Manager is the high level abstraction of a firewall manager // // It declares methods which handle actions required by the @@ -83,10 +119,9 @@ type Manager interface { AddRouteFiltering( id []byte, sources []netip.Prefix, - destination netip.Prefix, + destination Network, proto Protocol, - sPort *Port, - dPort *Port, + sPort, dPort *Port, action Action, ) (Rule, error) @@ -119,6 +154,9 @@ type Manager interface { // DeleteDNATRule deletes a DNAT rule DeleteDNATRule(Rule) error + + // UpdateSet updates the set with the given prefixes + UpdateSet(hash Set, prefixes []netip.Prefix) error } func GenKey(format string, pair RouterPair) string { @@ -153,22 +191,6 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error { return nil } -// GenerateSetName generates a unique name for an ipset based on the given sources. -func GenerateSetName(sources []netip.Prefix) string { - // sort for consistent naming - SortPrefixes(sources) - - var sourcesStr strings.Builder - for _, src := range sources { - sourcesStr.WriteString(src.String()) - } - - hash := sha256.Sum256([]byte(sourcesStr.String())) - shortHash := hex.EncodeToString(hash[:])[:8] - - return fmt.Sprintf("nb-%s", shortHash) -} - // MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix { if len(prefixes) == 0 { diff --git a/client/firewall/manager/firewall_test.go b/client/firewall/manager/firewall_test.go index 3f47d6679..180346906 100644 --- a/client/firewall/manager/firewall_test.go +++ b/client/firewall/manager/firewall_test.go @@ -20,8 +20,8 @@ func TestGenerateSetName(t *testing.T) { netip.MustParsePrefix("192.168.1.0/24"), } - result1 := manager.GenerateSetName(prefixes1) - result2 := manager.GenerateSetName(prefixes2) + result1 := manager.NewPrefixSet(prefixes1) + result2 := manager.NewPrefixSet(prefixes2) if result1 != result2 { t.Errorf("Different orders produced different hashes: %s != %s", result1, result2) @@ -34,9 +34,9 @@ func TestGenerateSetName(t *testing.T) { netip.MustParsePrefix("10.0.0.0/8"), } - result := manager.GenerateSetName(prefixes) + result := manager.NewPrefixSet(prefixes) - matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result) + matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result.HashedName()) if err != nil { t.Fatalf("Error matching regex: %v", err) } @@ -46,8 +46,8 @@ func TestGenerateSetName(t *testing.T) { }) t.Run("Empty input produces consistent result", func(t *testing.T) { - result1 := manager.GenerateSetName([]netip.Prefix{}) - result2 := manager.GenerateSetName([]netip.Prefix{}) + result1 := manager.NewPrefixSet([]netip.Prefix{}) + result2 := manager.NewPrefixSet([]netip.Prefix{}) if result1 != result2 { t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2) @@ -64,8 +64,8 @@ func TestGenerateSetName(t *testing.T) { netip.MustParsePrefix("192.168.1.0/24"), } - result1 := manager.GenerateSetName(prefixes1) - result2 := manager.GenerateSetName(prefixes2) + result1 := manager.NewPrefixSet(prefixes1) + result2 := manager.NewPrefixSet(prefixes2) if result1 != result2 { t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2) diff --git a/client/firewall/manager/routerpair.go b/client/firewall/manager/routerpair.go index 8c94b7dd4..079c051d9 100644 --- a/client/firewall/manager/routerpair.go +++ b/client/firewall/manager/routerpair.go @@ -1,15 +1,13 @@ package manager import ( - "net/netip" - "github.com/netbirdio/netbird/route" ) type RouterPair struct { ID route.ID - Source netip.Prefix - Destination netip.Prefix + Source Network + Destination Network Masquerade bool Inverse bool } diff --git a/client/firewall/manager/set.go b/client/firewall/manager/set.go new file mode 100644 index 000000000..4c88f6eac --- /dev/null +++ b/client/firewall/manager/set.go @@ -0,0 +1,74 @@ +package manager + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "net/netip" + "slices" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/domain" +) + +type Set struct { + hash [4]byte + comment string +} + +// String returns the string representation of the set: hashed name and comment +func (h Set) String() string { + if h.comment == "" { + return h.HashedName() + } + return h.HashedName() + ": " + h.comment +} + +// HashedName returns the string representation of the hash +func (h Set) HashedName() string { + return fmt.Sprintf( + "nb-%s", + hex.EncodeToString(h.hash[:]), + ) +} + +// Comment returns the comment of the set +func (h Set) Comment() string { + return h.comment +} + +// NewPrefixSet generates a unique name for an ipset based on the given prefixes. +func NewPrefixSet(prefixes []netip.Prefix) Set { + // sort for consistent naming + SortPrefixes(prefixes) + + hash := sha256.New() + for _, src := range prefixes { + bytes, err := src.MarshalBinary() + if err != nil { + log.Warnf("failed to marshal prefix %s: %v", src, err) + } + hash.Write(bytes) + } + var set Set + copy(set.hash[:], hash.Sum(nil)[:4]) + + return set +} + +// NewDomainSet generates a unique name for an ipset based on the given domains. +func NewDomainSet(domains domain.List) Set { + slices.Sort(domains) + + hash := sha256.New() + for _, d := range domains { + hash.Write([]byte(d.PunycodeString())) + } + set := Set{ + comment: domains.SafeString(), + } + copy(set.hash[:], hash.Sum(nil)[:4]) + + return set +} diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index a5809471c..e6b3a031b 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -135,17 +135,16 @@ func (m *Manager) AddPeerFiltering( func (m *Manager) AddRouteFiltering( id []byte, sources []netip.Prefix, - destination netip.Prefix, + destination firewall.Network, proto firewall.Protocol, - sPort *firewall.Port, - dPort *firewall.Port, + sPort, dPort *firewall.Port, action firewall.Action, ) (firewall.Rule, error) { m.mutex.Lock() defer m.mutex.Unlock() - if !destination.Addr().Is4() { - return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) + if destination.IsPrefix() && !destination.Prefix.Addr().Is4() { + return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String()) } return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) @@ -242,7 +241,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error { return firewall.SetLegacyManagement(m.router, isLegacy) } -// Reset firewall to the default state +// Close closes the firewall manager func (m *Manager) Close(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -359,6 +358,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { return m.router.DeleteDNATRule(rule) } +// UpdateSet updates the set with the given prefixes +func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.UpdateSet(set, prefixes) +} + func (m *Manager) createWorkTable() (*nftables.Table, error) { tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) if err != nil { diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 373743a08..602a6b8dc 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -289,7 +289,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { _, err = manager.AddRouteFiltering( nil, []netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")}, - netip.MustParsePrefix("10.1.0.0/24"), + fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{443}}, @@ -298,8 +298,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { require.NoError(t, err, "failed to add route filtering rule") pair := fw.RouterPair{ - Source: netip.MustParsePrefix("192.168.1.0/24"), - Destination: netip.MustParsePrefix("10.0.0.0/24"), + Source: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, + Destination: fw.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")}, Masquerade: true, } err = manager.AddNatRule(pair) diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index aff86dd90..c2ba2a072 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -10,7 +10,6 @@ import ( "strings" "github.com/coreos/go-iptables/iptables" - "github.com/davecgh/go-spew/spew" "github.com/google/nftables" "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" @@ -44,9 +43,14 @@ const ( const refreshRulesMapError = "refresh rules map: %w" var ( - errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found") + errFilterTableNotFound = fmt.Errorf("'filter' table not found") ) +type setInput struct { + set firewall.Set + prefixes []netip.Prefix +} + type router struct { conn *nftables.Conn workTable *nftables.Table @@ -54,7 +58,7 @@ type router struct { chains map[string]*nftables.Chain // rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules rules map[string]*nftables.Rule - ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set] + ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set] wgIface iFaceMapper ipFwdState *ipfwdstate.IPForwardingState @@ -163,7 +167,7 @@ func (r *router) removeNatPreroutingRules() error { func (r *router) loadFilterTable() (*nftables.Table, error) { tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) if err != nil { - return nil, fmt.Errorf("nftables: unable to list tables: %v", err) + return nil, fmt.Errorf("unable to list tables: %v", err) } for _, table := range tables { @@ -316,7 +320,7 @@ func (r *router) setupDataPlaneMark() error { func (r *router) AddRouteFiltering( id []byte, sources []netip.Prefix, - destination netip.Prefix, + destination firewall.Network, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, @@ -331,23 +335,29 @@ func (r *router) AddRouteFiltering( chain := r.chains[chainNameRoutingFw] var exprs []expr.Any + var source firewall.Network switch { case len(sources) == 1 && sources[0].Bits() == 0: // If it's 0.0.0.0/0, we don't need to add any source matching case len(sources) == 1: // If there's only one source, we can use it directly - exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...) + source.Prefix = sources[0] default: - // If there are multiple sources, create or get an ipset - var err error - exprs, err = r.getIpSetExprs(sources, exprs) - if err != nil { - return nil, fmt.Errorf("get ipset expressions: %w", err) - } + // If there are multiple sources, use a set + source.Set = firewall.NewPrefixSet(sources) } - // Handle destination - exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...) + sourceExp, err := r.applyNetwork(source, sources, true) + if err != nil { + return nil, fmt.Errorf("apply source: %w", err) + } + exprs = append(exprs, sourceExp...) + + destExp, err := r.applyNetwork(destination, nil, false) + if err != nil { + return nil, fmt.Errorf("apply destination: %w", err) + } + exprs = append(exprs, destExp...) // Handle protocol if proto != firewall.ProtocolALL { @@ -391,39 +401,27 @@ func (r *router) AddRouteFiltering( rule = r.conn.AddRule(rule) } - log.Tracef("Adding route rule %s", spew.Sdump(rule)) if err := r.conn.Flush(); err != nil { return nil, fmt.Errorf(flushError, err) } r.rules[string(ruleKey)] = rule - log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action) + log.Debugf("added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action) return ruleKey, nil } -func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) { - setName := firewall.GenerateSetName(sources) - ref, err := r.ipsetCounter.Increment(setName, sources) +func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bool) ([]expr.Any, error) { + ref, err := r.ipsetCounter.Increment(set.HashedName(), setInput{ + set: set, + prefixes: prefixes, + }) if err != nil { - return nil, fmt.Errorf("create or get ipset for sources: %w", err) + return nil, fmt.Errorf("create or get ipset: %w", err) } - exprs = append(exprs, - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - &expr.Lookup{ - SourceRegister: 1, - SetName: ref.Out.Name, - SetID: ref.Out.ID, - }, - ) - return exprs, nil + return getIpSetExprs(ref, isSource) } func (r *router) DeleteRouteRule(rule firewall.Rule) error { @@ -442,42 +440,54 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error { return fmt.Errorf("route rule %s has no handle", ruleKey) } - setName := r.findSetNameInRule(nftRule) - if err := r.deleteNftRule(nftRule, ruleKey); err != nil { return fmt.Errorf("delete: %w", err) } - if setName != "" { - if _, err := r.ipsetCounter.Decrement(setName); err != nil { - return fmt.Errorf("decrement ipset reference: %w", err) - } - } - if err := r.conn.Flush(); err != nil { return fmt.Errorf(flushError, err) } + if err := r.decrementSetCounter(nftRule); err != nil { + return fmt.Errorf("decrement set counter: %w", err) + } + return nil } -func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) { +func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, error) { // overlapping prefixes will result in an error, so we need to merge them - sources = firewall.MergeIPRanges(sources) + prefixes := firewall.MergeIPRanges(input.prefixes) - set := &nftables.Set{ - Name: setName, - Table: r.workTable, + nfset := &nftables.Set{ + Name: setName, + Comment: input.set.Comment(), + Table: r.workTable, // required for prefixes Interval: true, KeyType: nftables.TypeIPAddr, } + elements := convertPrefixesToSet(prefixes) + if err := r.conn.AddSet(nfset, elements); err != nil { + return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err) + } + + if err := r.conn.Flush(); err != nil { + return nil, fmt.Errorf("flush error: %w", err) + } + + log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2) + + return nfset, nil +} + +func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement { var elements []nftables.SetElement - for _, prefix := range sources { + for _, prefix := range prefixes { // TODO: Implement IPv6 support if prefix.Addr().Is6() { - log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) + log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) continue } @@ -493,18 +503,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables. nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true}, ) } - - if err := r.conn.AddSet(set, elements); err != nil { - return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err) - } - - if err := r.conn.Flush(); err != nil { - return nil, fmt.Errorf("flush error: %w", err) - } - - log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2) - - return set, nil + return elements } // calculateLastIP determines the last IP in a given prefix. @@ -528,8 +527,8 @@ func uint32ToBytes(ip uint32) [4]byte { return b } -func (r *router) deleteIpSet(setName string, set *nftables.Set) error { - r.conn.DelSet(set) +func (r *router) deleteIpSet(setName string, nfset *nftables.Set) error { + r.conn.DelSet(nfset) if err := r.conn.Flush(); err != nil { return fmt.Errorf(flushError, err) } @@ -538,13 +537,27 @@ func (r *router) deleteIpSet(setName string, set *nftables.Set) error { return nil } -func (r *router) findSetNameInRule(rule *nftables.Rule) string { - for _, e := range rule.Exprs { - if lookup, ok := e.(*expr.Lookup); ok { - return lookup.SetName +func (r *router) decrementSetCounter(rule *nftables.Rule) error { + sets := r.findSets(rule) + + var merr *multierror.Error + for _, setName := range sets { + if _, err := r.ipsetCounter.Decrement(setName); err != nil { + merr = multierror.Append(merr, fmt.Errorf("decrement set counter: %w", err)) } } - return "" + + return nberrors.FormatErrorOrNil(merr) +} + +func (r *router) findSets(rule *nftables.Rule) []string { + var sets []string + for _, e := range rule.Exprs { + if lookup, ok := e.(*expr.Lookup); ok { + sets = append(sets, lookup.SetName) + } + } + return sets } func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error { @@ -586,7 +599,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error { } if err := r.conn.Flush(); err != nil { - return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err) + // TODO: rollback ipset counter + return fmt.Errorf("insert rules for %s: %v", pair.Destination, err) } return nil @@ -594,19 +608,22 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error { // addNatRule inserts a nftables rule to the conn client flush queue func (r *router) addNatRule(pair firewall.RouterPair) error { - sourceExp := generateCIDRMatcherExpressions(true, pair.Source) - destExp := generateCIDRMatcherExpressions(false, pair.Destination) + sourceExp, err := r.applyNetwork(pair.Source, nil, true) + if err != nil { + return fmt.Errorf("apply source: %w", err) + } + + destExp, err := r.applyNetwork(pair.Destination, nil, false) + if err != nil { + return fmt.Errorf("apply destination: %w", err) + } op := expr.CmpOpEq if pair.Inverse { op = expr.CmpOpNeq } - // We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading. - // Masquerading will take care of the conntrack state, which means we won't need to mark established connections. - exprs := getCtNewExprs() - exprs = append(exprs, - // interface matching + exprs := []expr.Any{ &expr.Meta{ Key: expr.MetaKeyIIFNAME, Register: 1, @@ -616,7 +633,10 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { Register: 1, Data: ifname(r.wgIface.Name()), }, - ) + } + // We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading. + // Masquerading will take care of the conntrack state, which means we won't need to mark established connections. + exprs = append(exprs, getCtNewExprs()...) exprs = append(exprs, sourceExp...) exprs = append(exprs, destExp...) @@ -729,8 +749,15 @@ func (r *router) addPostroutingRules() error { // addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { - sourceExp := generateCIDRMatcherExpressions(true, pair.Source) - destExp := generateCIDRMatcherExpressions(false, pair.Destination) + sourceExp, err := r.applyNetwork(pair.Source, nil, true) + if err != nil { + return fmt.Errorf("apply source: %w", err) + } + + destExp, err := r.applyNetwork(pair.Destination, nil, false) + if err != nil { + return fmt.Errorf("apply destination: %w", err) + } exprs := []expr.Any{ &expr.Counter{}, @@ -739,7 +766,8 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { }, } - expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic + exprs = append(exprs, sourceExp...) + exprs = append(exprs, destExp...) ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) @@ -752,7 +780,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ Table: r.workTable, Chain: r.chains[chainNameRoutingFw], - Exprs: expression, + Exprs: exprs, UserData: []byte(ruleKey), }) return nil @@ -767,11 +795,13 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error { return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) } - log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination) + log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination) delete(r.rules, ruleKey) - } else { - log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey) + + if err := r.decrementSetCounter(rule); err != nil { + return fmt.Errorf("decrement set counter: %w", err) + } } return nil @@ -982,12 +1012,14 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error { return fmt.Errorf(refreshRulesMapError, err) } - if err := r.removeNatRule(pair); err != nil { - return fmt.Errorf("remove prerouting rule: %w", err) - } + if pair.Masquerade { + if err := r.removeNatRule(pair); err != nil { + return fmt.Errorf("remove prerouting rule: %w", err) + } - if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { - return fmt.Errorf("remove inverse prerouting rule: %w", err) + if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { + return fmt.Errorf("remove inverse prerouting rule: %w", err) + } } if err := r.removeLegacyRouteRule(pair); err != nil { @@ -995,10 +1027,10 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error { } if err := r.conn.Flush(); err != nil { - return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err) + // TODO: rollback set counter + return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err) } - log.Debugf("nftables: removed nat rules for %s", pair.Destination) return nil } @@ -1006,16 +1038,19 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error { ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair) if rule, exists := r.rules[ruleKey]; exists { - err := r.conn.DelRule(rule) - if err != nil { + if err := r.conn.DelRule(rule); err != nil { return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err) } - log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination) + log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination) delete(r.rules, ruleKey) + + if err := r.decrementSetCounter(rule); err != nil { + return fmt.Errorf("decrement set counter: %w", err) + } } else { - log.Debugf("nftables: prerouting rule %s not found", ruleKey) + log.Debugf("prerouting rule %s not found", ruleKey) } return nil @@ -1027,7 +1062,7 @@ func (r *router) refreshRulesMap() error { for _, chain := range r.chains { rules, err := r.conn.GetRules(chain.Table, chain) if err != nil { - return fmt.Errorf("nftables: unable to list rules: %v", err) + return fmt.Errorf(" unable to list rules: %v", err) } for _, rule := range rules { if len(rule.UserData) > 0 { @@ -1301,13 +1336,54 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error { return nberrors.FormatErrorOrNil(merr) } -// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR -func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any { - var offset uint32 - if source { - offset = 12 // src offset - } else { - offset = 16 // dst offset +func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { + nfset, err := r.conn.GetSetByName(r.workTable, set.HashedName()) + if err != nil { + return fmt.Errorf("get set %s: %w", set.HashedName(), err) + } + + elements := convertPrefixesToSet(prefixes) + if err := r.conn.SetAddElements(nfset, elements); err != nil { + return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err) + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes) + + return nil +} + +// applyNetwork generates nftables expressions for networks (CIDR) or sets +func (r *router) applyNetwork( + network firewall.Network, + setPrefixes []netip.Prefix, + isSource bool, +) ([]expr.Any, error) { + if network.IsSet() { + exprs, err := r.getIpSet(network.Set, setPrefixes, isSource) + if err != nil { + return nil, fmt.Errorf("source: %w", err) + } + return exprs, nil + } + + if network.IsPrefix() { + return applyPrefix(network.Prefix, isSource), nil + } + + return nil, nil +} + +// applyPrefix generates nftables expressions for a CIDR prefix +func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any { + // dst offset + offset := uint32(16) + if isSource { + // src offset + offset = 12 } ones := prefix.Bits() @@ -1415,3 +1491,27 @@ func getCtNewExprs() []expr.Any { }, } } + +func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) { + + // dst offset + offset := uint32(16) + if isSource { + // src offset + offset = 12 + } + + return []expr.Any{ + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: offset, + Len: 4, + }, + &expr.Lookup{ + SourceRegister: 1, + SetName: ref.Out.Name, + SetID: ref.Out.ID, + }, + }, nil +} diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index 28baef4dd..4fdbf3505 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -88,8 +88,8 @@ func TestNftablesManager_AddNatRule(t *testing.T) { } // Build CIDR matching expressions - sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) - destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) + sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true) + destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false) // Combine all expressions in the correct order // nolint:gocritic @@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) + ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action) require.NoError(t, err, "AddRouteFiltering failed") t.Cleanup(func() { @@ -441,8 +441,8 @@ func TestNftablesCreateIpSet(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - setName := firewall.GenerateSetName(tt.sources) - set, err := r.createIpSet(setName, tt.sources) + setName := firewall.NewPrefixSet(tt.sources).HashedName() + set, err := r.createIpSet(setName, setInput{prefixes: tt.sources}) if err != nil { t.Logf("Failed to create IP set: %v", err) printNftSets() diff --git a/client/firewall/test/cases_linux.go b/client/firewall/test/cases_linux.go index 267e93efd..59a370a97 100644 --- a/client/firewall/test/cases_linux.go +++ b/client/firewall/test/cases_linux.go @@ -15,8 +15,8 @@ var ( Name: "Insert Forwarding IPV4 Rule", InputPair: firewall.RouterPair{ ID: "zxa", - Source: netip.MustParsePrefix("100.100.100.1/32"), - Destination: netip.MustParsePrefix("100.100.200.0/24"), + Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")}, + Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")}, Masquerade: false, }, }, @@ -24,8 +24,8 @@ var ( Name: "Insert Forwarding And Nat IPV4 Rules", InputPair: firewall.RouterPair{ ID: "zxa", - Source: netip.MustParsePrefix("100.100.100.1/32"), - Destination: netip.MustParsePrefix("100.100.200.0/24"), + Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")}, + Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")}, Masquerade: true, }, }, @@ -40,8 +40,8 @@ var ( Name: "Remove Forwarding And Nat IPV4 Rules", InputPair: firewall.RouterPair{ ID: "zxa", - Source: netip.MustParsePrefix("100.100.100.1/32"), - Destination: netip.MustParsePrefix("100.100.200.0/24"), + Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")}, + Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")}, Masquerade: true, }, }, diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index 5fe698aa9..ce04c82c7 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -12,7 +12,7 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -// Reset firewall to the default state +// Close cleans up the firewall manager by removing all rules and closing trackers func (m *Manager) Close(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index f63792fec..f261c472f 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -10,7 +10,6 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -22,7 +21,7 @@ const ( firewallRuleName = "Netbird" ) -// Reset firewall to the default state +// Close cleans up the firewall manager by removing all rules and closing trackers func (m *Manager) Close(*statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -32,17 +31,14 @@ func (m *Manager) Close(*statemanager.Manager) error { if m.udpTracker != nil { m.udpTracker.Close() - m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger) } if m.icmpTracker != nil { m.icmpTracker.Close() - m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger) } if m.tcpTracker != nil { m.tcpTracker.Close() - m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger) } if fwder := m.forwarder.Load(); fwder != nil { diff --git a/client/firewall/uspfilter/rule.go b/client/firewall/uspfilter/rule.go index a23d2011b..b765c72e9 100644 --- a/client/firewall/uspfilter/rule.go +++ b/client/firewall/uspfilter/rule.go @@ -29,14 +29,15 @@ func (r *PeerRule) ID() string { } type RouteRule struct { - id string - mgmtId []byte - sources []netip.Prefix - destination netip.Prefix - proto firewall.Protocol - srcPort *firewall.Port - dstPort *firewall.Port - action firewall.Action + id string + mgmtId []byte + sources []netip.Prefix + dstSet firewall.Set + destinations []netip.Prefix + proto firewall.Protocol + srcPort *firewall.Port + dstPort *firewall.Port + action firewall.Action } // ID returns the rule id diff --git a/client/firewall/uspfilter/tracer_test.go b/client/firewall/uspfilter/tracer_test.go index 48b0ec44d..53ee6c886 100644 --- a/client/firewall/uspfilter/tracer_test.go +++ b/client/firewall/uspfilter/tracer_test.go @@ -199,7 +199,7 @@ func TestTracePacket(t *testing.T) { src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32) dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32) - _, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept) + _, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept) require.NoError(t, err) }, packetBuilder: func() *PacketBuilder { @@ -223,7 +223,7 @@ func TestTracePacket(t *testing.T) { src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32) dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32) - _, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop) + _, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop) require.NoError(t, err) }, packetBuilder: func() *PacketBuilder { diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 466c6a18b..ccf0be225 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -49,10 +49,10 @@ var errNatNotSupported = errors.New("nat not supported with userspace firewall") // RuleSet is a set of rules grouped by a string key type RuleSet map[string]PeerRule -type RouteRules []RouteRule +type RouteRules []*RouteRule func (r RouteRules) Sort() { - slices.SortStableFunc(r, func(a, b RouteRule) int { + slices.SortStableFunc(r, func(a, b *RouteRule) int { // Deny rules come first if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop { return -1 @@ -99,6 +99,8 @@ type Manager struct { forwarder atomic.Pointer[forwarder.Forwarder] logger *nblog.Logger flowLogger nftypes.FlowLogger + + blockRule firewall.Rule } // decoder for packages @@ -201,41 +203,35 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe } } - if err := m.blockInvalidRouted(iface); err != nil { - log.Errorf("failed to block invalid routed traffic: %v", err) - } - if err := iface.SetFilter(m); err != nil { return nil, fmt.Errorf("set filter: %w", err) } return m, nil } -func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error { - if m.forwarder.Load() == nil { - return nil - } +func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) { wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String()) if err != nil { - return fmt.Errorf("parse wireguard network: %w", err) + return nil, fmt.Errorf("parse wireguard network: %w", err) } log.Debugf("blocking invalid routed traffic for %s", wgPrefix) - if _, err := m.AddRouteFiltering( + rule, err := m.addRouteFiltering( nil, []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}, - wgPrefix, + firewall.Network{Prefix: wgPrefix}, firewall.ProtocolALL, nil, nil, firewall.ActionDrop, - ); err != nil { - return fmt.Errorf("block wg nte : %w", err) + ) + if err != nil { + return nil, fmt.Errorf("block wg nte : %w", err) } // TODO: Block networks that we're a client of - return nil + return rule, nil } func (m *Manager) determineRouting() error { @@ -413,10 +409,23 @@ func (m *Manager) AddPeerFiltering( func (m *Manager) AddRouteFiltering( id []byte, sources []netip.Prefix, - destination netip.Prefix, + destination firewall.Network, proto firewall.Protocol, - sPort *firewall.Port, - dPort *firewall.Port, + sPort, dPort *firewall.Port, + action firewall.Action, +) (firewall.Rule, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.addRouteFiltering(id, sources, destination, proto, sPort, dPort, action) +} + +func (m *Manager) addRouteFiltering( + id []byte, + sources []netip.Prefix, + destination firewall.Network, + proto firewall.Protocol, + sPort, dPort *firewall.Port, action firewall.Action, ) (firewall.Rule, error) { if m.nativeRouter.Load() && m.nativeFirewall != nil { @@ -426,34 +435,39 @@ func (m *Manager) AddRouteFiltering( ruleID := uuid.New().String() rule := RouteRule{ // TODO: consolidate these IDs - id: ruleID, - mgmtId: id, - sources: sources, - destination: destination, - proto: proto, - srcPort: sPort, - dstPort: dPort, - action: action, + id: ruleID, + mgmtId: id, + sources: sources, + dstSet: destination.Set, + proto: proto, + srcPort: sPort, + dstPort: dPort, + action: action, + } + if destination.IsPrefix() { + rule.destinations = []netip.Prefix{destination.Prefix} } - m.mutex.Lock() - m.routeRules = append(m.routeRules, rule) + m.routeRules = append(m.routeRules, &rule) m.routeRules.Sort() - m.mutex.Unlock() return &rule, nil } func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.deleteRouteRule(rule) +} + +func (m *Manager) deleteRouteRule(rule firewall.Rule) error { if m.nativeRouter.Load() && m.nativeFirewall != nil { return m.nativeFirewall.DeleteRouteRule(rule) } - m.mutex.Lock() - defer m.mutex.Unlock() - ruleID := rule.ID() - idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool { + idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool { return r.id == ruleID }) if idx < 0 { @@ -509,6 +523,52 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { return m.nativeFirewall.DeleteDNATRule(rule) } +// UpdateSet updates the rule destinations associated with the given set +// by merging the existing prefixes with the new ones, then deduplicating. +func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { + if m.nativeRouter.Load() && m.nativeFirewall != nil { + return m.nativeFirewall.UpdateSet(set, prefixes) + } + + m.mutex.Lock() + defer m.mutex.Unlock() + + var matches []*RouteRule + for _, rule := range m.routeRules { + if rule.dstSet == set { + matches = append(matches, rule) + } + } + + if len(matches) == 0 { + return fmt.Errorf("no route rule found for set: %s", set) + } + + destinations := matches[0].destinations + for _, prefix := range prefixes { + if prefix.Addr().Is4() { + destinations = append(destinations, prefix) + } + } + + slices.SortFunc(destinations, func(a, b netip.Prefix) int { + cmp := a.Addr().Compare(b.Addr()) + if cmp != 0 { + return cmp + } + return a.Bits() - b.Bits() + }) + + destinations = slices.Compact(destinations) + + for _, rule := range matches { + rule.destinations = destinations + } + log.Debugf("updated set %s to prefixes %v", set.HashedName(), destinations) + + return nil +} + // DropOutgoing filter outgoing packets func (m *Manager) DropOutgoing(packetData []byte, size int) bool { return m.processOutgoingHooks(packetData, size) @@ -988,8 +1048,15 @@ func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol return nil, false } -func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool { - if !rule.destination.Contains(dstAddr) { +func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool { + destMatched := false + for _, dst := range rule.destinations { + if dst.Contains(dstAddr) { + destMatched = true + break + } + } + if !destMatched { return false } @@ -1091,7 +1158,22 @@ func (m *Manager) EnableRouting() error { m.mutex.Lock() defer m.mutex.Unlock() - return m.determineRouting() + if err := m.determineRouting(); err != nil { + return fmt.Errorf("determine routing: %w", err) + } + + if m.forwarder.Load() == nil { + return nil + } + + rule, err := m.blockInvalidRouted(m.wgIface) + if err != nil { + return fmt.Errorf("block invalid routed: %w", err) + } + + m.blockRule = rule + + return nil } func (m *Manager) DisableRouting() error { @@ -1116,5 +1198,12 @@ func (m *Manager) DisableRouting() error { log.Debug("forwarder stopped") + if m.blockRule != nil { + if err := m.deleteRouteRule(m.blockRule); err != nil { + return fmt.Errorf("delete block rule: %w", err) + } + m.blockRule = nil + } + return nil } diff --git a/client/firewall/uspfilter/uspfilter_filter_test.go b/client/firewall/uspfilter/uspfilter_filter_test.go index 9c0a54e3f..04a398d1f 100644 --- a/client/firewall/uspfilter/uspfilter_filter_test.go +++ b/client/firewall/uspfilter/uspfilter_filter_test.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/management/domain" ) func TestPeerACLFiltering(t *testing.T) { @@ -600,8 +601,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager { } manager, err := Create(ifaceMock, false, flowLogger) - require.NoError(tb, manager.EnableRouting()) require.NoError(tb, err) + require.NoError(tb, manager.EnableRouting()) require.NotNil(tb, manager) require.True(tb, manager.routingEnabled.Load()) require.False(tb, manager.nativeRouter.Load()) @@ -618,7 +619,7 @@ func TestRouteACLFiltering(t *testing.T) { type rule struct { sources []netip.Prefix - dest netip.Prefix + dest fw.Network proto fw.Protocol srcPort *fw.Port dstPort *fw.Port @@ -644,7 +645,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 443, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{443}}, action: fw.ActionAccept, @@ -660,7 +661,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 443, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{443}}, action: fw.ActionAccept, @@ -676,7 +677,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 443, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, - dest: netip.MustParsePrefix("0.0.0.0/0"), + dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{443}}, action: fw.ActionAccept, @@ -692,7 +693,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 53, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolUDP, dstPort: &fw.Port{Values: []uint16{53}}, action: fw.ActionAccept, @@ -706,7 +707,7 @@ func TestRouteACLFiltering(t *testing.T) { proto: fw.ProtocolICMP, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("0.0.0.0/0"), + dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")}, proto: fw.ProtocolICMP, action: fw.ActionAccept, }, @@ -721,7 +722,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolALL, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionAccept, @@ -737,7 +738,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 8080, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionAccept, @@ -753,7 +754,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionAccept, @@ -769,7 +770,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionAccept, @@ -785,7 +786,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, srcPort: &fw.Port{Values: []uint16{12345}}, action: fw.ActionAccept, @@ -804,7 +805,7 @@ func TestRouteACLFiltering(t *testing.T) { netip.MustParsePrefix("100.10.0.0/16"), netip.MustParsePrefix("172.16.0.0/16"), }, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionAccept, @@ -818,7 +819,7 @@ func TestRouteACLFiltering(t *testing.T) { proto: fw.ProtocolICMP, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolALL, action: fw.ActionAccept, }, @@ -833,7 +834,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolALL, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionAccept, @@ -849,7 +850,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 8080, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80, 8080, 443}}, action: fw.ActionAccept, @@ -865,7 +866,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, srcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}}, action: fw.ActionAccept, @@ -881,7 +882,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolALL, srcPort: &fw.Port{Values: []uint16{12345}}, dstPort: &fw.Port{Values: []uint16{80}}, @@ -898,7 +899,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 8080, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{ IsRange: true, @@ -917,7 +918,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 7999, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{ IsRange: true, @@ -936,7 +937,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, srcPort: &fw.Port{ IsRange: true, @@ -955,7 +956,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 443, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, srcPort: &fw.Port{ IsRange: true, @@ -977,7 +978,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 8100, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{ IsRange: true, @@ -996,7 +997,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 5060, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolUDP, dstPort: &fw.Port{ IsRange: true, @@ -1015,7 +1016,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 8080, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolALL, dstPort: &fw.Port{ IsRange: true, @@ -1034,7 +1035,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 443, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{443}}, action: fw.ActionDrop, @@ -1050,7 +1051,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolALL, action: fw.ActionDrop, }, @@ -1068,13 +1069,32 @@ func TestRouteACLFiltering(t *testing.T) { netip.MustParsePrefix("100.10.0.0/16"), netip.MustParsePrefix("172.16.0.0/16"), }, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionDrop, }, shouldPass: false, }, + + { + name: "Drop empty destination set", + srcIP: "172.16.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), + }, + dest: fw.Network{Set: fw.Set{}}, + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionAccept, + }, + shouldPass: false, + }, { name: "Accept TCP traffic outside drop port range", srcIP: "100.10.0.1", @@ -1084,7 +1104,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 7999, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}}, action: fw.ActionDrop, @@ -1100,7 +1120,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 443, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, action: fw.ActionAccept, }, @@ -1115,7 +1135,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 53, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolUDP, action: fw.ActionAccept, }, @@ -1130,7 +1150,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolUDP, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionAccept, @@ -1146,7 +1166,7 @@ func TestRouteACLFiltering(t *testing.T) { dstPort: 80, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionAccept, @@ -1160,7 +1180,7 @@ func TestRouteACLFiltering(t *testing.T) { proto: fw.ProtocolICMP, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, action: fw.ActionAccept, }, @@ -1173,7 +1193,7 @@ func TestRouteACLFiltering(t *testing.T) { proto: fw.ProtocolICMP, rule: rule{ sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolUDP, action: fw.ActionAccept, }, @@ -1188,7 +1208,7 @@ func TestRouteACLFiltering(t *testing.T) { rule, err := manager.AddRouteFiltering( nil, []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, - netip.MustParsePrefix("0.0.0.0/0"), + fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")}, fw.ProtocolALL, nil, nil, @@ -1235,7 +1255,7 @@ func TestRouteACLOrder(t *testing.T) { name string rules []struct { sources []netip.Prefix - dest netip.Prefix + dest fw.Network proto fw.Protocol srcPort *fw.Port dstPort *fw.Port @@ -1256,7 +1276,7 @@ func TestRouteACLOrder(t *testing.T) { name: "Drop rules take precedence over accept", rules: []struct { sources []netip.Prefix - dest netip.Prefix + dest fw.Network proto fw.Protocol srcPort *fw.Port dstPort *fw.Port @@ -1265,7 +1285,7 @@ func TestRouteACLOrder(t *testing.T) { { // Accept rule added first sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80, 443}}, action: fw.ActionAccept, @@ -1273,7 +1293,7 @@ func TestRouteACLOrder(t *testing.T) { { // Drop rule added second but should be evaluated first sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{443}}, action: fw.ActionDrop, @@ -1311,7 +1331,7 @@ func TestRouteACLOrder(t *testing.T) { name: "Multiple drop rules take precedence", rules: []struct { sources []netip.Prefix - dest netip.Prefix + dest fw.Network proto fw.Protocol srcPort *fw.Port dstPort *fw.Port @@ -1320,14 +1340,14 @@ func TestRouteACLOrder(t *testing.T) { { // Accept all sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, - dest: netip.MustParsePrefix("0.0.0.0/0"), + dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")}, proto: fw.ProtocolALL, action: fw.ActionAccept, }, { // Drop specific port sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{443}}, action: fw.ActionDrop, @@ -1335,7 +1355,7 @@ func TestRouteACLOrder(t *testing.T) { { // Drop different port sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, - dest: netip.MustParsePrefix("192.168.1.0/24"), + dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, proto: fw.ProtocolTCP, dstPort: &fw.Port{Values: []uint16{80}}, action: fw.ActionDrop, @@ -1414,3 +1434,53 @@ func TestRouteACLOrder(t *testing.T) { }) } } + +func TestRouteACLSet(t *testing.T) { + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ + IP: net.ParseIP("100.10.0.100"), + Network: &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + }, + } + }, + } + + manager, err := Create(ifaceMock, false, flowLogger) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, manager.Close(nil)) + }) + + set := fw.NewDomainSet(domain.List{"example.org"}) + + // Add rule that uses the set (initially empty) + rule, err := manager.AddRouteFiltering( + nil, + []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + fw.Network{Set: set}, + fw.ProtocolTCP, + nil, + nil, + fw.ActionAccept, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + srcIP := netip.MustParseAddr("100.10.0.1") + dstIP := netip.MustParseAddr("192.168.1.100") + + // Check that traffic is dropped (empty set shouldn't match anything) + _, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80) + require.False(t, isAllowed, "Empty set should not allow any traffic") + + err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}) + require.NoError(t, err) + + // Now the packet should be allowed + _, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80) + require.True(t, isAllowed, "After set update, traffic to the added network should be allowed") +} diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index a48a483f8..24a6a2c40 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/netflow" + "github.com/netbirdio/netbird/management/domain" ) var logger = log.NewFromLogrus(logrus.StandardLogger()) @@ -711,3 +712,203 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { }) } } + +func TestUpdateSetMerge(t *testing.T) { + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + } + + manager, err := Create(ifaceMock, false, flowLogger) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, manager.Close(nil)) + }) + + set := fw.NewDomainSet(domain.List{"example.org"}) + + initialPrefixes := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), + } + + rule, err := manager.AddRouteFiltering( + nil, + []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + fw.Network{Set: set}, + fw.ProtocolTCP, + nil, + nil, + fw.ActionAccept, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + // Update the set with initial prefixes + err = manager.UpdateSet(set, initialPrefixes) + require.NoError(t, err) + + // Test initial prefixes work + srcIP := netip.MustParseAddr("100.10.0.1") + dstIP1 := netip.MustParseAddr("10.0.0.100") + dstIP2 := netip.MustParseAddr("192.168.1.100") + dstIP3 := netip.MustParseAddr("172.16.0.100") + + _, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80) + _, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80) + _, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80) + + require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed") + require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed") + require.False(t, isAllowed3, "Traffic to 172.16.0.100 should be denied") + + newPrefixes := []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), + netip.MustParsePrefix("10.1.0.0/24"), + } + + err = manager.UpdateSet(set, newPrefixes) + require.NoError(t, err) + + // Check that all original prefixes are still included + _, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80) + _, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80) + require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update") + require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update") + + // Check that new prefixes are included + dstIP4 := netip.MustParseAddr("172.16.1.100") + dstIP5 := netip.MustParseAddr("10.1.0.50") + + _, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80) + _, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80) + + require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed") + require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed") + + // Verify the rule has all prefixes + manager.mutex.RLock() + foundRule := false + for _, r := range manager.routeRules { + if r.id == rule.ID() { + foundRule = true + require.Len(t, r.destinations, len(initialPrefixes)+len(newPrefixes), + "Rule should have all prefixes merged") + } + } + manager.mutex.RUnlock() + require.True(t, foundRule, "Rule should be found") +} + +func TestUpdateSetDeduplication(t *testing.T) { + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + } + + manager, err := Create(ifaceMock, false, flowLogger) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, manager.Close(nil)) + }) + + set := fw.NewDomainSet(domain.List{"example.org"}) + + rule, err := manager.AddRouteFiltering( + nil, + []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + fw.Network{Set: set}, + fw.ProtocolTCP, + nil, + nil, + fw.ActionAccept, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + initialPrefixes := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("10.0.0.0/24"), // Duplicate + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), // Duplicate + } + + err = manager.UpdateSet(set, initialPrefixes) + require.NoError(t, err) + + // Check the internal state for deduplication + manager.mutex.RLock() + foundRule := false + for _, r := range manager.routeRules { + if r.id == rule.ID() { + foundRule = true + // Should have deduplicated to 2 prefixes + require.Len(t, r.destinations, 2, "Duplicate prefixes should be removed") + + // Check the prefixes are correct + expectedPrefixes := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), + } + for i, prefix := range expectedPrefixes { + require.True(t, r.destinations[i] == prefix, + "Prefix should match expected value") + } + } + } + manager.mutex.RUnlock() + require.True(t, foundRule, "Rule should be found") + + // Test with overlapping prefixes of different sizes + overlappingPrefixes := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/16"), // More general + netip.MustParsePrefix("10.0.0.0/24"), // More specific (already exists) + netip.MustParsePrefix("192.168.0.0/16"), // More general + netip.MustParsePrefix("192.168.1.0/24"), // More specific (already exists) + } + + err = manager.UpdateSet(set, overlappingPrefixes) + require.NoError(t, err) + + // Check that all prefixes are included (no deduplication of overlapping prefixes) + manager.mutex.RLock() + for _, r := range manager.routeRules { + if r.id == rule.ID() { + // Should have all 4 prefixes (2 original + 2 new more general ones) + require.Len(t, r.destinations, 4, + "Overlapping prefixes should not be deduplicated") + + // Verify they're sorted correctly (more specific prefixes should come first) + prefixes := make([]string, 0, len(r.destinations)) + for _, p := range r.destinations { + prefixes = append(prefixes, p.String()) + } + + // Check sorted order + require.Equal(t, []string{ + "10.0.0.0/16", + "10.0.0.0/24", + "192.168.0.0/16", + "192.168.1.0/24", + }, prefixes, "Prefixes should be sorted") + } + } + manager.mutex.RUnlock() + + // Test functionality with all prefixes + testCases := []struct { + dstIP netip.Addr + expected bool + desc string + }{ + {netip.MustParseAddr("10.0.0.100"), true, "IP in both /16 and /24"}, + {netip.MustParseAddr("10.0.1.100"), true, "IP only in /16"}, + {netip.MustParseAddr("192.168.1.100"), true, "IP in both /16 and /24"}, + {netip.MustParseAddr("192.168.2.100"), true, "IP only in /16"}, + {netip.MustParseAddr("172.16.0.100"), false, "IP not in any prefix"}, + } + + srcIP := netip.MustParseAddr("100.10.0.1") + for _, tc := range testCases { + _, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80) + require.Equal(t, tc.expected, isAllowed, tc.desc) + } +} diff --git a/client/internal/acl/id/id.go b/client/internal/acl/id/id.go index 93f16b429..23451453e 100644 --- a/client/internal/acl/id/id.go +++ b/client/internal/acl/id/id.go @@ -18,7 +18,7 @@ func (r RuleID) ID() string { func GenerateRouteRuleKey( sources []netip.Prefix, - destination netip.Prefix, + destination manager.Network, proto manager.Protocol, sPort *manager.Port, dPort *manager.Port, diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index 61fbb10ca..6fa35d5c2 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -18,6 +18,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/management/domain" mgmProto "github.com/netbirdio/netbird/management/proto" ) @@ -25,7 +26,7 @@ var ErrSourceRangesEmpty = errors.New("sources range is empty") // Manager is a ACL rules manager type Manager interface { - ApplyFiltering(networkMap *mgmProto.NetworkMap) + ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) } type protoMatch struct { @@ -53,7 +54,7 @@ func NewDefaultManager(fm firewall.Manager) *DefaultManager { // ApplyFiltering firewall rules to the local firewall manager processed by ACL policy. // // If allowByDefault is true it appends allow ALL traffic rules to input and output chains. -func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { +func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) { d.mutex.Lock() defer d.mutex.Unlock() @@ -82,7 +83,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { log.Errorf("failed to set legacy management flag: %v", err) } - if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil { + if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil { log.Errorf("Failed to apply route ACLs: %v", err) } @@ -176,16 +177,16 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { d.peerRulesPairs = newRulePairs } -func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error { +func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error { newRouteRules := make(map[id.RuleID]struct{}, len(rules)) var merr *multierror.Error // Apply new rules - firewall manager will return existing rule ID if already present for _, rule := range rules { - id, err := d.applyRouteACL(rule) + id, err := d.applyRouteACL(rule, dynamicResolver) if err != nil { if errors.Is(err, ErrSourceRangesEmpty) { - log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err) + log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err) } else { merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err)) } @@ -208,7 +209,7 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) err return nberrors.FormatErrorOrNil(merr) } -func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) { +func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (id.RuleID, error) { if len(rule.SourceRanges) == 0 { return "", ErrSourceRangesEmpty } @@ -222,15 +223,9 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul sources = append(sources, source) } - var destination netip.Prefix - if rule.IsDynamic { - destination = getDefault(sources[0]) - } else { - var err error - destination, err = netip.ParsePrefix(rule.Destination) - if err != nil { - return "", fmt.Errorf("parse destination: %w", err) - } + destination, err := determineDestination(rule, dynamicResolver, sources) + if err != nil { + return "", fmt.Errorf("determine destination: %w", err) } protocol, err := convertToFirewallProtocol(rule.Protocol) @@ -580,6 +575,33 @@ func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port { return nil } +func determineDestination(rule *mgmProto.RouteFirewallRule, dynamicResolver bool, sources []netip.Prefix) (firewall.Network, error) { + var destination firewall.Network + + if rule.IsDynamic { + if dynamicResolver { + if len(rule.Domains) > 0 { + destination.Set = firewall.NewDomainSet(domain.FromPunycodeList(rule.Domains)) + } else { + // isDynamic is set but no domains = outdated management server + log.Warn("connected to an older version of management server (no domains in rules), using default destination") + destination.Prefix = getDefault(sources[0]) + } + } else { + // client resolves DNS, we (router) don't know the destination + destination.Prefix = getDefault(sources[0]) + } + return destination, nil + } + + prefix, err := netip.ParsePrefix(rule.Destination) + if err != nil { + return destination, fmt.Errorf("parse destination: %w", err) + } + destination.Prefix = prefix + return destination, nil +} + func getDefault(prefix netip.Prefix) netip.Prefix { if prefix.Addr().Is6() { return netip.PrefixFrom(netip.IPv6Unspecified(), 0) diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 9488d33ab..3595ca600 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -66,7 +66,7 @@ func TestDefaultManager(t *testing.T) { acl := NewDefaultManager(fw) t.Run("apply firewall rules", func(t *testing.T) { - acl.ApplyFiltering(networkMap) + acl.ApplyFiltering(networkMap, false) if len(acl.peerRulesPairs) != 2 { t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs) @@ -92,7 +92,7 @@ func TestDefaultManager(t *testing.T) { }, ) - acl.ApplyFiltering(networkMap) + acl.ApplyFiltering(networkMap, false) // we should have one old and one new rule in the existed rules if len(acl.peerRulesPairs) != 2 { @@ -116,13 +116,13 @@ func TestDefaultManager(t *testing.T) { networkMap.FirewallRules = networkMap.FirewallRules[:0] networkMap.FirewallRulesIsEmpty = true - if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 { + if acl.ApplyFiltering(networkMap, false); len(acl.peerRulesPairs) != 0 { t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs)) return } networkMap.FirewallRulesIsEmpty = false - acl.ApplyFiltering(networkMap) + acl.ApplyFiltering(networkMap, false) if len(acl.peerRulesPairs) != 1 { t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs)) return @@ -359,7 +359,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { }(fw) acl := NewDefaultManager(fw) - acl.ApplyFiltering(networkMap) + acl.ApplyFiltering(networkMap, false) if len(acl.peerRulesPairs) != 3 { t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs)) diff --git a/client/internal/debug/debug_linux.go b/client/internal/debug/debug_linux.go index 291531fea..b4907beca 100644 --- a/client/internal/debug/debug_linux.go +++ b/client/internal/debug/debug_linux.go @@ -59,6 +59,16 @@ func collectIPTablesRules() (string, error) { builder.WriteString("\n") } + // Collect ipset information + ipsetOutput, err := collectIPSets() + if err != nil { + log.Warnf("Failed to collect ipset information: %v", err) + } else { + builder.WriteString("=== ipset list output ===\n") + builder.WriteString(ipsetOutput) + builder.WriteString("\n") + } + builder.WriteString("=== iptables -v -n -L output ===\n") tables := []string{"filter", "nat", "mangle", "raw", "security"} @@ -78,6 +88,28 @@ func collectIPTablesRules() (string, error) { return builder.String(), nil } +// collectIPSets collects information about ipsets +func collectIPSets() (string, error) { + cmd := exec.Command("ipset", "list") + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + if strings.Contains(err.Error(), "executable file not found") { + return "", fmt.Errorf("ipset command not found: %w", err) + } + return "", fmt.Errorf("execute ipset list: %w (stderr: %s)", err, stderr.String()) + } + + ipsets := stdout.String() + if strings.TrimSpace(ipsets) == "" { + return "No ipsets found", nil + } + + return ipsets, nil +} + // collectIPTablesSave uses iptables-save to get rule definitions func collectIPTablesSave() (string, error) { cmd := exec.Command("iptables-save") diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 2d69ce858..8f6a31f47 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -3,6 +3,7 @@ package dnsfwd import ( "context" "errors" + "fmt" "math" "net" "net/netip" @@ -10,11 +11,16 @@ import ( "sync" "time" + "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" + firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/peer" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/domain" + "github.com/netbirdio/netbird/route" ) const errResolveFailed = "failed to resolve query for domain=%s: %v" @@ -23,25 +29,27 @@ const upstreamTimeout = 15 * time.Second type DNSForwarder struct { listenAddress string ttl uint32 - domains []string statusRecorder *peer.Status dnsServer *dns.Server mux *dns.ServeMux - resId sync.Map + mutex sync.RWMutex + fwdEntries []*ForwarderEntry + firewall firewall.Manager } -func NewDNSForwarder(listenAddress string, ttl uint32, statusRecorder *peer.Status) *DNSForwarder { +func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager, statusRecorder *peer.Status) *DNSForwarder { log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl) return &DNSForwarder{ listenAddress: listenAddress, ttl: ttl, + firewall: firewall, statusRecorder: statusRecorder, } } -func (f *DNSForwarder) Listen(domains []string, resIds map[string]string) error { +func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { log.Infof("listen DNS forwarder on address=%s", f.listenAddress) mux := dns.NewServeMux() @@ -53,31 +61,35 @@ func (f *DNSForwarder) Listen(domains []string, resIds map[string]string) error f.dnsServer = dnsServer f.mux = mux - f.UpdateDomains(domains, resIds) + f.UpdateDomains(entries) return dnsServer.ListenAndServe() } -func (f *DNSForwarder) UpdateDomains(domains []string, resIds map[string]string) { - log.Debugf("Updating domains from %v to %v", f.domains, domains) +func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { + f.mutex.Lock() + defer f.mutex.Unlock() - for _, d := range f.domains { - f.mux.HandleRemove(d) + if f.mux == nil { + log.Debug("DNS mux is nil, skipping domain update") + f.fwdEntries = entries + return } - f.resId.Clear() - newDomains := filterDomains(domains) + oldDomains := filterDomains(f.fwdEntries) + + for _, d := range oldDomains { + f.mux.HandleRemove(d.PunycodeString()) + } + + newDomains := filterDomains(entries) for _, d := range newDomains { - f.mux.HandleFunc(d, f.handleDNSQuery) + f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQuery) } - for domain, resId := range resIds { - if domain != "" { - f.resId.Store(domain, resId) - } - } + f.fwdEntries = entries - f.domains = newDomains + log.Debugf("Updated domains from %v to %v", oldDomains, newDomains) } func (f *DNSForwarder) Close(ctx context.Context) error { @@ -91,11 +103,11 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { if len(query.Question) == 0 { return } - log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v", - query.Question[0].Name, query.Question[0].Qtype, query.Question[0].Qclass) - question := query.Question[0] - domain := question.Name + log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v", + question.Name, question.Qtype, question.Qclass) + + domain := strings.ToLower(question.Name) resp := query.SetReply(query) var network string @@ -122,21 +134,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { return } - resId := f.getResIdForDomain(strings.TrimSuffix(domain, ".")) - if resId != "" { - for _, ip := range ips { - var ipWithSuffix string - if ip.Is4() { - ipWithSuffix = ip.String() + "/32" - log.Tracef("resolved domain=%s to IPv4=%s", domain, ipWithSuffix) - } else { - ipWithSuffix = ip.String() + "/128" - log.Tracef("resolved domain=%s to IPv6=%s", domain, ipWithSuffix) - } - f.statusRecorder.AddResolvedIPLookupEntry(ipWithSuffix, resId) - } - } - + f.updateInternalState(domain, ips) f.addIPsToResponse(resp, domain, ips) if err := w.WriteMsg(resp); err != nil { @@ -144,6 +142,42 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { } } +func (f *DNSForwarder) updateInternalState(domain string, ips []netip.Addr) { + var prefixes []netip.Prefix + mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, ".")) + if mostSpecificResId != "" { + for _, ip := range ips { + var prefix netip.Prefix + if ip.Is4() { + prefix = netip.PrefixFrom(ip, 32) + } else { + prefix = netip.PrefixFrom(ip, 128) + } + prefixes = append(prefixes, prefix) + f.statusRecorder.AddResolvedIPLookupEntry(prefix, mostSpecificResId) + } + } + + if f.firewall != nil { + f.updateFirewall(matchingEntries, prefixes) + } +} + +func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixes []netip.Prefix) { + var merr *multierror.Error + for _, entry := range matchingEntries { + if err := f.firewall.UpdateSet(entry.Set, prefixes); err != nil { + merr = multierror.Append(merr, fmt.Errorf("update set for domain=%s: %w", entry.Domain, err)) + } + } + if merr != nil { + log.Errorf("failed to update firewall sets (%d/%d): %v", + len(merr.Errors), + len(matchingEntries), + nberrors.FormatErrorOrNil(merr)) + } +} + // handleDNSError processes DNS lookup errors and sends an appropriate error response func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domain string, err error) { var dnsErr *net.DNSError @@ -204,45 +238,53 @@ func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []neti } } -func (f *DNSForwarder) getResIdForDomain(domain string) string { - var selectedResId string +// getMatchingEntries retrieves the resource IDs for a given domain. +// It returns the most specific match and all matching resource IDs. +func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*ForwarderEntry) { + var selectedResId route.ResID var bestScore int + var matches []*ForwarderEntry - f.resId.Range(func(key, value interface{}) bool { + f.mutex.RLock() + defer f.mutex.RUnlock() + + for _, entry := range f.fwdEntries { var score int - pattern := key.(string) + pattern := entry.Domain.PunycodeString() switch { case strings.HasPrefix(pattern, "*."): baseDomain := strings.TrimPrefix(pattern, "*.") - if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) { + + if strings.EqualFold(domain, baseDomain) || strings.HasSuffix(domain, "."+baseDomain) { score = len(baseDomain) + matches = append(matches, entry) } case domain == pattern: score = math.MaxInt + matches = append(matches, entry) default: - return true + continue } if score > bestScore { bestScore = score - selectedResId = value.(string) + selectedResId = entry.ResID } - return true - }) + } - return selectedResId + return selectedResId, matches } // filterDomains returns a list of normalized domains -func filterDomains(domains []string) []string { - newDomains := make([]string, 0, len(domains)) - for _, d := range domains { - if d == "" { +func filterDomains(entries []*ForwarderEntry) domain.List { + newDomains := make(domain.List, 0, len(entries)) + for _, d := range entries { + if d.Domain == "" { log.Warn("empty domain in DNS forwarder") continue } - newDomains = append(newDomains, nbdns.NormalizeZone(d)) + newDomains = append(newDomains, domain.Domain(nbdns.NormalizeZone(d.Domain.PunycodeString()))) } return newDomains } diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go index 88ffc2af3..f0829bbbd 100644 --- a/client/internal/dnsfwd/forwarder_test.go +++ b/client/internal/dnsfwd/forwarder_test.go @@ -1,56 +1,61 @@ package dnsfwd import ( - "sync" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/domain" + "github.com/netbirdio/netbird/route" ) -func TestGetResIdForDomain(t *testing.T) { +func Test_getMatchingEntries(t *testing.T) { testCases := []struct { name string - storedMappings map[string]string // key: domain pattern, value: resId + storedMappings map[string]route.ResID // key: domain pattern, value: resId queryDomain string - expectedResId string + expectedResId route.ResID }{ { name: "Empty map returns empty string", - storedMappings: map[string]string{}, + storedMappings: map[string]route.ResID{}, queryDomain: "example.com", expectedResId: "", }, { name: "Exact match returns stored resId", - storedMappings: map[string]string{"example.com": "res1"}, + storedMappings: map[string]route.ResID{"example.com": "res1"}, queryDomain: "example.com", expectedResId: "res1", }, { name: "Wildcard pattern matches base domain", - storedMappings: map[string]string{"*.example.com": "res2"}, + storedMappings: map[string]route.ResID{"*.example.com": "res2"}, queryDomain: "example.com", expectedResId: "res2", }, { name: "Wildcard pattern matches subdomain", - storedMappings: map[string]string{"*.example.com": "res3"}, + storedMappings: map[string]route.ResID{"*.example.com": "res3"}, queryDomain: "foo.example.com", expectedResId: "res3", }, { name: "Wildcard pattern does not match different domain", - storedMappings: map[string]string{"*.example.com": "res4"}, + storedMappings: map[string]route.ResID{"*.example.com": "res4"}, queryDomain: "foo.notexample.com", expectedResId: "", }, { name: "Non-wildcard pattern does not match subdomain", - storedMappings: map[string]string{"example.com": "res5"}, + storedMappings: map[string]route.ResID{"example.com": "res5"}, queryDomain: "foo.example.com", expectedResId: "", }, { name: "Exact match over overlapping wildcard", - storedMappings: map[string]string{ + storedMappings: map[string]route.ResID{ "*.example.com": "resWildcard", "foo.example.com": "resExact", }, @@ -59,7 +64,7 @@ func TestGetResIdForDomain(t *testing.T) { }, { name: "Overlapping wildcards: Select more specific wildcard", - storedMappings: map[string]string{ + storedMappings: map[string]route.ResID{ "*.example.com": "resA", "*.sub.example.com": "resB", }, @@ -68,7 +73,7 @@ func TestGetResIdForDomain(t *testing.T) { }, { name: "Wildcard multi-level subdomain match", - storedMappings: map[string]string{ + storedMappings: map[string]route.ResID{ "*.example.com": "resMulti", }, queryDomain: "a.b.example.com", @@ -78,18 +83,21 @@ func TestGetResIdForDomain(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - fwd := &DNSForwarder{ - resId: sync.Map{}, - } + fwd := &DNSForwarder{} + var entries []*ForwarderEntry for domainPattern, resId := range tc.storedMappings { - fwd.resId.Store(domainPattern, resId) + d, err := domain.FromString(domainPattern) + require.NoError(t, err) + entries = append(entries, &ForwarderEntry{ + Domain: d, + ResID: resId, + }) } + fwd.UpdateDomains(entries) - got := fwd.getResIdForDomain(tc.queryDomain) - if got != tc.expectedResId { - t.Errorf("For query domain %q, expected resId %q, but got %q", tc.queryDomain, tc.expectedResId, got) - } + got, _ := fwd.getMatchingEntries(tc.queryDomain) + assert.Equal(t, got, tc.expectedResId) }) } } diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index a51ae7abb..e4a23450f 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -11,6 +11,8 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/management/domain" + "github.com/netbirdio/netbird/route" ) const ( @@ -19,6 +21,13 @@ const ( dnsTTL = 60 //seconds ) +// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list. +type ForwarderEntry struct { + Domain domain.Domain + ResID route.ResID + Set firewall.Set +} + type Manager struct { firewall firewall.Manager statusRecorder *peer.Status @@ -34,7 +43,7 @@ func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager { } } -func (m *Manager) Start(domains []string, resIds map[string]string) error { +func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { log.Infof("starting DNS forwarder") if m.dnsForwarder != nil { return nil @@ -44,9 +53,9 @@ func (m *Manager) Start(domains []string, resIds map[string]string) error { return err } - m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.statusRecorder) + m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.firewall, m.statusRecorder) go func() { - if err := m.dnsForwarder.Listen(domains, resIds); err != nil { + if err := m.dnsForwarder.Listen(fwdEntries); err != nil { // todo handle close error if it is exists log.Errorf("failed to start DNS forwarder, err: %v", err) } @@ -55,12 +64,12 @@ func (m *Manager) Start(domains []string, resIds map[string]string) error { return nil } -func (m *Manager) UpdateDomains(domains []string, resIds map[string]string) { +func (m *Manager) UpdateDomains(entries []*ForwarderEntry) { if m.dnsForwarder == nil { return } - m.dnsForwarder.UpdateDomains(domains, resIds) + m.dnsForwarder.UpdateDomains(entries) } func (m *Manager) Stop(ctx context.Context) error { @@ -81,34 +90,34 @@ func (m *Manager) Stop(ctx context.Context) error { return nberrors.FormatErrorOrNil(mErr) } -func (h *Manager) allowDNSFirewall() error { +func (m *Manager) allowDNSFirewall() error { dport := &firewall.Port{ IsRange: false, Values: []uint16{ListenPort}, } - if h.firewall == nil { + if m.firewall == nil { return nil } - dnsRules, err := h.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "") + dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "") if err != nil { log.Errorf("failed to add allow DNS router rules, err: %v", err) return err } - h.fwRules = dnsRules + m.fwRules = dnsRules return nil } -func (h *Manager) dropDNSFirewall() error { +func (m *Manager) dropDNSFirewall() error { var mErr *multierror.Error - for _, rule := range h.fwRules { - if err := h.firewall.DeletePeerRule(rule); err != nil { + for _, rule := range m.fwRules { + if err := m.firewall.DeletePeerRule(rule); err != nil { mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err)) } } - h.fwRules = nil + m.fwRules = nil return nberrors.FormatErrorOrNil(mErr) } diff --git a/client/internal/engine.go b/client/internal/engine.go index c377c12e1..b16232883 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -527,7 +527,7 @@ func (e *Engine) blockLanAccess() { if _, err := e.firewall.AddRouteFiltering( nil, []netip.Prefix{v4}, - network, + firewallManager.Network{Prefix: network}, firewallManager.ProtocolALL, nil, nil, @@ -960,21 +960,21 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } } - // DNS forwarder dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap) - dnsRouteDomains, resourceIds := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes()) - e.updateDNSForwarder(dnsRouteFeatureFlag, dnsRouteDomains, resourceIds) + // apply routes first, route related actions might depend on routing being enabled routes := toRoutes(networkMap.GetRoutes()) if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil { log.Errorf("failed to update clientRoutes, err: %v", err) } - // acls might need routing to be enabled, so we apply after routes if e.acl != nil { - e.acl.ApplyFiltering(networkMap) + e.acl.ApplyFiltering(networkMap, dnsRouteFeatureFlag) } + fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes) + e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries) + // Ingress forward rules if err := e.updateForwardRules(networkMap.GetForwardingRules()); err != nil { log.Errorf("failed to update forward rules, err: %v", err) @@ -1079,29 +1079,24 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { return routes } -func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) ([]string, map[string]string) { - if protoRoutes == nil { - protoRoutes = []*mgmProto.Route{} - } - - var dnsRoutes []string - resIds := make(map[string]string) - for _, protoRoute := range protoRoutes { - if len(protoRoute.Domains) == 0 { +func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderEntry { + var entries []*dnsfwd.ForwarderEntry + for _, route := range routes { + if len(route.Domains) == 0 { continue } - if protoRoute.Peer == myPubKey { - dnsRoutes = append(dnsRoutes, protoRoute.Domains...) - // resource ID is the first part of the ID - resId := strings.Split(protoRoute.ID, ":") - for _, domain := range protoRoute.Domains { - if len(resId) > 0 { - resIds[domain] = resId[0] - } + if route.Peer == myPubKey { + domainSet := firewallManager.NewDomainSet(route.Domains) + for _, d := range route.Domains { + entries = append(entries, &dnsfwd.ForwarderEntry{ + Domain: d, + Set: domainSet, + ResID: route.GetResourceID(), + }) } } } - return dnsRoutes, resIds + return entries } func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config { @@ -1751,7 +1746,10 @@ func (e *Engine) GetWgAddr() net.IP { } // updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag -func (e *Engine) updateDNSForwarder(enabled bool, domains []string, resIds map[string]string) { +func (e *Engine) updateDNSForwarder( + enabled bool, + fwdEntries []*dnsfwd.ForwarderEntry, +) { if !enabled { if e.dnsForwardMgr == nil { return @@ -1762,18 +1760,18 @@ func (e *Engine) updateDNSForwarder(enabled bool, domains []string, resIds map[s return } - if len(domains) > 0 { - log.Infof("enable domain router service for domains: %v", domains) + if len(fwdEntries) > 0 { if e.dnsForwardMgr == nil { e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) - if err := e.dnsForwardMgr.Start(domains, resIds); err != nil { + if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil } + + log.Infof("started domain router service with %d entries", len(fwdEntries)) } else { - log.Infof("update domain router service for domains: %v", domains) - e.dnsForwardMgr.UpdateDomains(domains, resIds) + e.dnsForwardMgr.UpdateDomains(fwdEntries) } } else if e.dnsForwardMgr != nil { log.Infof("disable domain router service") diff --git a/client/internal/peer/route.go b/client/internal/peer/route.go index c3567dcc9..e5e315e3c 100644 --- a/client/internal/peer/route.go +++ b/client/internal/peer/route.go @@ -6,12 +6,14 @@ import ( "sync" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/route" ) // routeEntry holds the route prefix and the corresponding resource ID. type routeEntry struct { prefix netip.Prefix - resourceID string + resourceID route.ResID } type routeIDLookup struct { @@ -24,7 +26,7 @@ type routeIDLookup struct { resolvedIPs sync.Map } -func (r *routeIDLookup) AddLocalRouteID(resourceID string, route netip.Prefix) { +func (r *routeIDLookup) AddLocalRouteID(resourceID route.ResID, route netip.Prefix) { r.localLock.Lock() defer r.localLock.Unlock() @@ -56,7 +58,7 @@ func (r *routeIDLookup) RemoveLocalRouteID(route netip.Prefix) { } } -func (r *routeIDLookup) AddRemoteRouteID(resourceID string, route netip.Prefix) { +func (r *routeIDLookup) AddRemoteRouteID(resourceID route.ResID, route netip.Prefix) { r.remoteLock.Lock() defer r.remoteLock.Unlock() @@ -87,7 +89,7 @@ func (r *routeIDLookup) RemoveRemoteRouteID(route netip.Prefix) { } } -func (r *routeIDLookup) AddResolvedIP(resourceID string, route netip.Prefix) { +func (r *routeIDLookup) AddResolvedIP(resourceID route.ResID, route netip.Prefix) { r.resolvedIPs.Store(route.Addr(), resourceID) } @@ -97,19 +99,19 @@ func (r *routeIDLookup) RemoveResolvedIP(route netip.Prefix) { // Lookup returns the resource ID for the given IP address // and a bool indicating if the IP is an exit node. -func (r *routeIDLookup) Lookup(ip netip.Addr) (string, bool) { +func (r *routeIDLookup) Lookup(ip netip.Addr) (route.ResID, bool) { if res, ok := r.resolvedIPs.Load(ip); ok { - return res.(string), false + return res.(route.ResID), false } - var resourceID string + var resourceID route.ResID var isExitNode bool r.localLock.RLock() for _, entry := range r.localRoutes { if entry.prefix.Contains(ip) { resourceID = entry.resourceID - isExitNode = (entry.prefix.Bits() == 0) + isExitNode = entry.prefix.Bits() == 0 break } } @@ -120,7 +122,7 @@ func (r *routeIDLookup) Lookup(ip netip.Addr) (string, bool) { for _, entry := range r.remoteRoutes { if entry.prefix.Contains(ip) { resourceID = entry.resourceID - isExitNode = (entry.prefix.Bits() == 0) + isExitNode = entry.prefix.Bits() == 0 break } } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 9b3fc744d..3eca6a8c9 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -21,6 +21,7 @@ import ( "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/management/domain" relayClient "github.com/netbirdio/netbird/relay/client" + "github.com/netbirdio/netbird/route" ) const eventQueueSize = 10 @@ -313,7 +314,7 @@ func (d *Status) UpdatePeerState(receivedState State) error { return nil } -func (d *Status) AddPeerStateRoute(peer string, route string, resourceId string) error { +func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.ResID) error { d.mux.Lock() defer d.mux.Unlock() @@ -581,7 +582,7 @@ func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) { } // AddLocalPeerStateRoute adds a route to the local peer state -func (d *Status) AddLocalPeerStateRoute(route, resourceId string) { +func (d *Status) AddLocalPeerStateRoute(route string, resourceId route.ResID) { d.mux.Lock() defer d.mux.Unlock() @@ -611,14 +612,11 @@ func (d *Status) RemoveLocalPeerStateRoute(route string) { } // AddResolvedIPLookupEntry adds a resolved IP lookup entry -func (d *Status) AddResolvedIPLookupEntry(route, resourceId string) { +func (d *Status) AddResolvedIPLookupEntry(prefix netip.Prefix, resourceId route.ResID) { d.mux.Lock() defer d.mux.Unlock() - pref, err := netip.ParsePrefix(route) - if err == nil { - d.routeIDLookup.AddResolvedIP(resourceId, pref) - } + d.routeIDLookup.AddResolvedIP(resourceId, prefix) } // RemoveResolvedIPLookupEntry removes a resolved IP lookup entry @@ -723,7 +721,7 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) { d.nsGroupStates = dnsStates } -func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix, resourceId string) { +func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix, resourceId route.ResID) { d.mux.Lock() defer d.mux.Unlock() diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 68d81d968..6d51c88c0 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -234,7 +234,7 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { origPattern = writer.GetOrigPattern() } - resolvedDomain := domain.Domain(r.Question[0].Name) + resolvedDomain := domain.Domain(strings.ToLower(r.Question[0].Name)) // already punycode via RegisterHandler() originalDomain := domain.Domain(origPattern) @@ -328,6 +328,11 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom // Update domain prefixes using resolved domain as key if len(toAdd) > 0 || len(toRemove) > 0 { + if d.route.KeepRoute { + // replace stored prefixes with old + added + // nolint:gocritic + newPrefixes = append(oldPrefixes, toAdd...) + } d.interceptedDomains[resolvedDomain] = newPrefixes originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), ".")) d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID()) @@ -338,7 +343,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom originalDomain.SafeString(), toAdd) } - if len(toRemove) > 0 { + if len(toRemove) > 0 && !d.route.KeepRoute { log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s", resolvedDomain.SafeString(), originalDomain.SafeString(), diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index ae0d1d220..078206ab9 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -259,8 +259,6 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { } } - m.ctx = nil - m.mux.Lock() defer m.mux.Unlock() m.clientRoutes = nil @@ -292,7 +290,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro return nil } - if err := m.serverRouter.updateRoutes(newServerRoutesMap); err != nil { + if err := m.serverRouter.updateRoutes(newServerRoutesMap, useNewDNSRoute); err != nil { return fmt.Errorf("update routes: %w", err) } diff --git a/client/internal/routemanager/server_android.go b/client/internal/routemanager/server_android.go index 48bb0380d..953210e9e 100644 --- a/client/internal/routemanager/server_android.go +++ b/client/internal/routemanager/server_android.go @@ -18,7 +18,7 @@ type serverRouter struct { func (r serverRouter) cleanUp() { } -func (r serverRouter) updateRoutes(map[route.ID]*route.Route) error { +func (r serverRouter) updateRoutes(map[route.ID]*route.Route, bool) error { return nil } diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index 18713ee65..131d4c170 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -35,7 +35,10 @@ func newServerRouter(ctx context.Context, wgInterface iface.WGIface, firewall fi }, nil } -func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error { +func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRoute bool) error { + m.mux.Lock() + defer m.mux.Unlock() + serverRoutesToRemove := make([]route.ID, 0) for routeID := range m.routes { @@ -73,7 +76,7 @@ func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error { continue } - err := m.addToServerNetwork(newRoute) + err := m.addToServerNetwork(newRoute, useNewDNSRoute) if err != nil { log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err) continue @@ -90,57 +93,30 @@ func (m *serverRouter) removeFromServerNetwork(route *route.Route) error { return m.ctx.Err() } - m.mux.Lock() - defer m.mux.Unlock() - - routerPair, err := routeToRouterPair(route) - if err != nil { - return fmt.Errorf("parse prefix: %w", err) - } - - err = m.firewall.RemoveNatRule(routerPair) - if err != nil { + routerPair := routeToRouterPair(route, false) + if err := m.firewall.RemoveNatRule(routerPair); err != nil { return fmt.Errorf("remove routing rules: %w", err) } delete(m.routes, route.ID) - - routeStr := route.Network.String() - if route.IsDynamic() { - routeStr = route.Domains.SafeString() - } - m.statusRecorder.RemoveLocalPeerStateRoute(routeStr) + m.statusRecorder.RemoveLocalPeerStateRoute(route.NetString()) return nil } -func (m *serverRouter) addToServerNetwork(route *route.Route) error { +func (m *serverRouter) addToServerNetwork(route *route.Route, useNewDNSRoute bool) error { if m.ctx.Err() != nil { log.Infof("Not adding to server network because context is done") return m.ctx.Err() } - m.mux.Lock() - defer m.mux.Unlock() - - routerPair, err := routeToRouterPair(route) - if err != nil { - return fmt.Errorf("parse prefix: %w", err) - } - - err = m.firewall.AddNatRule(routerPair) - if err != nil { + routerPair := routeToRouterPair(route, useNewDNSRoute) + if err := m.firewall.AddNatRule(routerPair); err != nil { return fmt.Errorf("insert routing rules: %w", err) } m.routes[route.ID] = route - - routeStr := route.Network.String() - if route.IsDynamic() { - routeStr = route.Domains.SafeString() - } - - m.statusRecorder.AddLocalPeerStateRoute(routeStr, route.GetResourceID()) + m.statusRecorder.AddLocalPeerStateRoute(route.NetString(), route.GetResourceID()) return nil } @@ -148,31 +124,29 @@ func (m *serverRouter) addToServerNetwork(route *route.Route) error { func (m *serverRouter) cleanUp() { m.mux.Lock() defer m.mux.Unlock() - for _, r := range m.routes { - routerPair, err := routeToRouterPair(r) - if err != nil { - log.Errorf("Failed to convert route to router pair: %v", err) - continue - } - err = m.firewall.RemoveNatRule(routerPair) - if err != nil { + for _, r := range m.routes { + routerPair := routeToRouterPair(r, false) + if err := m.firewall.RemoveNatRule(routerPair); err != nil { log.Errorf("Failed to remove cleanup route: %v", err) } - } m.statusRecorder.CleanLocalPeerStateRoutes() } -func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) { - // TODO: add ipv6 +func routeToRouterPair(route *route.Route, useNewDNSRoute bool) firewall.RouterPair { source := getDefaultPrefix(route.Network) - - destination := route.Network.Masked() + destination := firewall.Network{} if route.IsDynamic() { - // TODO: add ipv6 additionally - destination = getDefaultPrefix(destination) + if useNewDNSRoute { + destination.Set = firewall.NewDomainSet(route.Domains) + } else { + // TODO: add ipv6 additionally + destination = getDefaultPrefix(destination.Prefix) + } + } else { + destination.Prefix = route.Network.Masked() } return firewall.RouterPair{ @@ -180,12 +154,16 @@ func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) { Source: source, Destination: destination, Masquerade: route.Masquerade, - }, nil + } } -func getDefaultPrefix(prefix netip.Prefix) netip.Prefix { +func getDefaultPrefix(prefix netip.Prefix) firewall.Network { if prefix.Addr().Is6() { - return netip.PrefixFrom(netip.IPv6Unspecified(), 0) + return firewall.Network{ + Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0), + } + } + return firewall.Network{ + Prefix: netip.PrefixFrom(netip.IPv4Unspecified(), 0), } - return netip.PrefixFrom(netip.IPv4Unspecified(), 0) } diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index cf3c2f0aa..59b6346c6 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -45,7 +45,7 @@ var sysctlFailed bool type ruleParams struct { priority int - fwmark int + fwmark uint32 tableID int family int invert bool @@ -55,8 +55,8 @@ type ruleParams struct { func getSetupRules() []ruleParams { return []ruleParams{ - {100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"}, - {100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"}, + {100, 0, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"}, + {100, 0, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"}, {110, nbnet.ControlPlaneMark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"}, {110, nbnet.ControlPlaneMark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"}, } diff --git a/client/server/network.go b/client/server/network.go index e0b01f763..93b7caa46 100644 --- a/client/server/network.go +++ b/client/server/network.go @@ -100,7 +100,7 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro // Convert to proto format for domain, ips := range domainMap { - pbRoute.ResolvedIPs[domain.PunycodeString()] = &proto.IPList{ + pbRoute.ResolvedIPs[domain.SafeString()] = &proto.IPList{ Ips: ips, } } diff --git a/client/status/status.go b/client/status/status.go index 43acc9197..f37e5b0f0 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/version" ) @@ -414,7 +415,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, signalConnString, relaysString, dnsServersString, - overview.FQDN, + domain.Domain(overview.FQDN).SafeString(), interfaceIP, interfaceTypeString, rosenpassEnabledStatus, @@ -508,7 +509,7 @@ func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bo " Quantum resistance: %s\n"+ " Networks: %s\n"+ " Latency: %s\n", - peerState.FQDN, + domain.Domain(peerState.FQDN).SafeString(), peerState.IP, peerState.PubKey, peerState.Status, diff --git a/dns/dns.go b/dns/dns.go index 8dfdf8526..3a1c76e56 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -111,6 +111,5 @@ func GetParsedDomainLabel(name string) (string, error) { // NormalizeZone returns a normalized domain name without the wildcard prefix func NormalizeZone(domain string) string { - d, _ := strings.CutPrefix(domain, "*.") - return d + return strings.TrimPrefix(domain, "*.") } diff --git a/go.mod b/go.mod index b1b01d446..095840f13 100644 --- a/go.mod +++ b/go.mod @@ -18,9 +18,9 @@ require ( github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 - github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/crypto v0.36.0 - golang.org/x/sys v0.31.0 + github.com/vishvananda/netlink v1.3.0 + golang.org/x/crypto v0.37.0 + golang.org/x/sys v0.32.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 @@ -39,7 +39,6 @@ require ( github.com/coder/websocket v1.8.12 github.com/coreos/go-iptables v0.7.0 github.com/creack/pty v1.1.18 - github.com/davecgh/go-spew v1.1.1 github.com/eko/gocache/lib/v4 v4.2.0 github.com/eko/gocache/store/go_cache/v4 v4.2.2 github.com/eko/gocache/store/redis/v4 v4.2.2 @@ -49,7 +48,7 @@ require ( github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.7.0 github.com/google/gopacket v1.1.19 - github.com/google/nftables v0.2.0 + github.com/google/nftables v0.3.0 github.com/gopacket/gopacket v1.1.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/hashicorp/go-multierror v1.1.1 @@ -100,10 +99,10 @@ require ( goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a - golang.org/x/net v0.38.0 + golang.org/x/net v0.39.0 golang.org/x/oauth2 v0.24.0 - golang.org/x/sync v0.12.0 - golang.org/x/term v0.30.0 + golang.org/x/sync v0.13.0 + golang.org/x/term v0.31.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.5.7 @@ -145,6 +144,7 @@ require ( github.com/containerd/log v0.1.0 // indirect github.com/containerd/platforms v0.2.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/docker v26.1.5+incompatible // indirect @@ -183,7 +183,6 @@ require ( github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect - github.com/josharian/native v1.1.0 // indirect github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect github.com/kelseyhightower/envconfig v1.4.0 // indirect github.com/klauspost/compress v1.18.0 // indirect @@ -192,7 +191,7 @@ require ( github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mdlayher/genetlink v1.3.2 // indirect - github.com/mdlayher/netlink v1.7.2 // indirect + github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect github.com/mholt/acmez/v2 v2.0.1 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/patternmatcher v0.6.0 // indirect @@ -235,7 +234,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/mod v0.17.0 // indirect - golang.org/x/text v0.23.0 // indirect + golang.org/x/text v0.24.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect diff --git a/go.sum b/go.sum index fb351dd25..8c1c021f8 100644 --- a/go.sum +++ b/go.sum @@ -301,8 +301,8 @@ github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= -github.com/google/nftables v0.2.0 h1:PbJwaBmbVLzpeldoeUKGkE2RjstrjPKMl6oLrfEJ6/8= -github.com/google/nftables v0.2.0/go.mod h1:Beg6V6zZ3oEn0JuiUQ4wqwuyqqzasOltcoXPtgLbFp4= +github.com/google/nftables v0.3.0 h1:bkyZ0cbpVeMHXOrtlFc8ISmfVqq5gPJukoYieyVmITg= +github.com/google/nftables v0.3.0/go.mod h1:BCp9FsrbF1Fn/Yu6CLUc9GGZFw/+hsxfluNXXmxBfRM= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= @@ -399,8 +399,6 @@ github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9Y github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= -github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= -github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= @@ -447,8 +445,8 @@ github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= -github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= -github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o= github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k= @@ -665,9 +663,8 @@ github.com/tklauser/numcpus v0.8.0 h1:Mx4Wwe/FjZLeQsK/6kt2EOepwwSl7SmJrK5bV/dXYg github.com/tklauser/numcpus v0.8.0/go.mod h1:ZJZlAY+dmR4eut8epnzf0u/VwodKmryxR8txiloSqBE= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= -github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs= -github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= -github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= +github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk= +github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= @@ -752,8 +749,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= -golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= +golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -846,8 +843,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= -golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= +golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -876,8 +873,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= -golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= +golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -902,7 +899,6 @@ golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -911,7 +907,6 @@ golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -939,14 +934,16 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= +golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -954,8 +951,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= -golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= -golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= +golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= +golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -969,8 +966,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= -golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= +golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/management/domain/domain.go b/management/domain/domain.go index 2e089b01f..97acec688 100644 --- a/management/domain/domain.go +++ b/management/domain/domain.go @@ -1,12 +1,17 @@ package domain import ( + "strings" + "golang.org/x/net/idna" ) +// Domain represents a punycode-encoded domain string. +// This should only be converted from a string when the string already is in punycode, otherwise use FromString. type Domain string // String converts the Domain to a non-punycode string. +// For an infallible conversion, use SafeString. func (d Domain) String() (string, error) { unicode, err := idna.ToUnicode(string(d)) if err != nil { @@ -15,16 +20,17 @@ func (d Domain) String() (string, error) { return unicode, nil } -// SafeString converts the Domain to a non-punycode string, falling back to the original string if conversion fails. +// SafeString converts the Domain to a non-punycode string, falling back to the punycode string if conversion fails. func (d Domain) SafeString() string { str, err := d.String() if err != nil { - str = string(d) + return string(d) } return str } // PunycodeString returns the punycode representation of the Domain. +// This should only be used if a punycode domain is expected but only a string is supported. func (d Domain) PunycodeString() string { return string(d) } @@ -35,5 +41,5 @@ func FromString(s string) (Domain, error) { if err != nil { return "", err } - return Domain(ascii), nil + return Domain(strings.ToLower(ascii)), nil } diff --git a/management/domain/list.go b/management/domain/list.go index b6090c717..a988f4f70 100644 --- a/management/domain/list.go +++ b/management/domain/list.go @@ -5,6 +5,7 @@ import ( "strings" ) +// List is a slice of punycode-encoded domain strings. type List []Domain // ToStringList converts a List to a slice of string. @@ -53,7 +54,7 @@ func (d List) String() (string, error) { func (d List) SafeString() string { str, err := d.String() if err != nil { - return strings.Join(d.ToPunycodeList(), ", ") + return d.PunycodeString() } return str } @@ -101,7 +102,7 @@ func FromStringList(s []string) (List, error) { func FromPunycodeList(s []string) List { var dl List for _, domain := range s { - dl = append(dl, Domain(domain)) + dl = append(dl, Domain(strings.ToLower(domain))) } return dl } diff --git a/management/domain/validate.go b/management/domain/validate.go index bcbf26e05..a42aebe6f 100644 --- a/management/domain/validate.go +++ b/management/domain/validate.go @@ -22,8 +22,6 @@ func ValidateDomains(domains []string) (List, error) { var domainList List for _, d := range domains { - d := strings.ToLower(d) - // handles length and idna conversion punycode, err := FromString(d) if err != nil { diff --git a/management/server/types/account.go b/management/server/types/account.go index e9fa37085..8315f5796 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -1289,7 +1289,7 @@ func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer if route.Peer != peer.Key { continue } - resourceAppliedPolicies := resourcePolicies[route.GetResourceID()] + resourceAppliedPolicies := resourcePolicies[string(route.GetResourceID())] distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups) rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers) diff --git a/route/hauniqueid.go b/route/hauniqueid.go index 4d952beba..064608171 100644 --- a/route/hauniqueid.go +++ b/route/hauniqueid.go @@ -4,13 +4,14 @@ import "strings" const haSeparator = "|" +// HAUniqueID is a unique identifier that is used to group high availability routes. type HAUniqueID string func (id HAUniqueID) String() string { return string(id) } -// NetID returns the Network ID from the HAUniqueID +// NetID returns the NetID from the HAUniqueID func (id HAUniqueID) NetID() NetID { if i := strings.LastIndex(string(id), haSeparator); i != -1 { return NetID(id[:i]) diff --git a/route/route.go b/route/route.go index f7bf3ea87..722dacc2d 100644 --- a/route/route.go +++ b/route/route.go @@ -6,8 +6,6 @@ import ( "slices" "strings" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server/status" ) @@ -46,10 +44,16 @@ const ( DomainNetwork ) +// ID is the unique route ID. type ID string +// ResID is the resourceID part of a route.ID (first part before the colon). +type ResID string + +// NetID is the route network identifier, a human-readable string. type NetID string +// HAMap is a map of HAUniqueID to a list of routes. type HAMap map[HAUniqueID][]*Route // NetworkType route network type @@ -162,21 +166,25 @@ func (r *Route) IsDynamic() bool { return r.NetworkType == DomainNetwork } +// GetHAUniqueID returns the HAUniqueID for the route, it can be used for grouping. func (r *Route) GetHAUniqueID() HAUniqueID { - if r.IsDynamic() { - domains, err := r.Domains.String() - if err != nil { - log.Errorf("Failed to convert domains to string: %v", err) - domains = r.Domains.PunycodeString() - } - return HAUniqueID(fmt.Sprintf("%s%s%s", r.NetID, haSeparator, domains)) - } - return HAUniqueID(fmt.Sprintf("%s%s%s", r.NetID, haSeparator, r.Network.String())) + return HAUniqueID(fmt.Sprintf("%s%s%s", r.NetID, haSeparator, r.NetString())) } -// GetResourceID returns the Networks Resource ID from a route ID -func (r *Route) GetResourceID() string { - return strings.Split(string(r.ID), ":")[0] +// GetResourceID returns the Networks ResID from the route ID. +// It's the part before the first colon in the ID string. +func (r *Route) GetResourceID() ResID { + return ResID(strings.Split(string(r.ID), ":")[0]) +} + +// NetString returns the network string. +// If the route is dynamic, it returns the domains as comma-separated punycode-encoded string. +// If the route is not dynamic, it returns the network (prefix) string. +func (r *Route) NetString() string { + if r.IsDynamic() { + return r.Domains.SafeString() + } + return r.Network.String() } // ParseNetwork Parses a network prefix string and returns a netip.Prefix object and if is invalid, IPv4 or IPv6 From 2817f62c13a2bd51e332b796e23c4a468177580e Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 25 Apr 2025 09:26:18 +0200 Subject: [PATCH 13/45] [client] Fix error handling case of flow grpc error (#3727) When a gRPC error occurs in the Flow package, it will be propagated to the upper layers and handled similarly to a Management gRPC error. Always report a disconnected state in the event of any error Hide the underlying gRPC errors Force close the gRPC connection in the event of any error --- management/client/grpc.go | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/management/client/grpc.go b/management/client/grpc.go index d3aaffec0..956aaebb2 100644 --- a/management/client/grpc.go +++ b/management/client/grpc.go @@ -128,7 +128,13 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler return err } - return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler) + streamErr := c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler) + if c.conn.GetState() != connectivity.Shutdown { + if err := c.conn.Close(); err != nil { + log.Warnf("failed closing connection to Management service: %s", err) + } + } + return streamErr } err := backoff.Retry(operation, defaultBackoff(ctx)) @@ -159,6 +165,7 @@ func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, // blocking until error err = c.receiveEvents(stream, serverPubKey, msgHandler) if err != nil { + c.notifyDisconnected(err) s, _ := gstatus.FromError(err) switch s.Code() { case codes.PermissionDenied: @@ -167,7 +174,6 @@ func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, log.Debugf("management connection context has been canceled, this usually indicates shutdown") return nil default: - c.notifyDisconnected(err) log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err) return err } @@ -258,10 +264,10 @@ func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, se return err } - err = msgHandler(decryptedResp) - if err != nil { + if err := msgHandler(decryptedResp); err != nil { log.Errorf("failed handling an update message received from Management Service: %v", err.Error()) - return err + // hide any grpc error code that is not relevant for management + return fmt.Errorf("msg handler error: %v", err.Error()) } } } From ef8b8a28912e7979098b9f94294c8f23fc3aa81b Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 25 Apr 2025 12:43:20 +0200 Subject: [PATCH 14/45] [client] Ensure dst-type local marks can overwrite nat marks (#3738) --- client/firewall/iptables/router_linux.go | 4 +++- client/firewall/nftables/router_linux.go | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index b59c88580..bb799b99b 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -631,7 +631,9 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue), ) - if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil { + // Ensure nat rules come first, so the mark can be overwritten. + // Currently overwritten by the dst-type LOCAL rules for redirected traffic. + if err := r.iptablesClient.Insert(tableMangle, chainRTPRE, 1, rule...); err != nil { // TODO: rollback ipset counter return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err) } diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index c2ba2a072..0f6c5bdf6 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -666,7 +666,9 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { } } - r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ + // Ensure nat rules come first, so the mark can be overwritten. + // Currently overwritten by the dst-type LOCAL rules for redirected traffic. + r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{ Table: r.workTable, Chain: r.chains[chainNameManglePrerouting], Exprs: exprs, From c0eaea938e3a640d22cf87ddfdc947955f8aacf6 Mon Sep 17 00:00:00 2001 From: Carlos Hernandez Date: Fri, 25 Apr 2025 06:41:57 -0600 Subject: [PATCH 15/45] [client] Fix macos privacy warning when checking static info (#3496) avoid checking static info with a init call --- client/cmd/login.go | 3 +++ client/cmd/service_controller.go | 5 +++++ client/system/info.go | 7 +++++++ client/system/static_info.go | 6 ------ client/system/static_info_stub.go | 8 ++++++++ 5 files changed, 23 insertions(+), 6 deletions(-) create mode 100644 client/system/static_info_stub.go diff --git a/client/cmd/login.go b/client/cmd/login.go index c86d6c636..549eef40e 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -48,6 +48,9 @@ var loginCmd = &cobra.Command{ return err } + // update host's static platform and system information + system.UpdateStaticInfo() + // workaround to run without service if logFile == "console" { err = handleRebrand(cmd) diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go index 0ddf6c4c8..5e3c63e57 100644 --- a/client/cmd/service_controller.go +++ b/client/cmd/service_controller.go @@ -16,12 +16,17 @@ import ( "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/server" + "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/util" ) func (p *program) Start(svc service.Service) error { // Start should not block. Do the actual work async. log.Info("starting Netbird service") //nolint + + // Collect static system and platform information + system.UpdateStaticInfo() + // in any case, even if configuration does not exists we run daemon to serve CLI gRPC API. p.serv = grpc.NewServer() diff --git a/client/system/info.go b/client/system/info.go index 2a0343ca6..3a0c57156 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -185,3 +185,10 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro return info, nil } + +// UpdateStaticInfo asynchronously updates static system and platform information +func UpdateStaticInfo() { + go func() { + _ = updateStaticInfo() + }() +} diff --git a/client/system/static_info.go b/client/system/static_info.go index fabe65a68..f178ec932 100644 --- a/client/system/static_info.go +++ b/client/system/static_info.go @@ -16,12 +16,6 @@ var ( once sync.Once ) -func init() { - go func() { - _ = updateStaticInfo() - }() -} - func updateStaticInfo() StaticInfo { once.Do(func() { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) diff --git a/client/system/static_info_stub.go b/client/system/static_info_stub.go new file mode 100644 index 000000000..faa3e700b --- /dev/null +++ b/client/system/static_info_stub.go @@ -0,0 +1,8 @@ +//go:build android || freebsd || ios + +package system + +// updateStaticInfo returns an empty implementation for unsupported platforms +func updateStaticInfo() StaticInfo { + return StaticInfo{} +} From 39483f8ca818f38ecca9f951e67b66441a4ad20c Mon Sep 17 00:00:00 2001 From: Pedro Maia Costa <550684+pnmcosta@users.noreply.github.com> Date: Fri, 25 Apr 2025 15:04:25 +0100 Subject: [PATCH 16/45] [management] Auditor role (#3721) --- management/server/permissions/roles/auditor.go | 16 ++++++++++++++++ .../server/permissions/roles/role_permissions.go | 7 ++++--- management/server/types/user.go | 3 +++ 3 files changed, 23 insertions(+), 3 deletions(-) create mode 100644 management/server/permissions/roles/auditor.go diff --git a/management/server/permissions/roles/auditor.go b/management/server/permissions/roles/auditor.go new file mode 100644 index 000000000..33d8651f4 --- /dev/null +++ b/management/server/permissions/roles/auditor.go @@ -0,0 +1,16 @@ +package roles + +import ( + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/server/types" +) + +var Auditor = RolePermissions{ + Role: types.UserRoleAuditor, + AutoAllowNew: map[operations.Operation]bool{ + operations.Read: true, + operations.Create: false, + operations.Update: false, + operations.Delete: false, + }, +} diff --git a/management/server/permissions/roles/role_permissions.go b/management/server/permissions/roles/role_permissions.go index dda7e6b99..aca812fe2 100644 --- a/management/server/permissions/roles/role_permissions.go +++ b/management/server/permissions/roles/role_permissions.go @@ -15,7 +15,8 @@ type RolePermissions struct { type Permissions map[modules.Module]map[operations.Operation]bool var RolesMap = map[types.UserRole]RolePermissions{ - types.UserRoleOwner: Owner, - types.UserRoleAdmin: Admin, - types.UserRoleUser: User, + types.UserRoleOwner: Owner, + types.UserRoleAdmin: Admin, + types.UserRoleUser: User, + types.UserRoleAuditor: Auditor, } diff --git a/management/server/types/user.go b/management/server/types/user.go index 5f7a4f2cb..419e688f5 100644 --- a/management/server/types/user.go +++ b/management/server/types/user.go @@ -15,6 +15,7 @@ const ( UserRoleUser UserRole = "user" UserRoleUnknown UserRole = "unknown" UserRoleBillingAdmin UserRole = "billing_admin" + UserRoleAuditor UserRole = "auditor" UserStatusActive UserStatus = "active" UserStatusDisabled UserStatus = "disabled" @@ -35,6 +36,8 @@ func StrRoleToUserRole(strRole string) UserRole { return UserRoleUser case "billing_admin": return UserRoleBillingAdmin + case "auditor": + return UserRoleAuditor default: return UserRoleUnknown } From dbf81a145e8c259cb7e7f77cd2d984955415ca03 Mon Sep 17 00:00:00 2001 From: Pedro Maia Costa <550684+pnmcosta@users.noreply.github.com> Date: Fri, 25 Apr 2025 15:14:32 +0100 Subject: [PATCH 17/45] [management] network admin role (#3720) --- .../server/permissions/roles/network_admin.go | 91 +++++++++++++++++++ .../permissions/roles/role_permissions.go | 9 +- management/server/types/user.go | 1 + 3 files changed, 97 insertions(+), 4 deletions(-) create mode 100644 management/server/permissions/roles/network_admin.go diff --git a/management/server/permissions/roles/network_admin.go b/management/server/permissions/roles/network_admin.go new file mode 100644 index 000000000..761933386 --- /dev/null +++ b/management/server/permissions/roles/network_admin.go @@ -0,0 +1,91 @@ +package roles + +import ( + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/server/types" +) + +var NetworkAdmin = RolePermissions{ + Role: types.UserRoleNetworkAdmin, + AutoAllowNew: map[operations.Operation]bool{ + operations.Read: false, + operations.Create: false, + operations.Update: false, + operations.Delete: false, + }, + Permissions: Permissions{ + modules.Networks: { + operations.Read: true, + operations.Create: true, + operations.Update: true, + operations.Delete: true, + }, + modules.Groups: { + operations.Read: true, + operations.Create: false, + operations.Update: false, + operations.Delete: false, + }, + modules.Settings: { + operations.Read: true, + operations.Create: false, + operations.Update: false, + operations.Delete: false, + }, + modules.Accounts: { + operations.Read: true, + operations.Create: false, + operations.Update: false, + operations.Delete: false, + }, + modules.Dns: { + operations.Read: true, + operations.Create: true, + operations.Update: true, + operations.Delete: true, + }, + modules.Nameservers: { + operations.Read: true, + operations.Create: true, + operations.Update: true, + operations.Delete: true, + }, + modules.Events: { + operations.Read: true, + operations.Create: false, + operations.Update: false, + operations.Delete: false, + }, + modules.Policies: { + operations.Read: true, + operations.Create: true, + operations.Update: true, + operations.Delete: true, + }, + modules.Routes: { + operations.Read: true, + operations.Create: true, + operations.Update: true, + operations.Delete: true, + }, + modules.Users: { + operations.Read: true, + operations.Create: false, + operations.Update: false, + operations.Delete: false, + }, + modules.SetupKeys: { + operations.Read: true, + operations.Create: false, + operations.Update: false, + operations.Delete: false, + }, + modules.Pats: { + operations.Read: true, + operations.Create: true, + operations.Update: true, + operations.Delete: true, + }, + }, +} diff --git a/management/server/permissions/roles/role_permissions.go b/management/server/permissions/roles/role_permissions.go index aca812fe2..754e568f5 100644 --- a/management/server/permissions/roles/role_permissions.go +++ b/management/server/permissions/roles/role_permissions.go @@ -15,8 +15,9 @@ type RolePermissions struct { type Permissions map[modules.Module]map[operations.Operation]bool var RolesMap = map[types.UserRole]RolePermissions{ - types.UserRoleOwner: Owner, - types.UserRoleAdmin: Admin, - types.UserRoleUser: User, - types.UserRoleAuditor: Auditor, + types.UserRoleOwner: Owner, + types.UserRoleAdmin: Admin, + types.UserRoleUser: User, + types.UserRoleAuditor: Auditor, + types.UserRoleNetworkAdmin: NetworkAdmin, } diff --git a/management/server/types/user.go b/management/server/types/user.go index 419e688f5..e17a29bee 100644 --- a/management/server/types/user.go +++ b/management/server/types/user.go @@ -16,6 +16,7 @@ const ( UserRoleUnknown UserRole = "unknown" UserRoleBillingAdmin UserRole = "billing_admin" UserRoleAuditor UserRole = "auditor" + UserRoleNetworkAdmin UserRole = "network_admin" UserStatusActive UserStatus = "active" UserStatusDisabled UserStatus = "disabled" From 38ada44a0e5f255b1e831b4f3b72de90beac4df1 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 25 Apr 2025 16:40:54 +0200 Subject: [PATCH 18/45] [management] allow impersonation via pats (#3739) --- management/server/http/middleware/auth_middleware.go | 5 +++++ management/server/http/middleware/auth_middleware_test.go | 5 +++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 6f0d1556f..f2732fbf8 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -167,6 +167,11 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h IsPAT: true, } + if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 { + userAuth.AccountId = impersonate[0] + userAuth.IsChild = ok + } + return nbcontext.SetUserAuthInRequest(r, userAuth), nil } diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index 410ff7e15..2285ed244 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -242,14 +242,15 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { }, }, { - name: "Valid PAT Token ignores child", + name: "Valid PAT Token accesses child", path: "/test?account=xyz", authHeader: "Token " + PAT, expectedUserAuth: &nbcontext.UserAuth{ - AccountId: accountID, + AccountId: "xyz", UserId: userID, Domain: testAccount.Domain, DomainCategory: testAccount.DomainCategory, + IsChild: true, IsPAT: true, }, }, From 4fe4c2054d78b1e0f9e448f5acb2b5f862b50b59 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 25 Apr 2025 18:25:48 +0200 Subject: [PATCH 19/45] [client] Move static check when running on foreground (#3742) --- client/cmd/login.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/client/cmd/login.go b/client/cmd/login.go index 549eef40e..84906a7a4 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -48,9 +48,6 @@ var loginCmd = &cobra.Command{ return err } - // update host's static platform and system information - system.UpdateStaticInfo() - // workaround to run without service if logFile == "console" { err = handleRebrand(cmd) @@ -58,6 +55,9 @@ var loginCmd = &cobra.Command{ return err } + // update host's static platform and system information + system.UpdateStaticInfo() + ic := internal.ConfigInput{ ManagementURL: managementURL, AdminURL: adminURL, From 3cf87b68467c3ebe45e86d6cb5d3abe1c6209759 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 25 Apr 2025 18:50:44 +0200 Subject: [PATCH 20/45] [client] Run container tests more generically (#3737) --- .github/workflows/golang-test-linux.yml | 196 ++++++------------ client/Dockerfile | 5 +- client/firewall/uspfilter/tracer_test.go | 16 +- .../systemops/systemops_bsd_test.go | 1 - .../systemops/systemops_linux_test.go | 55 ++--- .../systemops/systemops_unix_test.go | 12 +- 6 files changed, 101 insertions(+), 184 deletions(-) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 4e690ff1b..2f1df9b1a 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -146,6 +146,64 @@ jobs: - name: Test run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay) + test_client_on_docker: + name: "Client (Docker) / Unit" + needs: [build-cache] + runs-on: ubuntu-22.04 + steps: + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + cache: false + + - name: Checkout code + uses: actions/checkout@v4 + + - name: Get Go environment + id: go-env + run: | + echo "cache_dir=$(go env GOCACHE)" >> $GITHUB_OUTPUT + echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT + + - name: Cache Go modules + uses: actions/cache/restore@v4 + id: cache-restore + with: + path: | + ${{ steps.go-env.outputs.cache_dir }} + ${{ steps.go-env.outputs.modcache_dir }} + key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-gotest-cache- + + - name: Run tests in container + env: + HOST_GOCACHE: ${{ steps.go-env.outputs.cache_dir }} + HOST_GOMODCACHE: ${{ steps.go-env.outputs.modcache_dir }} + run: | + CONTAINER_GOCACHE="/root/.cache/go-build" + CONTAINER_GOMODCACHE="/go/pkg/mod" + + docker run --rm \ + --cap-add=NET_ADMIN \ + --privileged \ + -v $PWD:/app \ + -w /app \ + -v "${HOST_GOCACHE}:${CONTAINER_GOCACHE}" \ + -v "${HOST_GOMODCACHE}:${CONTAINER_GOMODCACHE}" \ + -e CGO_ENABLED=1 \ + -e CI=true \ + -e GOARCH=${GOARCH_TARGET} \ + -e GOCACHE=${CONTAINER_GOCACHE} \ + -e GOMODCACHE=${CONTAINER_GOMODCACHE} \ + golang:1.23-alpine \ + sh -c ' \ + apk update; apk add --no-cache \ + ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \ + go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui) + ' + test_relay: name: "Relay / Unit" needs: [build-cache] @@ -179,13 +237,6 @@ jobs: restore-keys: | ${{ runner.os }}-gotest-cache- - - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - - - name: Install 32-bit libpcap - if: matrix.arch == '386' - run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 - - name: Install modules run: go mod tidy @@ -232,13 +283,6 @@ jobs: restore-keys: | ${{ runner.os }}-gotest-cache- - - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - - - name: Install 32-bit libpcap - if: matrix.arch == '386' - run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 - - name: Install modules run: go mod tidy @@ -286,13 +330,6 @@ jobs: restore-keys: | ${{ runner.os }}-gotest-cache- - - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - - - name: Install 32-bit libpcap - if: matrix.arch == '386' - run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 - - name: Install modules run: go mod tidy @@ -354,13 +391,6 @@ jobs: restore-keys: | ${{ runner.os }}-gotest-cache- - - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - - - name: Install 32-bit libpcap - if: matrix.arch == '386' - run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 - - name: Install modules run: go mod tidy @@ -449,13 +479,6 @@ jobs: restore-keys: | ${{ runner.os }}-gotest-cache- - - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - - - name: Install 32-bit libpcap - if: matrix.arch == '386' - run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 - - name: Install modules run: go mod tidy @@ -520,13 +543,6 @@ jobs: restore-keys: | ${{ runner.os }}-gotest-cache- - - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - - - name: Install 32-bit libpcap - if: matrix.arch == '386' - run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 - - name: Install modules run: go mod tidy @@ -541,99 +557,3 @@ jobs: go test -tags=integration \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ -timeout 20m ./management/... - - test_client_on_docker: - name: "Client (Docker) / Unit" - needs: [ build-cache ] - runs-on: ubuntu-22.04 - steps: - - name: Install Go - uses: actions/setup-go@v5 - with: - go-version: "1.23.x" - cache: false - - - name: Checkout code - uses: actions/checkout@v4 - - - name: Get Go environment - run: | - echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV - echo "modcache=$(go env.GOMODCACHE)" >> $GITHUB_ENV - - - name: Cache Go modules - uses: actions/cache/restore@v4 - with: - path: | - ${{ env.cache }} - ${{ env.modcache }} - key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-gotest-cache- - - - name: Install dependencies - run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev - - - name: Install modules - run: go mod tidy - - - name: Check git status - run: git --no-pager diff --exit-code - - - name: Generate Shared Sock Test bin - run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock - - - name: Generate RouteManager Test bin - run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager - - - name: Generate SystemOps Test bin (static via Alpine) - run: | - docker run --rm -v $PWD:/app -w /app \ - alpine:latest \ - sh -c " - apk add --no-cache go gcc musl-dev libpcap-dev dbus-dev && \ - adduser -D -u $(id -u) builder && \ - su builder -c '\ - cd /app && \ - CGO_ENABLED=1 GOOS=linux GOARCH=amd64 \ - go test -c -o /app/systemops-testing.bin \ - -tags netgo \ - -ldflags=\"-w -extldflags \\\"-static -ldbus-1 -lpcap\\\"\" \ - ./client/internal/routemanager/systemops \ - ' - " - - - name: Generate nftables Manager Test bin - run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/... - - - name: Generate Engine Test bin - run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal - - - name: Generate Peer Test bin - run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/ - - - run: chmod +x *testing.bin - - - name: Run Shared Sock tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /ci/sharedsock-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1 - - - name: Run Iface tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/... - - - name: Run RouteManager tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /ci/routemanager-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1 - - - name: Run SystemOps tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /ci/systemops-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1 - - - name: Run nftables Manager tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /ci/nftablesmanager-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1 - - - name: Run Engine tests in docker with file store - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /ci/engine-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1 - - - name: Run Engine tests in docker with sqlite store - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /ci/engine-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1 - - - name: Run Peer tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /ci/peer-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1 diff --git a/client/Dockerfile b/client/Dockerfile index 35c1d04c2..16b2916c7 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -1,5 +1,6 @@ FROM alpine:3.21.3 -RUN apk add --no-cache ca-certificates iptables ip6tables +# iproute2: busybox doesn't display ip rules properly +RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables ENV NB_FOREGROUND_MODE=true ENTRYPOINT [ "/usr/local/bin/netbird","up"] -COPY netbird /usr/local/bin/netbird \ No newline at end of file +COPY netbird /usr/local/bin/netbird diff --git a/client/firewall/uspfilter/tracer_test.go b/client/firewall/uspfilter/tracer_test.go index 53ee6c886..bd87879a5 100644 --- a/client/firewall/uspfilter/tracer_test.go +++ b/client/firewall/uspfilter/tracer_test.go @@ -198,12 +198,12 @@ func TestTracePacket(t *testing.T) { m.forwarder.Store(&forwarder.Forwarder{}) src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32) - dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32) + dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32) _, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept) require.NoError(t, err) }, packetBuilder: func() *PacketBuilder { - return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN) + return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN) }, expectedStages: []PacketStage{ StageReceived, @@ -222,12 +222,12 @@ func TestTracePacket(t *testing.T) { m.nativeRouter.Store(false) src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32) - dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32) + dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32) _, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop) require.NoError(t, err) }, packetBuilder: func() *PacketBuilder { - return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN) + return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN) }, expectedStages: []PacketStage{ StageReceived, @@ -245,7 +245,7 @@ func TestTracePacket(t *testing.T) { m.nativeRouter.Store(true) }, packetBuilder: func() *PacketBuilder { - return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN) + return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN) }, expectedStages: []PacketStage{ StageReceived, @@ -263,7 +263,7 @@ func TestTracePacket(t *testing.T) { m.routingEnabled.Store(false) }, packetBuilder: func() *PacketBuilder { - return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN) + return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN) }, expectedStages: []PacketStage{ StageReceived, @@ -425,8 +425,8 @@ func TestTracePacket(t *testing.T) { require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")), "100.10.0.100 should be recognized as a local IP") - require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")), - "172.17.0.2 should not be recognized as a local IP") + require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("192.168.17.2")), + "192.168.17.2 should not be recognized as a local IP") pb := tc.packetBuilder() diff --git a/client/internal/routemanager/systemops/systemops_bsd_test.go b/client/internal/routemanager/systemops/systemops_bsd_test.go index 84b84483e..a83d7f1de 100644 --- a/client/internal/routemanager/systemops/systemops_bsd_test.go +++ b/client/internal/routemanager/systemops/systemops_bsd_test.go @@ -24,7 +24,6 @@ func init() { testCases = append(testCases, []testCase{ { name: "To more specific route without custom dialer via vpn", - destination: "10.10.0.2:53", expectedInterface: expectedVPNint, dialer: &net.Dialer{}, expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53), diff --git a/client/internal/routemanager/systemops/systemops_linux_test.go b/client/internal/routemanager/systemops/systemops_linux_test.go index 8f12740d0..f0d7472dc 100644 --- a/client/internal/routemanager/systemops/systemops_linux_test.go +++ b/client/internal/routemanager/systemops/systemops_linux_test.go @@ -27,14 +27,12 @@ func init() { testCases = append(testCases, []testCase{ { name: "To more specific route without custom dialer via physical interface", - destination: "10.10.0.2:53", expectedInterface: expectedInternalInt, dialer: &net.Dialer{}, expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53), }, { name: "To more specific route (local) without custom dialer via physical interface", - destination: "127.0.10.1:53", expectedInterface: expectedLoopbackInt, dialer: &net.Dialer{}, expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53), @@ -134,6 +132,16 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { _, dstIPNet, err := net.ParseCIDR(dstCIDR) require.NoError(t, err) + link, err := netlink.LinkByName(intf) + require.NoError(t, err) + linkIndex := link.Attrs().Index + + route := &netlink.Route{ + Dst: dstIPNet, + Gw: gw, + LinkIndex: linkIndex, + } + // Handle existing routes with metric 0 var originalNexthop net.IP var originalLinkIndex int @@ -145,32 +153,24 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { } if originalNexthop != nil { + // remove original route err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0}) - switch { - case err != nil && !errors.Is(err, syscall.ESRCH): - t.Logf("Failed to delete route: %v", err) - case err == nil: - t.Cleanup(func() { - err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0}) - if err != nil && !errors.Is(err, syscall.EEXIST) { - t.Fatalf("Failed to add route: %v", err) - } - }) - default: - t.Logf("Failed to delete route: %v", err) - } + assert.NoError(t, err) + + // add new route + assert.NoError(t, netlink.RouteAdd(route)) + + t.Cleanup(func() { + // restore original route + assert.NoError(t, netlink.RouteDel(route)) + err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0}) + assert.NoError(t, err) + }) + + return } } - link, err := netlink.LinkByName(intf) - require.NoError(t, err) - linkIndex := link.Attrs().Index - - route := &netlink.Route{ - Dst: dstIPNet, - Gw: gw, - LinkIndex: linkIndex, - } err = netlink.RouteDel(route) if err != nil && !errors.Is(err, syscall.ESRCH) { t.Logf("Failed to delete route: %v", err) @@ -180,7 +180,6 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { if err != nil && !errors.Is(err, syscall.EEXIST) { t.Fatalf("Failed to add route: %v", err) } - require.NoError(t, err) } func fetchOriginalGateway(family int) (net.IP, int, error) { @@ -190,7 +189,11 @@ func fetchOriginalGateway(family int) (net.IP, int, error) { } for _, route := range routes { - if route.Dst == nil && route.Priority == 0 { + ones := -1 + if route.Dst != nil { + ones, _ = route.Dst.Mask.Size() + } + if route.Dst == nil || ones == 0 && route.Priority == 0 { return route.Gw, route.LinkIndex, nil } } diff --git a/client/internal/routemanager/systemops/systemops_unix_test.go b/client/internal/routemanager/systemops/systemops_unix_test.go index d88c1ab6b..ad37f611f 100644 --- a/client/internal/routemanager/systemops/systemops_unix_test.go +++ b/client/internal/routemanager/systemops/systemops_unix_test.go @@ -31,7 +31,6 @@ type PacketExpectation struct { type testCase struct { name string - destination string expectedInterface string dialer dialer expectedPacket PacketExpectation @@ -40,14 +39,12 @@ type testCase struct { var testCases = []testCase{ { name: "To external host without custom dialer via vpn", - destination: "192.0.2.1:53", expectedInterface: expectedVPNint, dialer: &net.Dialer{}, expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53), }, { name: "To external host with custom dialer via physical interface", - destination: "192.0.2.1:53", expectedInterface: expectedExternalInt, dialer: nbnet.NewDialer(), expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53), @@ -55,14 +52,12 @@ var testCases = []testCase{ { name: "To duplicate internal route with custom dialer via physical interface", - destination: "10.0.0.2:53", expectedInterface: expectedInternalInt, dialer: nbnet.NewDialer(), expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), }, { name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence - destination: "10.0.0.2:53", expectedInterface: expectedInternalInt, dialer: &net.Dialer{}, expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), @@ -70,14 +65,12 @@ var testCases = []testCase{ { name: "To unique vpn route with custom dialer via physical interface", - destination: "172.16.0.2:53", expectedInterface: expectedExternalInt, dialer: nbnet.NewDialer(), expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53), }, { name: "To unique vpn route without custom dialer via vpn", - destination: "172.16.0.2:53", expectedInterface: expectedVPNint, dialer: &net.Dialer{}, expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53), @@ -94,10 +87,11 @@ func TestRouting(t *testing.T) { t.Run(tc.name, func(t *testing.T) { setupTestEnv(t) - filter := createBPFFilter(tc.destination) + dst := fmt.Sprintf("%s:%d", tc.expectedPacket.DstIP, tc.expectedPacket.DstPort) + filter := createBPFFilter(dst) handle := startPacketCapture(t, tc.expectedInterface, filter) - sendTestPacket(t, tc.destination, tc.expectedPacket.SrcPort, tc.dialer) + sendTestPacket(t, dst, tc.expectedPacket.SrcPort, tc.dialer) packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) packet, err := packetSource.NextPacket() From 84bfecdd3743afc73aa2f614500e52e0d6b8bbad Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Mon, 28 Apr 2025 10:10:41 +0300 Subject: [PATCH 21/45] [client] add byte counters & ruleID for routed traffic on userspace (#3653) * [client] add byte counters for routed traffic on userspace * [client] add allowed ruleID for routed traffic on userspace --- .../firewall/uspfilter/forwarder/forwarder.go | 41 +++++- client/firewall/uspfilter/forwarder/icmp.go | 69 ++++++---- client/firewall/uspfilter/forwarder/tcp.go | 118 +++++++++++------- client/firewall/uspfilter/forwarder/udp.go | 108 ++++++++++------ client/firewall/uspfilter/uspfilter.go | 6 +- 5 files changed, 235 insertions(+), 107 deletions(-) diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index 0dff3acc7..2ae983f6e 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -4,7 +4,9 @@ import ( "context" "fmt" "net" + "net/netip" "runtime" + "sync" log "github.com/sirupsen/logrus" "gvisor.dev/gvisor/pkg/buffer" @@ -17,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "github.com/netbirdio/netbird/client/firewall/uspfilter/common" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" ) @@ -29,8 +32,10 @@ const ( ) type Forwarder struct { - logger *nblog.Logger - flowLogger nftypes.FlowLogger + logger *nblog.Logger + flowLogger nftypes.FlowLogger + // ruleIdMap is used to store the rule ID for a given connection + ruleIdMap sync.Map stack *stack.Stack endpoint *endpoint udpForwarder *udpForwarder @@ -167,3 +172,35 @@ func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP { } return addr.AsSlice() } + +func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) { + key := buildKey(srcIP, dstIP, srcPort, dstPort) + f.ruleIdMap.LoadOrStore(key, ruleID) +} + +func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) { + + if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok { + return value.([]byte), true + } else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok { + return value.([]byte), true + } + + return nil, false +} + +func (f *Forwarder) DeleteRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) { + if _, ok := f.ruleIdMap.LoadAndDelete(buildKey(srcIP, dstIP, srcPort, dstPort)); ok { + return + } + f.ruleIdMap.LoadAndDelete(buildKey(dstIP, srcIP, dstPort, srcPort)) +} + +func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKey { + return conntrack.ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } +} diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go index a21ec2c87..08d77ed05 100644 --- a/client/firewall/uspfilter/forwarder/icmp.go +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -25,7 +25,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf } flowID := uuid.New() - f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode) + f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0) ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second) defer cancel() @@ -34,14 +34,14 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf // TODO: support non-root conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") if err != nil { - f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err) + f.logger.Error("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err) // This will make netstack reply on behalf of the original destination, that's ok for now return false } defer func() { if err := conn.Close(); err != nil { - f.logger.Debug("Failed to close ICMP socket: %v", err) + f.logger.Debug("forwarder: Failed to close ICMP socket: %v", err) } }() @@ -52,36 +52,37 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf payload := fullPacket.AsSlice() if _, err = conn.WriteTo(payload, dst); err != nil { - f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err) + f.logger.Error("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err) return true } - f.logger.Trace("Forwarded ICMP packet %v type %v code %v", + f.logger.Trace("forwarder: Forwarded ICMP packet %v type %v code %v", epID(id), icmpHdr.Type(), icmpHdr.Code()) // For Echo Requests, send and handle response if header.ICMPv4Type(icmpType) == header.ICMPv4Echo { - f.handleEchoResponse(icmpHdr, conn, id) - f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode) + rxBytes := pkt.Size() + txBytes := f.handleEchoResponse(icmpHdr, conn, id) + f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) } // For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing return true } -func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) { +func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int { if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { - f.logger.Error("Failed to set read deadline for ICMP response: %v", err) - return + f.logger.Error("forwarder: Failed to set read deadline for ICMP response: %v", err) + return 0 } response := make([]byte, f.endpoint.mtu) n, _, err := conn.ReadFrom(response) if err != nil { if !isTimeout(err) { - f.logger.Error("Failed to read ICMP response: %v", err) + f.logger.Error("forwarder: Failed to read ICMP response: %v", err) } - return + return 0 } ipHdr := make([]byte, header.IPv4MinimumSize) @@ -100,28 +101,54 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon fullPacket = append(fullPacket, response[:n]...) if err := f.InjectIncomingPacket(fullPacket); err != nil { - f.logger.Error("Failed to inject ICMP response: %v", err) + f.logger.Error("forwarder: Failed to inject ICMP response: %v", err) - return + return 0 } - f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v", + f.logger.Trace("forwarder: Forwarded ICMP echo reply for %v type %v code %v", epID(id), icmpHdr.Type(), icmpHdr.Code()) + + return len(fullPacket) } // sendICMPEvent stores flow events for ICMP packets -func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) { - f.flowLogger.StoreEvent(nftypes.EventFields{ +func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, rxBytes, txBytes uint64) { + var rxPackets, txPackets uint64 + if rxBytes > 0 { + rxPackets = 1 + } + if txBytes > 0 { + txPackets = 1 + } + + srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) + dstIp := netip.AddrFrom4(id.LocalAddress.As4()) + + fields := nftypes.EventFields{ FlowID: flowID, Type: typ, Direction: nftypes.Ingress, Protocol: nftypes.ICMP, // TODO: handle ipv6 - SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()), - DestIP: netip.AddrFrom4(id.LocalAddress.As4()), + SourceIP: srcIp, + DestIP: dstIp, ICMPType: icmpType, ICMPCode: icmpCode, - // TODO: get packets/bytes - }) + RxBytes: rxBytes, + TxBytes: txBytes, + RxPackets: rxPackets, + TxPackets: txPackets, + } + + if typ == nftypes.TypeStart { + if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok { + fields.RuleID = ruleId + } + } else { + f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort) + } + + f.flowLogger.StoreEvent(fields) } diff --git a/client/firewall/uspfilter/forwarder/tcp.go b/client/firewall/uspfilter/forwarder/tcp.go index 71cd457ef..04b3ae233 100644 --- a/client/firewall/uspfilter/forwarder/tcp.go +++ b/client/firewall/uspfilter/forwarder/tcp.go @@ -6,8 +6,10 @@ import ( "io" "net" "net/netip" + "sync" "github.com/google/uuid" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -23,11 +25,11 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { flowID := uuid.New() - f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil) + f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0) var success bool defer func() { if !success { - f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil) + f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0) } }() @@ -65,67 +67,97 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { } func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) { - defer func() { - if err := inConn.Close(); err != nil { - f.logger.Debug("forwarder: inConn close error: %v", err) - } - if err := outConn.Close(); err != nil { - f.logger.Debug("forwarder: outConn close error: %v", err) - } - ep.Close() - f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep) - }() - - // Create context for managing the proxy goroutines ctx, cancel := context.WithCancel(f.ctx) defer cancel() - errChan := make(chan error, 2) - go func() { - _, err := io.Copy(outConn, inConn) - errChan <- err - }() - - go func() { - _, err := io.Copy(inConn, outConn) - errChan <- err - }() - - select { - case <-ctx.Done(): - f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id)) - return - case err := <-errChan: - if err != nil && !isClosedError(err) { - f.logger.Error("proxyTCP: copy error: %v", err) + <-ctx.Done() + // Close connections and endpoint. + if err := inConn.Close(); err != nil && !isClosedError(err) { + f.logger.Debug("forwarder: inConn close error: %v", err) + } + if err := outConn.Close(); err != nil && !isClosedError(err) { + f.logger.Debug("forwarder: outConn close error: %v", err) + } + + ep.Close() + }() + + var wg sync.WaitGroup + wg.Add(2) + + var ( + bytesFromInToOut int64 // bytes from client to server (tx for client) + bytesFromOutToIn int64 // bytes from server to client (rx for client) + errInToOut error + errOutToIn error + ) + + go func() { + bytesFromInToOut, errInToOut = io.Copy(outConn, inConn) + cancel() + wg.Done() + }() + + go func() { + + bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn) + cancel() + wg.Done() + }() + + wg.Wait() + + if errInToOut != nil { + if !isClosedError(errInToOut) { + f.logger.Error("proxyTCP: copy error (in -> out): %v", errInToOut) } - f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id)) - return } + if errOutToIn != nil { + if !isClosedError(errOutToIn) { + f.logger.Error("proxyTCP: copy error (out -> in): %v", errOutToIn) + } + } + + var rxPackets, txPackets uint64 + if tcpStats, ok := ep.Stats().(*tcp.Stats); ok { + // fields are flipped since this is the in conn + rxPackets = tcpStats.SegmentsSent.Value() + txPackets = tcpStats.SegmentsReceived.Value() + } + + f.logger.Trace("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut) + + f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets) } -func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) { +func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) { + srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) + dstIp := netip.AddrFrom4(id.LocalAddress.As4()) + fields := nftypes.EventFields{ FlowID: flowID, Type: typ, Direction: nftypes.Ingress, Protocol: nftypes.TCP, // TODO: handle ipv6 - SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()), - DestIP: netip.AddrFrom4(id.LocalAddress.As4()), + SourceIP: srcIp, + DestIP: dstIp, SourcePort: id.RemotePort, DestPort: id.LocalPort, + RxBytes: rxBytes, + TxBytes: txBytes, + RxPackets: rxPackets, + TxPackets: txPackets, } - if ep != nil { - if tcpStats, ok := ep.Stats().(*tcp.Stats); ok { - // fields are flipped since this is the in conn - // TODO: get bytes - fields.RxPackets = tcpStats.SegmentsSent.Value() - fields.TxPackets = tcpStats.SegmentsReceived.Value() + if typ == nftypes.TypeStart { + if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok { + fields.RuleID = ruleId } + } else { + f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort) } f.flowLogger.StoreEvent(fields) diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index 7ce85e2b6..cb88aa59a 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -149,11 +149,11 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { flowID := uuid.New() - f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil) + f.sendUDPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0) var success bool defer func() { if !success { - f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil) + f.sendUDPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0) } }() @@ -199,7 +199,6 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { if err := outConn.Close(); err != nil { f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err) } - return } f.udpForwarder.conns[id] = pConn @@ -212,68 +211,94 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { } func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) { - defer func() { + + ctx, cancel := context.WithCancel(f.ctx) + defer cancel() + + go func() { + <-ctx.Done() + pConn.cancel() - if err := pConn.conn.Close(); err != nil { + if err := pConn.conn.Close(); err != nil && !isClosedError(err) { f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err) } - if err := pConn.outConn.Close(); err != nil { + if err := pConn.outConn.Close(); err != nil && !isClosedError(err) { f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err) } ep.Close() - - f.udpForwarder.Lock() - delete(f.udpForwarder.conns, id) - f.udpForwarder.Unlock() - - f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep) }() - errChan := make(chan error, 2) + var wg sync.WaitGroup + wg.Add(2) + var txBytes, rxBytes int64 + var outboundErr, inboundErr error + + // outbound->inbound: copy from pConn.conn to pConn.outConn go func() { - errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound") + defer wg.Done() + txBytes, outboundErr = pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound") }() + // inbound->outbound: copy from pConn.outConn to pConn.conn go func() { - errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound") + defer wg.Done() + rxBytes, inboundErr = pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound") }() - select { - case <-ctx.Done(): - f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id)) - return - case err := <-errChan: - if err != nil && !isClosedError(err) { - f.logger.Error("proxyUDP: copy error: %v", err) - } - f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id)) - return + wg.Wait() + + if outboundErr != nil && !isClosedError(outboundErr) { + f.logger.Error("proxyUDP: copy error (outbound->inbound): %v", outboundErr) } + if inboundErr != nil && !isClosedError(inboundErr) { + f.logger.Error("proxyUDP: copy error (inbound->outbound): %v", inboundErr) + } + + var rxPackets, txPackets uint64 + if udpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok { + // fields are flipped since this is the in conn + rxPackets = udpStats.PacketsSent.Value() + txPackets = udpStats.PacketsReceived.Value() + } + + f.logger.Trace("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes) + + f.udpForwarder.Lock() + delete(f.udpForwarder.conns, id) + f.udpForwarder.Unlock() + + f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, uint64(rxBytes), uint64(txBytes), rxPackets, txPackets) } // sendUDPEvent stores flow events for UDP connections -func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) { +func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) { + srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) + dstIp := netip.AddrFrom4(id.LocalAddress.As4()) + fields := nftypes.EventFields{ FlowID: flowID, Type: typ, Direction: nftypes.Ingress, Protocol: nftypes.UDP, // TODO: handle ipv6 - SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()), - DestIP: netip.AddrFrom4(id.LocalAddress.As4()), + SourceIP: srcIp, + DestIP: dstIp, SourcePort: id.RemotePort, DestPort: id.LocalPort, + RxBytes: rxBytes, + TxBytes: txBytes, + RxPackets: rxPackets, + TxPackets: txPackets, } - if ep != nil { - if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok { - // fields are flipped since this is the in conn - // TODO: get bytes - fields.RxPackets = tcpStats.PacketsSent.Value() - fields.TxPackets = tcpStats.PacketsReceived.Value() + if typ == nftypes.TypeStart { + if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok { + fields.RuleID = ruleId } + } else { + f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort) } f.flowLogger.StoreEvent(fields) @@ -288,18 +313,20 @@ func (c *udpPacketConn) getIdleDuration() time.Duration { return time.Since(lastSeen) } -func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error { +// copy reads from src and writes to dst. +func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) (int64, error) { bufp := bufPool.Get().(*[]byte) defer bufPool.Put(bufp) buffer := *bufp + var totalBytes int64 = 0 for { if ctx.Err() != nil { - return ctx.Err() + return totalBytes, ctx.Err() } if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil { - return fmt.Errorf("set read deadline: %w", err) + return totalBytes, fmt.Errorf("set read deadline: %w", err) } n, err := src.Read(buffer) @@ -307,14 +334,15 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu if isTimeout(err) { continue } - return fmt.Errorf("read from %s: %w", direction, err) + return totalBytes, fmt.Errorf("read from %s: %w", direction, err) } - _, err = dst.Write(buffer[:n]) + nWritten, err := dst.Write(buffer[:n]) if err != nil { - return fmt.Errorf("write to %s: %w", direction, err) + return totalBytes, fmt.Errorf("write to %s: %w", direction, err) } + totalBytes += int64(nWritten) c.updateLastSeen() } } diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index ccf0be225..11730dbb3 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -824,7 +824,8 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe proto, pnum := getProtocolFromPacket(d) srcPort, dstPort := getPortsFromPacket(d) - if ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass { + ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) + if !pass { m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", ruleID, pnum, srcIP, srcPort, dstIP, dstPort) @@ -850,8 +851,11 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe if fwd == nil { m.logger.Trace("failed to forward routed packet (forwarder not initialized)") } else { + fwd.RegisterRuleID(srcIP, dstIP, srcPort, dstPort, ruleID) + if err := fwd.InjectIncomingPacket(packetData); err != nil { m.logger.Error("Failed to inject routed packet: %v", err) + fwd.DeleteRuleID(srcIP, dstIP, srcPort, dstPort) } } From 47c3afe56199f7d36df2dd0e7eb4e3be2a7c5ecb Mon Sep 17 00:00:00 2001 From: Pedro Maia Costa <550684+pnmcosta@users.noreply.github.com> Date: Mon, 28 Apr 2025 11:05:27 +0100 Subject: [PATCH 22/45] [management] add missing network admin mapping (#3751) --- management/server/types/user.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/management/server/types/user.go b/management/server/types/user.go index e17a29bee..a2596b3cb 100644 --- a/management/server/types/user.go +++ b/management/server/types/user.go @@ -39,6 +39,8 @@ func StrRoleToUserRole(strRole string) UserRole { return UserRoleBillingAdmin case "auditor": return UserRoleAuditor + case "network_admin": + return UserRoleNetworkAdmin default: return UserRoleUnknown } From 3fa915e271466b3749336bc20c0f06f3aa802e10 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 28 Apr 2025 13:40:36 +0200 Subject: [PATCH 23/45] [misc] Exclude client benchmarks from CI (#3752) --- .github/workflows/golang-test-linux.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 2f1df9b1a..faadcb3b3 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -415,7 +415,7 @@ jobs: CI=true \ go test -tags devcert -run=^$ -bench=. \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ - -timeout 20m ./... + -timeout 20m ./management/... api_benchmark: name: "Management / Benchmark (API)" From d8dc107bee6102b6130c8b8cdd64764d42dd12f3 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Mon, 28 Apr 2025 15:10:40 +0300 Subject: [PATCH 24/45] [management] Skip IdP cache warm-up on Redis if data exists (#3733) * Add Redis cache check to skip warm-up on startup if cache is already populated * Refactor Redis test container setup for reusability --- management/server/account.go | 49 ++++++++++++++- management/server/account_test.go | 72 +++++++++++++++++++---- management/server/cache/idp_test.go | 17 +----- management/server/store/sql_store.go | 13 ++++ management/server/store/sql_store_test.go | 25 ++++++++ management/server/store/store.go | 1 + management/server/testutil/store.go | 26 ++++++++ management/server/testutil/store_ios.go | 6 ++ 8 files changed, 182 insertions(+), 27 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index cc5ca309a..ab1ffe8b3 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -17,6 +17,7 @@ import ( "time" cacheStore "github.com/eko/gocache/lib/v4/store" + "github.com/eko/gocache/store/redis/v4" "github.com/rs/xid" log "github.com/sirupsen/logrus" "github.com/vmihailenco/msgpack/v5" @@ -237,7 +238,7 @@ func BuildManager( if !isNil(am.idpManager) { go func() { - err := am.warmupIDPCache(ctx) + err := am.warmupIDPCache(ctx, cacheStore) if err != nil { log.WithContext(ctx).Warnf("failed warming up cache due to error: %v", err) // todo retry? @@ -494,7 +495,25 @@ func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain return nil, status.Errorf(status.Internal, "error while creating new account") } -func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error { +func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context, store cacheStore.StoreInterface) error { + cold, err := am.isCacheCold(ctx, store) + if err != nil { + return err + } + + if !cold { + log.WithContext(ctx).Debug("cache already populated, skipping warm up") + return nil + } + + if delayStr, ok := os.LookupEnv("NB_IDP_CACHE_WARMUP_DELAY"); ok { + delay, err := time.ParseDuration(delayStr) + if err != nil { + return fmt.Errorf("invalid IDP warmup delay: %w", err) + } + time.Sleep(delay) + } + userData, err := am.idpManager.GetAllAccounts(ctx) if err != nil { return err @@ -534,6 +553,32 @@ func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error { return nil } +// isCacheCold checks if the cache needs warming up. +func (am *DefaultAccountManager) isCacheCold(ctx context.Context, store cacheStore.StoreInterface) (bool, error) { + if store.GetType() != redis.RedisType { + return true, nil + } + + accountID, err := am.Store.GetAnyAccountID(ctx) + if err != nil { + if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { + return true, nil + } + return false, err + } + + _, err = store.Get(ctx, accountID) + if err == nil { + return false, nil + } + + if notFoundErr := new(cacheStore.NotFound); errors.As(err, ¬FoundErr) { + return true, nil + } + + return false, fmt.Errorf("failed to check cache: %w", err) +} + // DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) diff --git a/management/server/account_test.go b/management/server/account_test.go index 7f34cf845..fe082d9a0 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -14,30 +14,30 @@ import ( "time" "github.com/golang/mock/gomock" - - nbAccount "github.com/netbirdio/netbird/management/server/account" - "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/settings" - "github.com/netbirdio/netbird/management/server/util" - - resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" - routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" - networkTypes "github.com/netbirdio/netbird/management/server/networks/types" - + "github.com/netbirdio/netbird/management/server/idp" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" nbdns "github.com/netbirdio/netbird/dns" + nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/cache" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/testutil" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" ) @@ -3201,3 +3201,53 @@ func Test_UpdateToPrimaryAccount(t *testing.T) { assert.NoError(t, err) assert.True(t, account.IsDomainPrimaryAccount) } + +func TestDefaultAccountManager_IsCacheCold(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err) + + t.Run("memory cache", func(t *testing.T) { + t.Run("should always return true", func(t *testing.T) { + cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond) + require.NoError(t, err) + + cold, err := manager.isCacheCold(context.Background(), cacheStore) + assert.NoError(t, err) + assert.True(t, cold) + }) + }) + + t.Run("redis cache", func(t *testing.T) { + cleanup, redisURL, err := testutil.CreateRedisTestContainer() + require.NoError(t, err) + t.Cleanup(cleanup) + t.Setenv(cache.RedisStoreEnvVar, redisURL) + + cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond) + require.NoError(t, err) + + t.Run("should return true when no account exists", func(t *testing.T) { + cold, err := manager.isCacheCold(context.Background(), cacheStore) + assert.NoError(t, err) + assert.True(t, cold) + }) + + account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + require.NoError(t, err) + + t.Run("should return true when account is not found in cache", func(t *testing.T) { + cold, err := manager.isCacheCold(context.Background(), cacheStore) + assert.NoError(t, err) + assert.True(t, cold) + }) + + t.Run("should return false when account is found in cache", func(t *testing.T) { + err = cacheStore.Set(context.Background(), account.Id, &idp.UserData{ID: "v", Name: "vv"}) + require.NoError(t, err) + + cold, err := manager.isCacheCold(context.Background(), cacheStore) + assert.NoError(t, err) + assert.False(t, cold) + }) + }) +} diff --git a/management/server/cache/idp_test.go b/management/server/cache/idp_test.go index beefcd9bd..3fcfbb11a 100644 --- a/management/server/cache/idp_test.go +++ b/management/server/cache/idp_test.go @@ -8,12 +8,11 @@ import ( "github.com/eko/gocache/lib/v4/store" "github.com/redis/go-redis/v9" - "github.com/testcontainers/testcontainers-go" - testcontainersredis "github.com/testcontainers/testcontainers-go/modules/redis" "github.com/vmihailenco/msgpack/v5" "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/testutil" ) func TestNewIDPCacheManagers(t *testing.T) { @@ -27,21 +26,11 @@ func TestNewIDPCacheManagers(t *testing.T) { for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { if tc.redis { - ctx := context.Background() - redisContainer, err := testcontainersredis.RunContainer(ctx, testcontainers.WithImage("redis:7")) + cleanup, redisURL, err := testutil.CreateRedisTestContainer() if err != nil { t.Fatalf("couldn't start redis container: %s", err) } - defer func() { - if err := redisContainer.Terminate(ctx); err != nil { - t.Logf("failed to terminate container: %s", err) - } - }() - redisURL, err := redisContainer.ConnectionString(ctx) - if err != nil { - t.Fatalf("couldn't get connection string: %s", err) - } - + t.Cleanup(cleanup) t.Setenv(cache.RedisStoreEnvVar, redisURL) } cacheStore, err := cache.NewStore(context.Background(), cache.DefaultIDPCacheExpirationMax, cache.DefaultIDPCacheCleanupInterval) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index b73c372ae..7d3b288e0 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -800,6 +800,19 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) ( return s.GetAccount(ctx, peer.AccountID) } +func (s *SqlStore) GetAnyAccountID(ctx context.Context) (string, error) { + var account types.Account + result := s.db.WithContext(ctx).Select("id").Limit(1).Find(&account) + if result.Error != nil { + return "", status.NewGetAccountFromStoreError(result.Error) + } + if result.RowsAffected == 0 { + return "", status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + return account.Id, nil +} + func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) { var peer nbpeer.Peer var accountID string diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index c16a50108..8bd8ce098 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -3263,3 +3263,28 @@ func TestSqlStore_GetAccountMeta(t *testing.T) { require.Equal(t, "private", accountMeta.DomainCategory) require.Equal(t, time.Date(2024, time.October, 2, 14, 1, 38, 210000000, time.UTC), accountMeta.CreatedAt.UTC()) } + +func TestSqlStore_GetAnyAccountID(t *testing.T) { + t.Run("should return account ID when accounts exist", func(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID, err := store.GetAnyAccountID(context.Background()) + require.NoError(t, err) + assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", accountID) + }) + + t.Run("should return error when no accounts exist", func(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID, err := store.GetAnyAccountID(context.Background()) + require.Error(t, err) + sErr, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, sErr.Type(), status.NotFound) + assert.Empty(t, accountID) + }) +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 4a26bf5c3..ca332a493 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -55,6 +55,7 @@ type Store interface { GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error) + GetAnyAccountID(ctx context.Context) (string, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) diff --git a/management/server/testutil/store.go b/management/server/testutil/store.go index 8672efa7f..ca022bfef 100644 --- a/management/server/testutil/store.go +++ b/management/server/testutil/store.go @@ -12,6 +12,7 @@ import ( "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/mysql" "github.com/testcontainers/testcontainers-go/modules/postgres" + testcontainersredis "github.com/testcontainers/testcontainers-go/modules/redis" "github.com/testcontainers/testcontainers-go/wait" ) @@ -84,3 +85,28 @@ func CreatePostgresTestContainer() (func(), error) { return cleanup, os.Setenv("NETBIRD_STORE_ENGINE_POSTGRES_DSN", talksConn) } + +// CreateRedisTestContainer creates a new Redis container for testing. +func CreateRedisTestContainer() (func(), string, error) { + ctx := context.Background() + + redisContainer, err := testcontainersredis.RunContainer(ctx, testcontainers.WithImage("redis:7")) + if err != nil { + return nil, "", err + } + + cleanup := func() { + timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second) + defer cancelFunc() + if err = redisContainer.Terminate(timeoutCtx); err != nil { + log.WithContext(ctx).Warnf("failed to stop redis container %s: %s", redisContainer.GetContainerID(), err) + } + } + + redisURL, err := redisContainer.ConnectionString(ctx) + if err != nil { + return nil, "", err + } + + return cleanup, redisURL, nil +} diff --git a/management/server/testutil/store_ios.go b/management/server/testutil/store_ios.go index edde62f1e..a614258d2 100644 --- a/management/server/testutil/store_ios.go +++ b/management/server/testutil/store_ios.go @@ -14,3 +14,9 @@ func CreateMysqlTestContainer() (func(), error) { // Empty function for MySQL }, nil } + +func CreateRedisTestContainer() (func(), string, error) { + return func() { + // Empty function for Redis + }, "", nil +} From 2f44fe2e23579e362338f049b81252354335fa5a Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 29 Apr 2025 00:43:50 +0200 Subject: [PATCH 25/45] [client] Feature/upload bundle (#3734) Add an upload bundle option with the flag --upload-bundle; by default, the upload will use a NetBird address, which can be replaced using the flag --upload-bundle-url. The upload server is available under the /upload-server path. The release change will push a docker image to netbirdio/upload image repository. The server supports using s3 with pre-signed URL for direct upload and local file for storing bundles. --- .github/workflows/golang-test-linux.yml | 3 +- .goreleaser.yaml | 70 ++++ client/cmd/debug.go | 36 +- client/cmd/root.go | 8 + client/proto/daemon.pb.go | 509 +++++++++++++----------- client/proto/daemon.proto | 3 + client/server/debug.go | 105 ++++- client/server/debug_test.go | 49 +++ go.mod | 31 +- go.sum | 62 +-- upload-server/Dockerfile | 3 + upload-server/main.go | 22 + upload-server/server/local.go | 124 ++++++ upload-server/server/local_test.go | 65 +++ upload-server/server/s3.go | 69 ++++ upload-server/server/s3_test.go | 103 +++++ upload-server/server/server.go | 109 +++++ upload-server/types/upload.go | 16 + 18 files changed, 1100 insertions(+), 287 deletions(-) create mode 100644 client/server/debug_test.go create mode 100644 upload-server/Dockerfile create mode 100644 upload-server/main.go create mode 100644 upload-server/server/local.go create mode 100644 upload-server/server/local_test.go create mode 100644 upload-server/server/s3.go create mode 100644 upload-server/server/s3_test.go create mode 100644 upload-server/server/server.go create mode 100644 upload-server/types/upload.go diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index faadcb3b3..d585ba209 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -194,6 +194,7 @@ jobs: -v "${HOST_GOMODCACHE}:${CONTAINER_GOMODCACHE}" \ -e CGO_ENABLED=1 \ -e CI=true \ + -e DOCKER_CI=true \ -e GOARCH=${GOARCH_TARGET} \ -e GOCACHE=${CONTAINER_GOCACHE} \ -e GOMODCACHE=${CONTAINER_GOMODCACHE} \ @@ -201,7 +202,7 @@ jobs: sh -c ' \ apk update; apk add --no-cache \ ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \ - go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui) + go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server) ' test_relay: diff --git a/.goreleaser.yaml b/.goreleaser.yaml index d6479763e..112659d1c 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -96,6 +96,20 @@ builds: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser mod_timestamp: "{{ .CommitTimestamp }}" + - id: netbird-upload + dir: upload-server + env: [CGO_ENABLED=0] + binary: netbird-upload + goos: + - linux + goarch: + - amd64 + - arm64 + - arm + ldflags: + - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser + mod_timestamp: "{{ .CommitTimestamp }}" + universal_binaries: - id: netbird @@ -409,6 +423,52 @@ dockers: - "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.version={{.Version}}" - "--label=maintainer=dev@netbird.io" + - image_templates: + - netbirdio/upload:{{ .Version }}-amd64 + ids: + - netbird-upload + goarch: amd64 + use: buildx + dockerfile: upload-server/Dockerfile + build_flag_templates: + - "--platform=linux/amd64" + - "--label=org.opencontainers.image.created={{.Date}}" + - "--label=org.opencontainers.image.title={{.ProjectName}}" + - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=org.opencontainers.image.revision={{.FullCommit}}" + - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=maintainer=dev@netbird.io" + - image_templates: + - netbirdio/upload:{{ .Version }}-arm64v8 + ids: + - netbird-upload + goarch: arm64 + use: buildx + dockerfile: upload-server/Dockerfile + build_flag_templates: + - "--platform=linux/arm64" + - "--label=org.opencontainers.image.created={{.Date}}" + - "--label=org.opencontainers.image.title={{.ProjectName}}" + - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=org.opencontainers.image.revision={{.FullCommit}}" + - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=maintainer=dev@netbird.io" + - image_templates: + - netbirdio/upload:{{ .Version }}-arm + ids: + - netbird-upload + goarch: arm + goarm: 6 + use: buildx + dockerfile: upload-server/Dockerfile + build_flag_templates: + - "--platform=linux/arm" + - "--label=org.opencontainers.image.created={{.Date}}" + - "--label=org.opencontainers.image.title={{.ProjectName}}" + - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=org.opencontainers.image.revision={{.FullCommit}}" + - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=maintainer=dev@netbird.io" docker_manifests: - name_template: netbirdio/netbird:{{ .Version }} image_templates: @@ -475,7 +535,17 @@ docker_manifests: - netbirdio/management:{{ .Version }}-debug-arm64v8 - netbirdio/management:{{ .Version }}-debug-arm - netbirdio/management:{{ .Version }}-debug-amd64 + - name_template: netbirdio/upload:{{ .Version }} + image_templates: + - netbirdio/upload:{{ .Version }}-arm64v8 + - netbirdio/upload:{{ .Version }}-arm + - netbirdio/upload:{{ .Version }}-amd64 + - name_template: netbirdio/upload:latest + image_templates: + - netbirdio/upload:{{ .Version }}-arm64v8 + - netbirdio/upload:{{ .Version }}-arm + - netbirdio/upload:{{ .Version }}-amd64 brews: - ids: - default diff --git a/client/cmd/debug.go b/client/cmd/debug.go index d2e5bdd7e..b4adee826 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -87,16 +87,27 @@ func debugBundle(cmd *cobra.Command, _ []string) error { }() client := proto.NewDaemonServiceClient(conn) - resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{ + request := &proto.DebugBundleRequest{ Anonymize: anonymizeFlag, Status: getStatusOutput(cmd, anonymizeFlag), SystemInfo: debugSystemInfoFlag, - }) + } + if debugUploadBundle { + request.UploadURL = debugUploadBundleURL + } + resp, err := client.DebugBundle(cmd.Context(), request) if err != nil { return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) } + cmd.Printf("Local file:\n%s\n", resp.GetPath()) - cmd.Println(resp.GetPath()) + if resp.GetUploadFailureReason() != "" { + return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason()) + } + + if debugUploadBundle { + cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey()) + } return nil } @@ -211,12 +222,15 @@ func runForDuration(cmd *cobra.Command, args []string) error { headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration) statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag)) - - resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{ + request := &proto.DebugBundleRequest{ Anonymize: anonymizeFlag, Status: statusOutput, SystemInfo: debugSystemInfoFlag, - }) + } + if debugUploadBundle { + request.UploadURL = debugUploadBundleURL + } + resp, err := client.DebugBundle(cmd.Context(), request) if err != nil { return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) } @@ -242,7 +256,15 @@ func runForDuration(cmd *cobra.Command, args []string) error { cmd.Println("Log level restored to", initialLogLevel.GetLevel()) } - cmd.Println(resp.GetPath()) + cmd.Printf("Local file:\n%s\n", resp.GetPath()) + + if resp.GetUploadFailureReason() != "" { + return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason()) + } + + if debugUploadBundle { + cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey()) + } return nil } diff --git a/client/cmd/root.go b/client/cmd/root.go index baf444b99..b4f067078 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -22,6 +22,7 @@ import ( "google.golang.org/grpc/credentials/insecure" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/upload-server/types" ) const ( @@ -39,6 +40,9 @@ const ( dnsRouteIntervalFlag = "dns-router-interval" systemInfoFlag = "system-info" blockLANAccessFlag = "block-lan-access" + uploadBundle = "upload-bundle" + uploadBundleURL = "upload-bundle-url" + defaultBundleURL = "https://upload.debug.netbird.io" + types.GetURLPath ) var ( @@ -75,6 +79,8 @@ var ( debugSystemInfoFlag bool dnsRouteInterval time.Duration blockLANAccess bool + debugUploadBundle bool + debugUploadBundleURL string rootCmd = &cobra.Command{ Use: "netbird", @@ -181,6 +187,8 @@ func init() { upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.") debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle") + debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL)) + debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, defaultBundleURL, "Service URL to get an URL to upload the debug bundle") } // SetupCloseHandler handles SIGTERM signal and exits with success diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index d04d7a9c0..879fb8032 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v4.24.3 +// protoc v3.21.9 // source: daemon.proto package proto @@ -2277,6 +2277,7 @@ type DebugBundleRequest struct { Anonymize bool `protobuf:"varint,1,opt,name=anonymize,proto3" json:"anonymize,omitempty"` Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` SystemInfo bool `protobuf:"varint,3,opt,name=systemInfo,proto3" json:"systemInfo,omitempty"` + UploadURL string `protobuf:"bytes,4,opt,name=uploadURL,proto3" json:"uploadURL,omitempty"` } func (x *DebugBundleRequest) Reset() { @@ -2332,12 +2333,21 @@ func (x *DebugBundleRequest) GetSystemInfo() bool { return false } +func (x *DebugBundleRequest) GetUploadURL() string { + if x != nil { + return x.UploadURL + } + return "" +} + type DebugBundleResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` + Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` + UploadedKey string `protobuf:"bytes,2,opt,name=uploadedKey,proto3" json:"uploadedKey,omitempty"` + UploadFailureReason string `protobuf:"bytes,3,opt,name=uploadFailureReason,proto3" json:"uploadFailureReason,omitempty"` } func (x *DebugBundleResponse) Reset() { @@ -2379,6 +2389,20 @@ func (x *DebugBundleResponse) GetPath() string { return "" } +func (x *DebugBundleResponse) GetUploadedKey() string { + if x != nil { + return x.UploadedKey + } + return "" +} + +func (x *DebugBundleResponse) GetUploadFailureReason() string { + if x != nil { + return x.UploadFailureReason + } + return "" +} + type GetLogLevelRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -3924,244 +3948,251 @@ var file_daemon_proto_rawDesc = []byte{ 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x2c, 0x0a, 0x05, 0x72, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x05, 0x72, 0x75, 0x6c, - 0x65, 0x73, 0x22, 0x6a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, - 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f, 0x6e, - 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e, 0x6f, - 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x1e, - 0x0a, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x22, 0x29, - 0x0a, 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, 0x12, 0x47, 0x65, 0x74, - 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, - 0x3d, 0x0a, 0x13, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, - 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x3c, - 0x0a, 0x12, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, - 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a, 0x13, - 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x1b, 0x0a, 0x05, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x12, 0x0a, 0x04, - 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, - 0x22, 0x13, 0x0a, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, - 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x06, 0x73, - 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, - 0x65, 0x73, 0x22, 0x44, 0x0a, 0x11, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74, 0x65, - 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x74, 0x61, - 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3b, 0x0a, 0x12, 0x43, 0x6c, 0x65, 0x61, - 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, - 0x0a, 0x0e, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x45, 0x0a, 0x12, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, - 0x74, 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x09, 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, - 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3c, 0x0a, 0x13, - 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x5f, 0x73, - 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x64, 0x65, 0x6c, - 0x65, 0x74, 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x3b, 0x0a, 0x1f, 0x53, 0x65, - 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, - 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x18, 0x0a, - 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, - 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x22, 0x0a, 0x20, 0x53, 0x65, 0x74, 0x4e, 0x65, - 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, - 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x76, 0x0a, 0x08, 0x54, - 0x43, 0x50, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x73, 0x79, 0x6e, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x73, 0x79, 0x6e, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x63, 0x6b, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x63, 0x6b, 0x12, 0x10, 0x0a, 0x03, 0x66, - 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x66, 0x69, 0x6e, 0x12, 0x10, 0x0a, - 0x03, 0x72, 0x73, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x72, 0x73, 0x74, 0x12, - 0x10, 0x0a, 0x03, 0x70, 0x73, 0x68, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x70, 0x73, - 0x68, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, - 0x75, 0x72, 0x67, 0x22, 0x80, 0x03, 0x0a, 0x12, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, - 0x6b, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x73, 0x6f, - 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, - 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x70, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x73, 0x74, 0x69, - 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0d, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x49, 0x70, 0x12, 0x1a, - 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x6f, - 0x75, 0x72, 0x63, 0x65, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, - 0x0a, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x29, 0x0a, 0x10, 0x64, - 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, - 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, - 0x69, 0x6f, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, - 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x32, 0x0a, 0x09, 0x74, 0x63, 0x70, 0x5f, 0x66, 0x6c, 0x61, 0x67, - 0x73, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x54, 0x43, 0x50, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x48, 0x00, 0x52, 0x08, 0x74, 0x63, 0x70, - 0x46, 0x6c, 0x61, 0x67, 0x73, 0x88, 0x01, 0x01, 0x12, 0x20, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70, - 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x01, 0x52, 0x08, 0x69, - 0x63, 0x6d, 0x70, 0x54, 0x79, 0x70, 0x65, 0x88, 0x01, 0x01, 0x12, 0x20, 0x0a, 0x09, 0x69, 0x63, - 0x6d, 0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x02, 0x52, - 0x08, 0x69, 0x63, 0x6d, 0x70, 0x43, 0x6f, 0x64, 0x65, 0x88, 0x01, 0x01, 0x42, 0x0c, 0x0a, 0x0a, - 0x5f, 0x74, 0x63, 0x70, 0x5f, 0x66, 0x6c, 0x61, 0x67, 0x73, 0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x69, - 0x63, 0x6d, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x69, 0x63, 0x6d, - 0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x22, 0x9f, 0x01, 0x0a, 0x0a, 0x54, 0x72, 0x61, 0x63, 0x65, - 0x53, 0x74, 0x61, 0x67, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, - 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, - 0x61, 0x67, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x12, 0x32, 0x0a, - 0x12, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x5f, 0x64, 0x65, 0x74, 0x61, - 0x69, 0x6c, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x11, 0x66, 0x6f, 0x72, - 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x88, 0x01, - 0x01, 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, - 0x5f, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x22, 0x6e, 0x0a, 0x13, 0x54, 0x72, 0x61, 0x63, - 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, - 0x2a, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x53, 0x74, - 0x61, 0x67, 0x65, 0x52, 0x06, 0x73, 0x74, 0x61, 0x67, 0x65, 0x73, 0x12, 0x2b, 0x0a, 0x11, 0x66, - 0x69, 0x6e, 0x61, 0x6c, 0x5f, 0x64, 0x69, 0x73, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x6f, 0x6e, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x66, 0x69, 0x6e, 0x61, 0x6c, 0x44, 0x69, 0x73, - 0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x12, 0x0a, 0x10, 0x53, 0x75, 0x62, 0x73, - 0x63, 0x72, 0x69, 0x62, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x93, 0x04, 0x0a, - 0x0b, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x12, 0x0e, 0x0a, 0x02, - 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x38, 0x0a, 0x08, - 0x73, 0x65, 0x76, 0x65, 0x72, 0x69, 0x74, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1c, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, - 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x76, 0x65, 0x72, 0x69, 0x74, 0x79, 0x52, 0x08, 0x73, 0x65, - 0x76, 0x65, 0x72, 0x69, 0x74, 0x79, 0x12, 0x38, 0x0a, 0x08, 0x63, 0x61, 0x74, 0x65, 0x67, 0x6f, - 0x72, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x61, - 0x74, 0x65, 0x67, 0x6f, 0x72, 0x79, 0x52, 0x08, 0x63, 0x61, 0x74, 0x65, 0x67, 0x6f, 0x72, 0x79, - 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x20, 0x0a, 0x0b, 0x75, 0x73, - 0x65, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0b, 0x75, 0x73, 0x65, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x38, 0x0a, 0x09, - 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, - 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69, 0x6d, - 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12, 0x3d, 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, - 0x74, 0x61, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x2e, 0x4d, 0x65, - 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x08, 0x6d, 0x65, 0x74, - 0x61, 0x64, 0x61, 0x74, 0x61, 0x1a, 0x3b, 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, - 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, - 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, - 0x38, 0x01, 0x22, 0x3a, 0x0a, 0x08, 0x53, 0x65, 0x76, 0x65, 0x72, 0x69, 0x74, 0x79, 0x12, 0x08, - 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x00, 0x12, 0x0b, 0x0a, 0x07, 0x57, 0x41, 0x52, 0x4e, - 0x49, 0x4e, 0x47, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x02, - 0x12, 0x0c, 0x0a, 0x08, 0x43, 0x52, 0x49, 0x54, 0x49, 0x43, 0x41, 0x4c, 0x10, 0x03, 0x22, 0x52, - 0x0a, 0x08, 0x43, 0x61, 0x74, 0x65, 0x67, 0x6f, 0x72, 0x79, 0x12, 0x0b, 0x0a, 0x07, 0x4e, 0x45, - 0x54, 0x57, 0x4f, 0x52, 0x4b, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x44, 0x4e, 0x53, 0x10, 0x01, - 0x12, 0x12, 0x0a, 0x0e, 0x41, 0x55, 0x54, 0x48, 0x45, 0x4e, 0x54, 0x49, 0x43, 0x41, 0x54, 0x49, - 0x4f, 0x4e, 0x10, 0x02, 0x12, 0x10, 0x0a, 0x0c, 0x43, 0x4f, 0x4e, 0x4e, 0x45, 0x43, 0x54, 0x49, - 0x56, 0x49, 0x54, 0x59, 0x10, 0x03, 0x12, 0x0a, 0x0a, 0x06, 0x53, 0x59, 0x53, 0x54, 0x45, 0x4d, - 0x10, 0x04, 0x22, 0x12, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x40, 0x0a, 0x11, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, - 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x2b, 0x0a, 0x06, 0x65, - 0x76, 0x65, 0x6e, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, - 0x52, 0x06, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, - 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, - 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, - 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, - 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, - 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, - 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, 0xb3, 0x0b, 0x0a, - 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, - 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, - 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, - 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, - 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, - 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, - 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, - 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, - 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, - 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x22, 0x00, 0x12, 0x51, 0x0a, 0x0e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, - 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, - 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, - 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x53, 0x0a, 0x10, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, - 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, - 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, - 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4a, 0x0a, 0x0f, 0x46, - 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x14, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1f, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x46, 0x6f, - 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, - 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, - 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, - 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, - 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, - 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, - 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, - 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, - 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, - 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, + 0x65, 0x73, 0x22, 0x88, 0x01, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, + 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f, + 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e, + 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, + 0x1e, 0x0a, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x12, + 0x1c, 0x0a, 0x09, 0x75, 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x55, 0x52, 0x4c, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x09, 0x75, 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x55, 0x52, 0x4c, 0x22, 0x7d, 0x0a, + 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x20, 0x0a, 0x0b, 0x75, 0x70, 0x6c, 0x6f, + 0x61, 0x64, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x75, + 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x75, 0x70, + 0x6c, 0x6f, 0x61, 0x64, 0x46, 0x61, 0x69, 0x6c, 0x75, 0x72, 0x65, 0x52, 0x65, 0x61, 0x73, 0x6f, + 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x75, 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x46, + 0x61, 0x69, 0x6c, 0x75, 0x72, 0x65, 0x52, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x22, 0x14, 0x0a, 0x12, + 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x22, 0x3d, 0x0a, 0x13, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, + 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, + 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, + 0x6c, 0x22, 0x3c, 0x0a, 0x12, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, + 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1b, 0x0a, 0x05, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, + 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x22, 0x13, 0x0a, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, + 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, + 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x73, + 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x44, 0x0a, 0x11, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, + 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, + 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3b, 0x0a, 0x12, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, - 0x74, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, - 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, - 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x6f, 0x0a, - 0x18, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, - 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, - 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, - 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, - 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, - 0x0a, 0x0b, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x12, 0x1a, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, - 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x44, 0x0a, 0x0f, 0x53, 0x75, 0x62, 0x73, - 0x63, 0x72, 0x69, 0x62, 0x65, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x18, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, - 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, - 0x0a, 0x09, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x18, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, - 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x33, + 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, + 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x63, 0x6c, 0x65, 0x61, 0x6e, + 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x45, 0x0a, 0x12, 0x44, 0x65, 0x6c, 0x65, + 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, + 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x09, 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, + 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, + 0x3c, 0x0a, 0x13, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, + 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, + 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x3b, 0x0a, + 0x1f, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, + 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x22, 0x0a, 0x20, 0x53, 0x65, + 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, + 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x76, + 0x0a, 0x08, 0x54, 0x43, 0x50, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x73, 0x79, + 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x73, 0x79, 0x6e, 0x12, 0x10, 0x0a, 0x03, + 0x61, 0x63, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x63, 0x6b, 0x12, 0x10, + 0x0a, 0x03, 0x66, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x66, 0x69, 0x6e, + 0x12, 0x10, 0x0a, 0x03, 0x72, 0x73, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x72, + 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x73, 0x68, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x03, 0x70, 0x73, 0x68, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x03, 0x75, 0x72, 0x67, 0x22, 0x80, 0x03, 0x0a, 0x12, 0x54, 0x72, 0x61, 0x63, 0x65, + 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1b, 0x0a, + 0x09, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x08, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x70, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, + 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x70, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x0d, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x49, + 0x70, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1f, 0x0a, + 0x0b, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x0d, 0x52, 0x0a, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x29, + 0x0a, 0x10, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x70, 0x6f, + 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x64, 0x69, 0x72, + 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x64, 0x69, + 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x32, 0x0a, 0x09, 0x74, 0x63, 0x70, 0x5f, 0x66, + 0x6c, 0x61, 0x67, 0x73, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x54, 0x43, 0x50, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x48, 0x00, 0x52, 0x08, + 0x74, 0x63, 0x70, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x88, 0x01, 0x01, 0x12, 0x20, 0x0a, 0x09, 0x69, + 0x63, 0x6d, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x01, + 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x54, 0x79, 0x70, 0x65, 0x88, 0x01, 0x01, 0x12, 0x20, 0x0a, + 0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0d, + 0x48, 0x02, 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x43, 0x6f, 0x64, 0x65, 0x88, 0x01, 0x01, 0x42, + 0x0c, 0x0a, 0x0a, 0x5f, 0x74, 0x63, 0x70, 0x5f, 0x66, 0x6c, 0x61, 0x67, 0x73, 0x42, 0x0c, 0x0a, + 0x0a, 0x5f, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x42, 0x0c, 0x0a, 0x0a, 0x5f, + 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x22, 0x9f, 0x01, 0x0a, 0x0a, 0x54, 0x72, + 0x61, 0x63, 0x65, 0x53, 0x74, 0x61, 0x67, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x18, 0x0a, 0x07, + 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, + 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, + 0x12, 0x32, 0x0a, 0x12, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x5f, 0x64, + 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x11, + 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, + 0x73, 0x88, 0x01, 0x01, 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, + 0x69, 0x6e, 0x67, 0x5f, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x22, 0x6e, 0x0a, 0x13, 0x54, + 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x2a, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x54, 0x72, 0x61, 0x63, + 0x65, 0x53, 0x74, 0x61, 0x67, 0x65, 0x52, 0x06, 0x73, 0x74, 0x61, 0x67, 0x65, 0x73, 0x12, 0x2b, + 0x0a, 0x11, 0x66, 0x69, 0x6e, 0x61, 0x6c, 0x5f, 0x64, 0x69, 0x73, 0x70, 0x6f, 0x73, 0x69, 0x74, + 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x66, 0x69, 0x6e, 0x61, 0x6c, + 0x44, 0x69, 0x73, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x12, 0x0a, 0x10, 0x53, + 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, + 0x93, 0x04, 0x0a, 0x0b, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x12, + 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, + 0x38, 0x0a, 0x08, 0x73, 0x65, 0x76, 0x65, 0x72, 0x69, 0x74, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0e, 0x32, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x79, 0x73, 0x74, 0x65, + 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x76, 0x65, 0x72, 0x69, 0x74, 0x79, 0x52, + 0x08, 0x73, 0x65, 0x76, 0x65, 0x72, 0x69, 0x74, 0x79, 0x12, 0x38, 0x0a, 0x08, 0x63, 0x61, 0x74, + 0x65, 0x67, 0x6f, 0x72, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1c, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, + 0x2e, 0x43, 0x61, 0x74, 0x65, 0x67, 0x6f, 0x72, 0x79, 0x52, 0x08, 0x63, 0x61, 0x74, 0x65, 0x67, + 0x6f, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x20, 0x0a, + 0x0b, 0x75, 0x73, 0x65, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x05, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x0b, 0x75, 0x73, 0x65, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, + 0x38, 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x06, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, + 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12, 0x3d, 0x0a, 0x08, 0x6d, 0x65, 0x74, + 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, + 0x2e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x08, + 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x1a, 0x3b, 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x61, + 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, + 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, + 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x3a, 0x0a, 0x08, 0x53, 0x65, 0x76, 0x65, 0x72, 0x69, 0x74, + 0x79, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x00, 0x12, 0x0b, 0x0a, 0x07, 0x57, + 0x41, 0x52, 0x4e, 0x49, 0x4e, 0x47, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, + 0x52, 0x10, 0x02, 0x12, 0x0c, 0x0a, 0x08, 0x43, 0x52, 0x49, 0x54, 0x49, 0x43, 0x41, 0x4c, 0x10, + 0x03, 0x22, 0x52, 0x0a, 0x08, 0x43, 0x61, 0x74, 0x65, 0x67, 0x6f, 0x72, 0x79, 0x12, 0x0b, 0x0a, + 0x07, 0x4e, 0x45, 0x54, 0x57, 0x4f, 0x52, 0x4b, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x44, 0x4e, + 0x53, 0x10, 0x01, 0x12, 0x12, 0x0a, 0x0e, 0x41, 0x55, 0x54, 0x48, 0x45, 0x4e, 0x54, 0x49, 0x43, + 0x41, 0x54, 0x49, 0x4f, 0x4e, 0x10, 0x02, 0x12, 0x10, 0x0a, 0x0c, 0x43, 0x4f, 0x4e, 0x4e, 0x45, + 0x43, 0x54, 0x49, 0x56, 0x49, 0x54, 0x59, 0x10, 0x03, 0x12, 0x0a, 0x0a, 0x06, 0x53, 0x59, 0x53, + 0x54, 0x45, 0x4d, 0x10, 0x04, 0x22, 0x12, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, + 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x40, 0x0a, 0x11, 0x47, 0x65, 0x74, + 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x2b, + 0x0a, 0x06, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x13, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, + 0x65, 0x6e, 0x74, 0x52, 0x06, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x2a, 0x62, 0x0a, 0x08, 0x4c, + 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, + 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, + 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, + 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, + 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, + 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, + 0xb3, 0x0b, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, + 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, + 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, + 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, + 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x4c, 0x69, 0x73, + 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x51, 0x0a, 0x0e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x53, 0x0a, 0x10, 0x44, 0x65, 0x73, + 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, + 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, + 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4a, + 0x0a, 0x0f, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, + 0x73, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1f, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, + 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, + 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, + 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, + 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, + 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, + 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, + 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, + 0x45, 0x0a, 0x0a, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, + 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, + 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, + 0x12, 0x6f, 0x0a, 0x18, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, + 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, + 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, + 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, + 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, + 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, + 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x44, 0x0a, 0x0f, 0x53, + 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x18, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, + 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x22, 0x00, 0x30, + 0x01, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x18, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, + 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 49e577853..6c63a8f9b 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -336,10 +336,13 @@ message DebugBundleRequest { bool anonymize = 1; string status = 2; bool systemInfo = 3; + string uploadURL = 4; } message DebugBundleResponse { string path = 1; + string uploadedKey = 2; + string uploadFailureReason = 3; } enum LogLevel { diff --git a/client/server/debug.go b/client/server/debug.go index 9ccfb13fb..b42b1467a 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -4,16 +4,24 @@ package server import ( "context" + "crypto/sha256" + "encoding/json" "errors" "fmt" + "io" + "net/http" + "os" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/debug" "github.com/netbirdio/netbird/client/proto" mgmProto "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/upload-server/types" ) +const maxBundleUploadSize = 50 * 1024 * 1024 + // DebugBundle creates a debug bundle and returns the location. func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) { s.mutex.Lock() @@ -42,7 +50,102 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) ( return nil, fmt.Errorf("generate debug bundle: %w", err) } - return &proto.DebugBundleResponse{Path: path}, nil + if req.GetUploadURL() == "" { + + return &proto.DebugBundleResponse{Path: path}, nil + } + key, err := uploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path) + if err != nil { + return &proto.DebugBundleResponse{Path: path, UploadFailureReason: err.Error()}, nil + } + + return &proto.DebugBundleResponse{Path: path, UploadedKey: key}, nil +} + +func uploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) { + response, err := getUploadURL(ctx, url, managementURL) + if err != nil { + return "", err + } + + err = upload(ctx, filePath, response) + if err != nil { + return "", err + } + return response.Key, nil +} + +func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error { + fileData, err := os.Open(filePath) + if err != nil { + return fmt.Errorf("open file: %w", err) + } + + defer fileData.Close() + + stat, err := fileData.Stat() + if err != nil { + return fmt.Errorf("stat file: %w", err) + } + + if stat.Size() > maxBundleUploadSize { + return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize) + } + + req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData) + if err != nil { + return fmt.Errorf("create PUT request: %w", err) + } + + req.ContentLength = stat.Size() + req.Header.Set("Content-Type", "application/octet-stream") + + putResp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("upload failed: %v", err) + } + defer putResp.Body.Close() + + if putResp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(putResp.Body) + return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body)) + } + return nil +} + +func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) { + id := getURLHash(managementURL) + getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil) + if err != nil { + return nil, fmt.Errorf("create GET request: %w", err) + } + + getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue) + + resp, err := http.DefaultClient.Do(getReq) + if err != nil { + return nil, fmt.Errorf("get presigned URL: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body)) + } + + urlBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response body: %w", err) + } + var response types.GetURLResponse + if err := json.Unmarshal(urlBytes, &response); err != nil { + return nil, fmt.Errorf("unmarshal response: %w", err) + } + return &response, nil +} + +func getURLHash(url string) string { + return fmt.Sprintf("%x", sha256.Sum256([]byte(url))) } // GetLogLevel gets the current logging level for the server. diff --git a/client/server/debug_test.go b/client/server/debug_test.go new file mode 100644 index 000000000..53d9ac8ed --- /dev/null +++ b/client/server/debug_test.go @@ -0,0 +1,49 @@ +package server + +import ( + "context" + "errors" + "net/http" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/upload-server/server" + "github.com/netbirdio/netbird/upload-server/types" +) + +func TestUpload(t *testing.T) { + if os.Getenv("DOCKER_CI") == "true" { + t.Skip("Skipping upload test on docker ci") + } + testDir := t.TempDir() + testURL := "http://localhost:8080" + t.Setenv("SERVER_URL", testURL) + t.Setenv("STORE_DIR", testDir) + srv := server.NewServer() + go func() { + if err := srv.Start(); err != nil && !errors.Is(err, http.ErrServerClosed) { + t.Errorf("Failed to start server: %v", err) + } + }() + t.Cleanup(func() { + if err := srv.Stop(); err != nil { + t.Errorf("Failed to stop server: %v", err) + } + }) + + file := filepath.Join(t.TempDir(), "tmpfile") + fileContent := []byte("test file content") + err := os.WriteFile(file, fileContent, 0640) + require.NoError(t, err) + key, err := uploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file) + require.NoError(t, err) + id := getURLHash(testURL) + require.Contains(t, key, id+"/") + expectedFilePath := filepath.Join(testDir, key) + createdFileContent, err := os.ReadFile(expectedFilePath) + require.NoError(t, err) + require.Equal(t, fileContent, createdFileContent) +} diff --git a/go.mod b/go.mod index 095840f13..2b3ef9cd6 100644 --- a/go.mod +++ b/go.mod @@ -33,6 +33,9 @@ require ( fyne.io/fyne/v2 v2.5.3 fyne.io/systray v1.11.0 github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible + github.com/aws/aws-sdk-go-v2 v1.36.3 + github.com/aws/aws-sdk-go-v2/config v1.29.14 + github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2 github.com/c-robinson/iplib v1.0.3 github.com/caddyserver/certmagic v0.21.3 github.com/cilium/ebpf v0.15.0 @@ -123,20 +126,22 @@ require ( github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Microsoft/hcsshim v0.12.3 // indirect github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect - github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect - github.com/aws/aws-sdk-go-v2/config v1.27.27 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.17.27 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 // indirect - github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.34 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.7.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.15 // indirect github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 // indirect - github.com/aws/smithy-go v1.20.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect + github.com/aws/smithy-go v1.22.2 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/caddyserver/zerossl v0.1.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect diff --git a/go.sum b/go.sum index 8c1c021f8..a90db83de 100644 --- a/go.sum +++ b/go.sum @@ -74,34 +74,44 @@ github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kd github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= -github.com/aws/aws-sdk-go-v2 v1.30.3 h1:jUeBtG0Ih+ZIFH0F4UkmL9w3cSpaMv9tYYDbzILP8dY= -github.com/aws/aws-sdk-go-v2 v1.30.3/go.mod h1:nIQjQVp5sfpQcTc9mPSr1B0PaWK5ByX9MOoDadSN4lc= -github.com/aws/aws-sdk-go-v2/config v1.27.27 h1:HdqgGt1OAP0HkEDDShEl0oSYa9ZZBSOmKpdpsDMdO90= -github.com/aws/aws-sdk-go-v2/config v1.27.27/go.mod h1:MVYamCg76dFNINkZFu4n4RjDixhVr51HLj4ErWzrVwg= -github.com/aws/aws-sdk-go-v2/credentials v1.17.27 h1:2raNba6gr2IfA0eqqiP2XiQ0UVOpGPgDSi0I9iAP+UI= -github.com/aws/aws-sdk-go-v2/credentials v1.17.27/go.mod h1:gniiwbGahQByxan6YjQUMcW4Aov6bLC3m+evgcoN4r4= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 h1:KreluoV8FZDEtI6Co2xuNk/UqI9iwMrOx/87PBNIKqw= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11/go.mod h1:SeSUYBLsMYFoRvHE0Tjvn7kbxaUhl75CJi1sbfhMxkU= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 h1:SoNJ4RlFEQEbtDcCEt+QG56MY4fm4W8rYirAmq+/DdU= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15/go.mod h1:U9ke74k1n2bf+RIgoX1SXFed1HLs51OgUSs+Ph0KJP8= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 h1:C6WHdGnTDIYETAm5iErQUiVNsclNx9qbJVPIt03B6bI= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15/go.mod h1:ZQLZqhcu+JhSrA9/NXRm8SkDvsycE+JkV3WGY41e+IM= -github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= -github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 h1:dT3MqvGhSoaIhRseqw2I0yH81l7wiR2vjs57O51EAm8= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3/go.mod h1:GlAeCkHwugxdHaueRr4nhPuY+WW+gR8UjlcqzPr1SPI= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 h1:HGErhhrxZlQ044RiM+WdoZxp0p+EGM62y3L6pwA4olE= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17/go.mod h1:RkZEx4l0EHYDJpWppMJ3nD9wZJAa8/0lq9aVC+r2UII= +github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= +github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 h1:zAybnyUQXIZ5mok5Jqwlf58/TFE7uvd3IAsa1aF9cXs= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10/go.mod h1:qqvMj6gHLR/EXWZw4ZbqlPbQUyenf4h82UQUlKc+l14= +github.com/aws/aws-sdk-go-v2/config v1.29.14 h1:f+eEi/2cKCg9pqKBoAIwRGzVb70MRKqWX4dg1BDcSJM= +github.com/aws/aws-sdk-go-v2/config v1.29.14/go.mod h1:wVPHWcIFv3WO89w0rE10gzf17ZYy+UVS1Geq8Iei34g= +github.com/aws/aws-sdk-go-v2/credentials v1.17.67 h1:9KxtdcIA/5xPNQyZRgUSpYOE6j9Bc4+D7nZua0KGYOM= +github.com/aws/aws-sdk-go-v2/credentials v1.17.67/go.mod h1:p3C44m+cfnbv763s52gCqrjaqyPikj9Sg47kUVaNZQQ= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 h1:x793wxmUWVDhshP8WW2mlnXuFrO4cOd3HLBroh1paFw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30/go.mod h1:Jpne2tDnYiFascUEs2AWHJL9Yp7A5ZVy3TNyxaAjD6M= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 h1:ZK5jHhnrioRkUNOc+hOgQKlUL5JeC3S6JgLxtQ+Rm0Q= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34/go.mod h1:p4VfIceZokChbA9FzMbRGz5OV+lekcVtHlPKEO0gSZY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 h1:SZwFm17ZUNNg5Np0ioo/gq8Mn6u9w19Mri8DnJ15Jf0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34/go.mod h1:dFZsC0BLo346mvKQLWmoJxT+Sjp+qcVR1tRVHQGOH9Q= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.34 h1:ZNTqv4nIdE/DiBfUUfXcLZ/Spcuz+RjeziUtNJackkM= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.34/go.mod h1:zf7Vcd1ViW7cPqYWEHLHJkS50X0JS2IKz9Cgaj6ugrs= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 h1:eAh2A4b5IzM/lum78bZ590jy36+d/aFLgKF/4Vd1xPE= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3/go.mod h1:0yKJC/kb8sAnmlYa6Zs3QVYqaC8ug2AbnNChv5Ox3uA= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.7.0 h1:lguz0bmOoGzozP9XfRJR1QIayEYo+2vP/No3OfLF0pU= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.7.0/go.mod h1:iu6FSzgt+M2/x3Dk8zhycdIcHjEFb36IS8HVUVFoMg0= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 h1:dM9/92u2F1JbDaGooxTq18wmmFzbJRfXfVfy96/1CXM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15/go.mod h1:SwFBy2vjtA0vZbjjaFtfN045boopadnoVPhu4Fv66vY= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.15 h1:moLQUoVq91LiqT1nbvzDukyqAlCv89ZmwaHw/ZFlFZg= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.15/go.mod h1:ZH34PJUc8ApjBIfgQCFvkWcUDBtl/WTD+uiYHjd8igA= github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3 h1:MmLCRqP4U4Cw9gJ4bNrCG0mWqEtBlmAVleyelcHARMU= github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3/go.mod h1:AMPjK2YnRh0YgOID3PqhJA1BRNfXDfGOnSsKHtAe8yA= -github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 h1:BXx0ZIxvrJdSgSvKTZ+yRBeSqqgPM89VPlulEcl37tM= -github.com/aws/aws-sdk-go-v2/service/sso v1.22.4/go.mod h1:ooyCOXjvJEsUw7x+ZDHeISPMhtwI3ZCB7ggFMcFfWLU= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 h1:yiwVzJW2ZxZTurVbYWA7QOrAaCYQR72t0wrSBfoesUE= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4/go.mod h1:0oxfLkpz3rQ/CHlx5hB7H69YUpFiI1tql6Q6Ne+1bCw= -github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 h1:ZsDKRLXGWHk8WdtyYMoGNO7bTudrvuKpDKgMVRlepGE= -github.com/aws/aws-sdk-go-v2/service/sts v1.30.3/go.mod h1:zwySh8fpFyXp9yOr/KVzxOl8SRqgf/IDw5aUt9UKFcQ= -github.com/aws/smithy-go v1.20.3 h1:ryHwveWzPV5BIof6fyDvor6V3iUL7nTfiTKXHiW05nE= -github.com/aws/smithy-go v1.20.3/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= +github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2 h1:tWUG+4wZqdMl/znThEk9tcCy8tTMxq8dW0JTgamohrY= +github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2/go.mod h1:U5SNqwhXB3Xe6F47kXvWihPl/ilGaEDe8HD/50Z9wxc= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 h1:1Gw+9ajCV1jogloEv1RRnvfRFia2cL6c9cuKV2Ps+G8= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.3/go.mod h1:qs4a9T5EMLl/Cajiw2TcbNt2UNo/Hqlyp+GiuG4CFDI= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 h1:hXmVKytPfTy5axZ+fYbR5d0cFmC3JvwLm5kM83luako= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1/go.mod h1:MlYRNmYu/fGPoxBQVvBYr9nyr948aY/WLUvwBMBJubs= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/XvaX32evhproijJEZY= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= +github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= +github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= diff --git a/upload-server/Dockerfile b/upload-server/Dockerfile new file mode 100644 index 000000000..a38c6fbb8 --- /dev/null +++ b/upload-server/Dockerfile @@ -0,0 +1,3 @@ +FROM gcr.io/distroless/base:debug +ENTRYPOINT [ "/go/bin/netbird-upload" ] +COPY netbird-upload /go/bin/netbird-upload diff --git a/upload-server/main.go b/upload-server/main.go new file mode 100644 index 000000000..dcfb35cdf --- /dev/null +++ b/upload-server/main.go @@ -0,0 +1,22 @@ +package main + +import ( + "errors" + "log" + "net/http" + + "github.com/netbirdio/netbird/upload-server/server" + "github.com/netbirdio/netbird/util" +) + +func main() { + err := util.InitLog("info", "console") + if err != nil { + log.Fatalf("Failed to initialize logger: %v", err) + } + + srv := server.NewServer() + if err = srv.Start(); err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Fatalf("Failed to start server: %v", err) + } +} diff --git a/upload-server/server/local.go b/upload-server/server/local.go new file mode 100644 index 000000000..f12c472d2 --- /dev/null +++ b/upload-server/server/local.go @@ -0,0 +1,124 @@ +package server + +import ( + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/upload-server/types" +) + +const ( + defaultDir = "/var/lib/netbird" + putHandler = "/{dir}/{file}" +) + +type local struct { + url string + dir string +} + +func configureLocalHandlers(mux *http.ServeMux) error { + envURL, ok := os.LookupEnv("SERVER_URL") + if !ok { + return fmt.Errorf("SERVER_URL environment variable is required") + } + _, err := url.Parse(envURL) + if err != nil { + return fmt.Errorf("SERVER_URL environment variable is invalid: %w", err) + } + + dir := defaultDir + envDir, ok := os.LookupEnv("STORE_DIR") + if ok { + if !filepath.IsAbs(envDir) { + return fmt.Errorf("STORE_DIR environment variable should point to an absolute path, e.g. /tmp") + } + log.Infof("Using local directory: %s", envDir) + dir = envDir + } + + l := &local{ + url: envURL, + dir: dir, + } + mux.HandleFunc(types.GetURLPath, l.handlerGetUploadURL) + mux.HandleFunc(putURLPath+putHandler, l.handlePutRequest) + + return nil +} + +func (l *local) handlerGetUploadURL(w http.ResponseWriter, r *http.Request) { + if !isValidRequest(w, r) { + return + } + + objectKey := getObjectKey(w, r) + if objectKey == "" { + return + } + + uploadURL, err := l.getUploadURL(objectKey) + if err != nil { + http.Error(w, "failed to get upload URL", http.StatusInternalServerError) + log.Errorf("Failed to get upload URL: %v", err) + return + } + + respondGetRequest(w, uploadURL, objectKey) +} + +func (l *local) getUploadURL(objectKey string) (string, error) { + parsedUploadURL, err := url.Parse(l.url) + if err != nil { + return "", fmt.Errorf("failed to parse upload URL: %w", err) + } + newURL := parsedUploadURL.JoinPath(parsedUploadURL.Path, putURLPath, objectKey) + return newURL.String(), nil +} + +func (l *local) handlePutRequest(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, fmt.Sprintf("failed to read body: %v", err), http.StatusInternalServerError) + return + } + + uploadDir := r.PathValue("dir") + if uploadDir == "" { + http.Error(w, "missing dir path", http.StatusBadRequest) + return + } + uploadFile := r.PathValue("file") + if uploadFile == "" { + http.Error(w, "missing file name", http.StatusBadRequest) + return + } + + dirPath := filepath.Join(l.dir, uploadDir) + err = os.MkdirAll(dirPath, 0750) + if err != nil { + http.Error(w, "failed to create upload dir", http.StatusInternalServerError) + log.Errorf("Failed to create upload dir: %v", err) + return + } + + file := filepath.Join(dirPath, uploadFile) + if err := os.WriteFile(file, body, 0600); err != nil { + http.Error(w, "failed to write file", http.StatusInternalServerError) + log.Errorf("Failed to write file %s: %v", file, err) + return + } + log.Infof("Uploading file %s", file) + w.WriteHeader(http.StatusOK) +} diff --git a/upload-server/server/local_test.go b/upload-server/server/local_test.go new file mode 100644 index 000000000..bd8a87809 --- /dev/null +++ b/upload-server/server/local_test.go @@ -0,0 +1,65 @@ +package server + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/upload-server/types" +) + +func Test_LocalHandlerGetUploadURL(t *testing.T) { + mockURL := "http://localhost:8080" + t.Setenv("SERVER_URL", mockURL) + t.Setenv("STORE_DIR", t.TempDir()) + + mux := http.NewServeMux() + err := configureLocalHandlers(mux) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodGet, types.GetURLPath+"?id=test-file", nil) + req.Header.Set(types.ClientHeader, types.ClientHeaderValue) + + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var response types.GetURLResponse + err = json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + require.Contains(t, response.URL, "test-file/") + require.NotEmpty(t, response.Key) + require.Contains(t, response.Key, "test-file/") + +} + +func Test_LocalHandlePutRequest(t *testing.T) { + mockDir := t.TempDir() + mockURL := "http://localhost:8080" + t.Setenv("SERVER_URL", mockURL) + t.Setenv("STORE_DIR", mockDir) + + mux := http.NewServeMux() + err := configureLocalHandlers(mux) + require.NoError(t, err) + + fileContent := []byte("test file content") + req := httptest.NewRequest(http.MethodPut, putURLPath+"/uploads/test.txt", bytes.NewReader(fileContent)) + + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + expectedFilePath := filepath.Join(mockDir, "uploads", "test.txt") + createdFileContent, err := os.ReadFile(expectedFilePath) + require.NoError(t, err) + require.Equal(t, fileContent, createdFileContent) +} diff --git a/upload-server/server/s3.go b/upload-server/server/s3.go new file mode 100644 index 000000000..c0976acb5 --- /dev/null +++ b/upload-server/server/s3.go @@ -0,0 +1,69 @@ +package server + +import ( + "context" + "fmt" + "net/http" + "os" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/s3" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/upload-server/types" +) + +type sThree struct { + ctx context.Context + bucket string + presignClient *s3.PresignClient +} + +func configureS3Handlers(mux *http.ServeMux) error { + bucket := os.Getenv(bucketVar) + region, ok := os.LookupEnv("AWS_REGION") + if !ok { + return fmt.Errorf("AWS_REGION environment variable is required") + } + ctx := context.Background() + cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) + if err != nil { + return fmt.Errorf("unable to load SDK config: %w", err) + } + + client := s3.NewFromConfig(cfg) + + handler := &sThree{ + ctx: ctx, + bucket: bucket, + presignClient: s3.NewPresignClient(client), + } + mux.HandleFunc(types.GetURLPath, handler.handlerGetUploadURL) + return nil +} + +func (s *sThree) handlerGetUploadURL(w http.ResponseWriter, r *http.Request) { + if !isValidRequest(w, r) { + return + } + + objectKey := getObjectKey(w, r) + if objectKey == "" { + return + } + + req, err := s.presignClient.PresignPutObject(s.ctx, &s3.PutObjectInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(objectKey), + }, s3.WithPresignExpires(15*time.Minute)) + + if err != nil { + http.Error(w, "failed to presign URL", http.StatusInternalServerError) + log.Errorf("Presign error: %v", err) + return + } + + respondGetRequest(w, req.URL, objectKey) +} diff --git a/upload-server/server/s3_test.go b/upload-server/server/s3_test.go new file mode 100644 index 000000000..26b0ecd09 --- /dev/null +++ b/upload-server/server/s3_test.go @@ -0,0 +1,103 @@ +package server + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "runtime" + "testing" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" + + "github.com/netbirdio/netbird/upload-server/types" +) + +func Test_S3HandlerGetUploadURL(t *testing.T) { + if runtime.GOOS != "linux" && os.Getenv("CI") == "true" { + t.Skip("Skipping test on non-Linux and CI environment due to docker dependency") + } + if runtime.GOOS == "windows" { + t.Skip("Skipping test on Windows due to potential docker dependency") + } + + awsEndpoint := "http://127.0.0.1:4566" + awsRegion := "us-east-1" + + ctx := context.Background() + containerRequest := testcontainers.ContainerRequest{ + Image: "localstack/localstack:s3-latest", + ExposedPorts: []string{"4566:4566/tcp"}, + WaitingFor: wait.ForLog("Ready"), + } + + c, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: containerRequest, + Started: true, + }) + if err != nil { + t.Error(err) + } + defer func(c testcontainers.Container, ctx context.Context) { + if err := c.Terminate(ctx); err != nil { + t.Log(err) + } + }(c, ctx) + + t.Setenv("AWS_REGION", awsRegion) + t.Setenv("AWS_ENDPOINT_URL", awsEndpoint) + t.Setenv("AWS_ACCESS_KEY_ID", "test") + t.Setenv("AWS_SECRET_ACCESS_KEY", "test") + + cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(awsRegion), config.WithBaseEndpoint(awsEndpoint)) + if err != nil { + t.Error(err) + } + + client := s3.NewFromConfig(cfg, func(o *s3.Options) { + o.UsePathStyle = true + o.BaseEndpoint = cfg.BaseEndpoint + }) + + bucketName := "test" + if _, err := client.CreateBucket(ctx, &s3.CreateBucketInput{ + Bucket: &bucketName, + }); err != nil { + t.Error(err) + } + + list, err := client.ListBuckets(ctx, &s3.ListBucketsInput{}) + if err != nil { + t.Error(err) + } + + assert.Equal(t, len(list.Buckets), 1) + assert.Equal(t, *list.Buckets[0].Name, bucketName) + + t.Setenv(bucketVar, bucketName) + + mux := http.NewServeMux() + err = configureS3Handlers(mux) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodGet, types.GetURLPath+"?id=test-file", nil) + req.Header.Set(types.ClientHeader, types.ClientHeaderValue) + + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var response types.GetURLResponse + err = json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + require.Contains(t, response.URL, "test-file/") + require.NotEmpty(t, response.Key) + require.Contains(t, response.Key, "test-file/") +} diff --git a/upload-server/server/server.go b/upload-server/server/server.go new file mode 100644 index 000000000..29ef72732 --- /dev/null +++ b/upload-server/server/server.go @@ -0,0 +1,109 @@ +package server + +import ( + "context" + "encoding/json" + "net/http" + "os" + "time" + + "github.com/google/uuid" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/upload-server/types" +) + +const ( + putURLPath = "/upload" + bucketVar = "BUCKET" +) + +type Server struct { + srv *http.Server +} + +func NewServer() *Server { + address := os.Getenv("SERVER_ADDRESS") + if address == "" { + log.Infof("SERVER_ADDRESS environment variable was not set, using 0.0.0.0:8080") + address = "0.0.0.0:8080" + } + mux := http.NewServeMux() + err := configureMux(mux) + if err != nil { + log.Fatalf("Failed to configure server: %v", err) + } + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "not found", http.StatusNotFound) + }) + + return &Server{ + srv: &http.Server{Addr: address, Handler: mux}, + } +} + +func (s *Server) Start() error { + log.Infof("Starting upload server on %s", s.srv.Addr) + return s.srv.ListenAndServe() +} + +func (s *Server) Stop() error { + if s.srv != nil { + log.Infof("Stopping upload server on %s", s.srv.Addr) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return s.srv.Shutdown(ctx) + } + return nil +} + +func configureMux(mux *http.ServeMux) error { + _, ok := os.LookupEnv(bucketVar) + if ok { + return configureS3Handlers(mux) + } else { + return configureLocalHandlers(mux) + } +} + +func getObjectKey(w http.ResponseWriter, r *http.Request) string { + id := r.URL.Query().Get("id") + if id == "" { + http.Error(w, "id query param required", http.StatusBadRequest) + return "" + } + + return id + "/" + uuid.New().String() +} + +func isValidRequest(w http.ResponseWriter, r *http.Request) bool { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return false + } + + if r.Header.Get(types.ClientHeader) != types.ClientHeaderValue { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return false + } + return true +} +func respondGetRequest(w http.ResponseWriter, uploadURL string, objectKey string) { + response := types.GetURLResponse{ + URL: uploadURL, + Key: objectKey, + } + + rdata, err := json.Marshal(response) + if err != nil { + http.Error(w, "failed to marshal response", http.StatusInternalServerError) + log.Errorf("Marshal error: %v", err) + return + } + + w.WriteHeader(http.StatusOK) + _, err = w.Write(rdata) + if err != nil { + log.Errorf("Write error: %v", err) + } +} diff --git a/upload-server/types/upload.go b/upload-server/types/upload.go new file mode 100644 index 000000000..35d003582 --- /dev/null +++ b/upload-server/types/upload.go @@ -0,0 +1,16 @@ +package types + +const ( + // ClientHeader is the header used to identify the client + ClientHeader = "x-nb-client" + // ClientHeaderValue is the value of the ClientHeader + ClientHeaderValue = "netbird" + // GetURLPath is the path for the GetURL request + GetURLPath = "/upload-url" +) + +// GetURLResponse is the response for the GetURL request +type GetURLResponse struct { + URL string + Key string +} From d2b42c8f686fb1df8f6a86bb93dacbe0ee2be28b Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Tue, 29 Apr 2025 13:43:42 +0300 Subject: [PATCH 26/45] [client] Add macOS .pkg installer support to installation script (#3755) [client] Add macOS .pkg installer support to installation script --- release_files/install.sh | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/release_files/install.sh b/release_files/install.sh index e5a61dcfe..49e313f2f 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -199,6 +199,21 @@ install_native_binaries() { fi } +# Handle macOS .pkg installer +install_pkg() { + case "$(uname -m)" in + x86_64) ARCH="amd64" ;; + arm64|aarch64) ARCH="arm64" ;; + *) echo "Unsupported macOS arch: $(uname -m)" >&2; exit 1 ;; + esac + + PKG_URL=$(curl -sIL -o /dev/null -w '%{url_effective}' "https://pkgs.netbird.io/macos/${ARCH}") + echo "Downloading NetBird macOS installer from https://pkgs.netbird.io/macos/${ARCH}" + curl -fsSL -o /tmp/netbird.pkg "${PKG_URL}" + ${SUDO} installer -pkg /tmp/netbird.pkg -target / + rm -f /tmp/netbird.pkg +} + check_use_bin_variable() { if [ "${USE_BIN_INSTALL}-x" = "true-x" ]; then echo "The installation will be performed using binary files" @@ -265,6 +280,16 @@ install_netbird() { ${SUDO} pacman -Syy add_aur_repo ;; + pkg) + # Check if the package is already installed + if [ -f /Library/Receipts/netbird.pkg ]; then + echo "NetBird is already installed. Please remove it before proceeding." + exit 1 + fi + + # Install the package + install_pkg + ;; brew) # Remove Netbird if it had been installed using Homebrew before if brew ls --versions netbird >/dev/null 2>&1; then @@ -274,7 +299,7 @@ install_netbird() { netbird service stop netbird service uninstall - # Unlik the app + # Unlink the app brew unlink netbird fi @@ -312,7 +337,7 @@ install_netbird() { echo "package_manager=$PACKAGE_MANAGER" | ${SUDO} tee "$CONFIG_FILE" > /dev/null # Load and start netbird service - if [ "$PACKAGE_MANAGER" != "rpm-ostree" ]; then + if [ "$PACKAGE_MANAGER" != "rpm-ostree" ] && [ "$PACKAGE_MANAGER" != "pkg" ]; then if ! ${SUDO} netbird service install 2>&1; then echo "NetBird service has already been loaded" fi @@ -451,9 +476,8 @@ if type uname >/dev/null 2>&1; then # Check the availability of a compatible package manager if check_use_bin_variable; then PACKAGE_MANAGER="bin" - elif [ -x "$(command -v brew)" ]; then - PACKAGE_MANAGER="brew" - echo "The installation will be performed using brew package manager" + else + PACKAGE_MANAGER="pkg" fi ;; esac @@ -471,4 +495,4 @@ case "$UPDATE_FLAG" in ;; *) install_netbird -esac +esac \ No newline at end of file From 488e619ec713cbcdd718ca56787e137de26a7a43 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Wed, 30 Apr 2025 11:51:40 +0300 Subject: [PATCH 27/45] [management] Add network traffic events pagination (#3580) * Add network traffic events pagination schema --- management/server/http/api/openapi.yml | 94 ++++++++++++++++++++++++- management/server/http/api/types.gen.go | 69 ++++++++++++++++++ 2 files changed, 160 insertions(+), 3 deletions(-) diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index c0ce06daa..51ffd65b2 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -2015,6 +2015,32 @@ components: - policy_name - icmp_type - icmp_code + NetworkTrafficEventsResponse: + type: object + properties: + data: + type: array + description: List of network traffic events + items: + $ref: "#/components/schemas/NetworkTrafficEvent" + page: + type: integer + description: Current page number + page_size: + type: integer + description: Number of items per page + total_records: + type: integer + description: Total number of event records available + total_pages: + type: integer + description: Total number of pages available + required: + - data + - page + - page_size + - total_records + - total_pages responses: not_found: description: Resource not found @@ -4231,15 +4257,77 @@ paths: tags: [ Events ] x-cloud-only: true x-experimental: true + parameters: + - name: page + in: query + description: Page number + required: false + schema: + type: integer + minimum: 1 + default: 1 + - name: page_size + in: query + description: Number of items per page + required: false + schema: + type: integer + minimum: 1 + maximum: 50000 + default: 1000 + - name: user_id + in: query + description: Filter by user ID + required: false + schema: + type: string + - name: protocol + in: query + description: Filter by protocol + required: false + schema: + type: integer + - name: type + in: query + description: Filter by event type + required: false + schema: + type: string + enum: [TYPE_UNKNOWN, TYPE_START, TYPE_END, TYPE_DROP] + - name: direction + in: query + description: Filter by direction + required: false + schema: + type: string + enum: [INGRESS, EGRESS, DIRECTION_UNKNOWN] + - name: search + in: query + description: Filters events with a partial match on user email, source and destination names and source and destination addresses + required: false + schema: + type: string + - name: start_date + in: query + description: Start date for filtering events (ISO 8601 format, e.g., 2024-01-01T00:00:00Z). + required: false + schema: + type: string + format: date-time + - name: end_date + in: query + description: End date for filtering events (ISO 8601 format, e.g., 2024-01-31T23:59:59Z). + required: false + schema: + type: string + format: date-time responses: "200": description: List of network traffic events content: application/json: schema: - type: array - items: - $ref: "#/components/schemas/NetworkTrafficEvent" + $ref: "#/components/schemas/NetworkTrafficEventsResponse" '400': "$ref": "#/components/responses/bad_request" '401': diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 243f2fdf9..e01275b99 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -185,6 +185,21 @@ const ( UserPermissionsDashboardViewLimited UserPermissionsDashboardView = "limited" ) +// Defines values for GetApiEventsNetworkTrafficParamsType. +const ( + GetApiEventsNetworkTrafficParamsTypeTYPEDROP GetApiEventsNetworkTrafficParamsType = "TYPE_DROP" + GetApiEventsNetworkTrafficParamsTypeTYPEEND GetApiEventsNetworkTrafficParamsType = "TYPE_END" + GetApiEventsNetworkTrafficParamsTypeTYPESTART GetApiEventsNetworkTrafficParamsType = "TYPE_START" + GetApiEventsNetworkTrafficParamsTypeTYPEUNKNOWN GetApiEventsNetworkTrafficParamsType = "TYPE_UNKNOWN" +) + +// Defines values for GetApiEventsNetworkTrafficParamsDirection. +const ( + GetApiEventsNetworkTrafficParamsDirectionDIRECTIONUNKNOWN GetApiEventsNetworkTrafficParamsDirection = "DIRECTION_UNKNOWN" + GetApiEventsNetworkTrafficParamsDirectionEGRESS GetApiEventsNetworkTrafficParamsDirection = "EGRESS" + GetApiEventsNetworkTrafficParamsDirectionINGRESS GetApiEventsNetworkTrafficParamsDirection = "INGRESS" +) + // AccessiblePeer defines model for AccessiblePeer. type AccessiblePeer struct { // CityName Commonly used English name of the city @@ -922,6 +937,24 @@ type NetworkTrafficEvent struct { UserName *string `json:"user_name"` } +// NetworkTrafficEventsResponse defines model for NetworkTrafficEventsResponse. +type NetworkTrafficEventsResponse struct { + // Data List of network traffic events + Data []NetworkTrafficEvent `json:"data"` + + // Page Current page number + Page int `json:"page"` + + // PageSize Number of items per page + PageSize int `json:"page_size"` + + // TotalPages Total number of pages available + TotalPages int `json:"total_pages"` + + // TotalRecords Total number of event records available + TotalRecords int `json:"total_records"` +} + // NetworkTrafficLocation defines model for NetworkTrafficLocation. type NetworkTrafficLocation struct { // CityName Name of the city (if known). @@ -1743,6 +1776,42 @@ type UserRequest struct { Role string `json:"role"` } +// GetApiEventsNetworkTrafficParams defines parameters for GetApiEventsNetworkTraffic. +type GetApiEventsNetworkTrafficParams struct { + // Page Page number + Page *int `form:"page,omitempty" json:"page,omitempty"` + + // PageSize Number of items per page + PageSize *int `form:"page_size,omitempty" json:"page_size,omitempty"` + + // UserId Filter by user ID + UserId *string `form:"user_id,omitempty" json:"user_id,omitempty"` + + // Protocol Filter by protocol + Protocol *int `form:"protocol,omitempty" json:"protocol,omitempty"` + + // Type Filter by event type + Type *GetApiEventsNetworkTrafficParamsType `form:"type,omitempty" json:"type,omitempty"` + + // Direction Filter by direction + Direction *GetApiEventsNetworkTrafficParamsDirection `form:"direction,omitempty" json:"direction,omitempty"` + + // Search Filters events with a partial match on user email, source and destination names and source and destination addresses + Search *string `form:"search,omitempty" json:"search,omitempty"` + + // StartDate Start date for filtering events (ISO 8601 format, e.g., 2024-01-01T00:00:00Z). + StartDate *time.Time `form:"start_date,omitempty" json:"start_date,omitempty"` + + // EndDate End date for filtering events (ISO 8601 format, e.g., 2024-01-31T23:59:59Z). + EndDate *time.Time `form:"end_date,omitempty" json:"end_date,omitempty"` +} + +// GetApiEventsNetworkTrafficParamsType defines parameters for GetApiEventsNetworkTraffic. +type GetApiEventsNetworkTrafficParamsType string + +// GetApiEventsNetworkTrafficParamsDirection defines parameters for GetApiEventsNetworkTraffic. +type GetApiEventsNetworkTrafficParamsDirection string + // GetApiPeersParams defines parameters for GetApiPeers. type GetApiPeersParams struct { // Name Filter peers by name From d5081cef90ac2b50060e20c0700ca96c9e08c98a Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 30 Apr 2025 13:09:00 +0200 Subject: [PATCH 28/45] [client] Revert mgm client error handling (#3764) --- management/client/grpc.go | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/management/client/grpc.go b/management/client/grpc.go index 956aaebb2..2f4729e23 100644 --- a/management/client/grpc.go +++ b/management/client/grpc.go @@ -128,13 +128,7 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler return err } - streamErr := c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler) - if c.conn.GetState() != connectivity.Shutdown { - if err := c.conn.Close(); err != nil { - log.Warnf("failed closing connection to Management service: %s", err) - } - } - return streamErr + return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler) } err := backoff.Retry(operation, defaultBackoff(ctx)) From b5419ef11a6c88d888cc30bb7f6e92c4bdc787e7 Mon Sep 17 00:00:00 2001 From: Pedro Maia Costa <550684+pnmcosta@users.noreply.github.com> Date: Wed, 30 Apr 2025 15:53:18 +0100 Subject: [PATCH 29/45] [management] limit peers based on module read permission (#3757) --- management/server/peer.go | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index 908610fbe..a4210e3f0 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -49,20 +49,9 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID return nil, err } - peers := make([]*nbpeer.Peer, 0) - peersMap := make(map[string]*nbpeer.Peer) - - for _, peer := range accountPeers { - if user.IsRegularUser() && user.Id != peer.UserID { - // only display peers that belong to the current user if the current user is not an admin - continue - } - peers = append(peers, peer) - peersMap[peer.ID] = peer - } - + // @note if the user has permission to read peers it shows all account peers if allowed { - return peers, nil + return accountPeers, nil } settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) @@ -74,6 +63,18 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID return []*nbpeer.Peer{}, nil } + // @note if it does not have permission read peers then only display it's own peers + peers := make([]*nbpeer.Peer, 0) + peersMap := make(map[string]*nbpeer.Peer) + + for _, peer := range accountPeers { + if user.Id != peer.UserID { + continue + } + peers = append(peers, peer) + peersMap[peer.ID] = peer + } + return am.getUserAccessiblePeers(ctx, accountID, peersMap, peers) } From 9bc7d788f03b228fdc94c10e16ab968cccfd4cca Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 1 May 2025 00:48:31 +0200 Subject: [PATCH 30/45] [client] Add debug upload option to netbird ui (#3768) --- client/cmd/root.go | 3 +- client/ui/client_ui.go | 43 +++--- client/ui/debug.go | 268 +++++++++++++++++++++++++++++++--- upload-server/types/upload.go | 2 + 4 files changed, 269 insertions(+), 47 deletions(-) diff --git a/client/cmd/root.go b/client/cmd/root.go index b4f067078..b57bee230 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -42,7 +42,6 @@ const ( blockLANAccessFlag = "block-lan-access" uploadBundle = "upload-bundle" uploadBundleURL = "upload-bundle-url" - defaultBundleURL = "https://upload.debug.netbird.io" + types.GetURLPath ) var ( @@ -188,7 +187,7 @@ func init() { debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle") debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL)) - debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, defaultBundleURL, "Service URL to get an URL to upload the debug bundle") + debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle") } // SetupCloseHandler handles SIGTERM signal and exits with success diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index d0b1bacf6..d8c1ee7a2 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -51,7 +51,7 @@ const ( ) func main() { - daemonAddr, showSettings, showNetworks, errorMsg, saveLogsInFile := parseFlags() + daemonAddr, showSettings, showNetworks, showDebug, errorMsg, saveLogsInFile := parseFlags() // Initialize file logging if needed. if saveLogsInFile { @@ -72,13 +72,13 @@ func main() { } // Create the service client (this also builds the settings or networks UI if requested). - client := newServiceClient(daemonAddr, a, showSettings, showNetworks) + client := newServiceClient(daemonAddr, a, showSettings, showNetworks, showDebug) // Watch for theme/settings changes to update the icon. go watchSettingsChanges(a, client) // Run in window mode if any UI flag was set. - if showSettings || showNetworks { + if showSettings || showNetworks || showDebug { a.Run() return } @@ -99,7 +99,7 @@ func main() { } // parseFlags reads and returns all needed command-line flags. -func parseFlags() (daemonAddr string, showSettings, showNetworks bool, errorMsg string, saveLogsInFile bool) { +func parseFlags() (daemonAddr string, showSettings, showNetworks, showDebug bool, errorMsg string, saveLogsInFile bool) { defaultDaemonAddr := "unix:///var/run/netbird.sock" if runtime.GOOS == "windows" { defaultDaemonAddr = "tcp://127.0.0.1:41731" @@ -107,24 +107,16 @@ func parseFlags() (daemonAddr string, showSettings, showNetworks bool, errorMsg flag.StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]") flag.BoolVar(&showSettings, "settings", false, "run settings window") flag.BoolVar(&showNetworks, "networks", false, "run networks window") + flag.BoolVar(&showDebug, "debug", false, "run debug window") flag.StringVar(&errorMsg, "error-msg", "", "displays an error message window") - - tmpDir := "/tmp" - if runtime.GOOS == "windows" { - tmpDir = os.TempDir() - } - flag.BoolVar(&saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", tmpDir)) + flag.BoolVar(&saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir())) flag.Parse() return } // initLogFile initializes logging into a file. func initLogFile() error { - tmpDir := "/tmp" - if runtime.GOOS == "windows" { - tmpDir = os.TempDir() - } - logFile := path.Join(tmpDir, fmt.Sprintf("netbird-ui-%d.log", os.Getpid())) + logFile := path.Join(os.TempDir(), fmt.Sprintf("netbird-ui-%d.log", os.Getpid())) return util.InitLog("trace", logFile) } @@ -231,7 +223,7 @@ type serviceClient struct { daemonVersion string updateIndicationLock sync.Mutex isUpdateIconActive bool - showRoutes bool + showNetworks bool wRoutes fyne.Window eventManager *event.Manager @@ -248,7 +240,7 @@ type menuHandler struct { // newServiceClient instance constructor // // This constructor also builds the UI elements for the settings window. -func newServiceClient(addr string, a fyne.App, showSettings bool, showRoutes bool) *serviceClient { +func newServiceClient(addr string, a fyne.App, showSettings bool, showNetworks bool, showDebug bool) *serviceClient { s := &serviceClient{ ctx: context.Background(), addr: addr, @@ -256,17 +248,21 @@ func newServiceClient(addr string, a fyne.App, showSettings bool, showRoutes boo sendNotification: false, showAdvancedSettings: showSettings, - showRoutes: showRoutes, + showNetworks: showNetworks, update: version.NewUpdate(), } s.setNewIcons() - if showSettings { + switch { + case showSettings: + s.showSettingsUI() return s - } else if showRoutes { + case showNetworks: s.showNetworksUI() + case showDebug: + s.showDebugUI() } return s @@ -743,11 +739,10 @@ func (s *serviceClient) onTrayReady() { s.runSelfCommand("settings", "true") }() case <-s.mCreateDebugBundle.ClickedCh: + s.mCreateDebugBundle.Disable() go func() { - if err := s.createAndOpenDebugBundle(); err != nil { - log.Errorf("Failed to create debug bundle: %v", err) - s.app.SendNotification(fyne.NewNotification("Error", "Failed to create debug bundle")) - } + defer s.mCreateDebugBundle.Enable() + s.runSelfCommand("debug", "true") }() case <-s.mQuit.ClickedCh: systray.Quit() diff --git a/client/ui/debug.go b/client/ui/debug.go index 845ea284c..e950e6d1e 100644 --- a/client/ui/debug.go +++ b/client/ui/debug.go @@ -7,44 +7,270 @@ import ( "path/filepath" "fyne.io/fyne/v2" + "fyne.io/fyne/v2/container" + "fyne.io/fyne/v2/dialog" + "fyne.io/fyne/v2/widget" + log "github.com/sirupsen/logrus" "github.com/skratchdot/open-golang/open" "github.com/netbirdio/netbird/client/proto" nbstatus "github.com/netbirdio/netbird/client/status" + uptypes "github.com/netbirdio/netbird/upload-server/types" ) -func (s *serviceClient) createAndOpenDebugBundle() error { +func (s *serviceClient) showDebugUI() { + w := s.app.NewWindow("NetBird Debug") + w.Resize(fyne.NewSize(600, 400)) + w.SetFixedSize(true) + + anonymizeCheck := widget.NewCheck("Anonymize sensitive information (Public IPs, domains, ...)", nil) + systemInfoCheck := widget.NewCheck("Include system information", nil) + systemInfoCheck.SetChecked(true) + uploadCheck := widget.NewCheck("Upload bundle automatically after creation", nil) + uploadCheck.SetChecked(true) + + uploadURLLabel := widget.NewLabel("Debug upload URL:") + uploadURL := widget.NewEntry() + uploadURL.SetText(uptypes.DefaultBundleURL) + uploadURL.SetPlaceHolder("Enter upload URL") + + statusLabel := widget.NewLabel("") + statusLabel.Hide() + + createButton := widget.NewButton("Create Debug Bundle", nil) + + uploadURLContainer := container.NewVBox( + uploadURLLabel, + uploadURL, + ) + + uploadCheck.OnChanged = func(checked bool) { + if checked { + uploadURLContainer.Show() + } else { + uploadURLContainer.Hide() + } + } + + createButton.OnTapped = s.getCreateHandler(createButton, statusLabel, uploadCheck, uploadURL, anonymizeCheck, systemInfoCheck, w) + + content := container.NewVBox( + widget.NewLabel("Create a debug bundle to help troubleshoot issues with NetBird"), + widget.NewLabel(""), + anonymizeCheck, + systemInfoCheck, + uploadCheck, + uploadURLContainer, + widget.NewLabel(""), + statusLabel, + createButton, + ) + + paddedContent := container.NewPadded(content) + w.SetContent(paddedContent) + + w.Show() +} + +func (s *serviceClient) getCreateHandler( + createButton *widget.Button, + statusLabel *widget.Label, + uploadCheck *widget.Check, + uploadURL *widget.Entry, + anonymizeCheck *widget.Check, + systemInfoCheck *widget.Check, + w fyne.Window, +) func() { + return func() { + createButton.Disable() + statusLabel.SetText("Creating debug bundle...") + statusLabel.Show() + + var url string + if uploadCheck.Checked { + url = uploadURL.Text + if url == "" { + statusLabel.SetText("Error: Upload URL is required when upload is enabled") + createButton.Enable() + return + } + } + + go s.handleDebugCreation(anonymizeCheck.Checked, systemInfoCheck.Checked, uploadCheck.Checked, url, statusLabel, createButton, w) + } +} + +func (s *serviceClient) handleDebugCreation( + anonymize bool, + systemInfo bool, + upload bool, + uploadURL string, + statusLabel *widget.Label, + createButton *widget.Button, + w fyne.Window, +) { + log.Infof("Creating debug bundle (Anonymized: %v, System Info: %v, Upload Attempt: %v)...", + anonymize, systemInfo, upload) + + resp, err := s.createDebugBundle(anonymize, systemInfo, uploadURL) + if err != nil { + log.Errorf("Failed to create debug bundle: %v", err) + statusLabel.SetText(fmt.Sprintf("Error creating bundle: %v", err)) + createButton.Enable() + return + } + + localPath := resp.GetPath() + uploadFailureReason := resp.GetUploadFailureReason() + uploadedKey := resp.GetUploadedKey() + + if upload { + if uploadFailureReason != "" { + showUploadFailedDialog(w, localPath, uploadFailureReason) + } else { + showUploadSuccessDialog(w, localPath, uploadedKey) + } + } else { + showBundleCreatedDialog(w, localPath) + } + + createButton.Enable() + statusLabel.SetText("Bundle created successfully") +} + +func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploadURL string) (*proto.DebugBundleResponse, error) { conn, err := s.getSrvClient(failFastTimeout) if err != nil { - return fmt.Errorf("get client: %v", err) + return nil, fmt.Errorf("get client: %v", err) } statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) if err != nil { - return fmt.Errorf("failed to get status: %v", err) + log.Warnf("failed to get status for debug bundle: %v", err) } - overview := nbstatus.ConvertToStatusOutputOverview(statusResp, true, "", nil, nil, nil) - statusOutput := nbstatus.ParseToFullDetailSummary(overview) + var statusOutput string + if statusResp != nil { + overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil) + statusOutput = nbstatus.ParseToFullDetailSummary(overview) + } - resp, err := conn.DebugBundle(s.ctx, &proto.DebugBundleRequest{ - Anonymize: true, + request := &proto.DebugBundleRequest{ + Anonymize: anonymize, Status: statusOutput, - SystemInfo: true, - }) + SystemInfo: systemInfo, + } + + if uploadURL != "" { + request.UploadURL = uploadURL + } + + resp, err := conn.DebugBundle(s.ctx, request) if err != nil { - return fmt.Errorf("failed to create debug bundle: %v", err) + return nil, fmt.Errorf("failed to create debug bundle via daemon: %v", err) } - bundleDir := filepath.Dir(resp.GetPath()) - if err := open.Start(bundleDir); err != nil { - return fmt.Errorf("failed to open debug bundle directory: %v", err) - } - - s.app.SendNotification(fyne.NewNotification( - "Debug Bundle", - fmt.Sprintf("Debug bundle created at %s. Administrator privileges are required to access it.", resp.GetPath()), - )) - - return nil + return resp, nil +} + +// showUploadFailedDialog displays a dialog when upload fails +func showUploadFailedDialog(parent fyne.Window, localPath, failureReason string) { + content := container.NewVBox( + widget.NewLabel(fmt.Sprintf("Bundle upload failed:\n%s\n\n"+ + "A local copy was saved at:\n%s", failureReason, localPath)), + ) + + customDialog := dialog.NewCustom("Upload Failed", "Cancel", content, parent) + + buttonBox := container.NewHBox( + widget.NewButton("Open File", func() { + log.Infof("Attempting to open local file: %s", localPath) + if openErr := open.Start(localPath); openErr != nil { + log.Errorf("Failed to open local file '%s': %v", localPath, openErr) + dialog.ShowError(fmt.Errorf("Failed to open the local file:\n%s\n\nError: %v", localPath, openErr), parent) + } + customDialog.Hide() + }), + widget.NewButton("Open Folder", func() { + folderPath := filepath.Dir(localPath) + log.Infof("Attempting to open local folder: %s", folderPath) + if openErr := open.Start(folderPath); openErr != nil { + log.Errorf("Failed to open local folder '%s': %v", folderPath, openErr) + dialog.ShowError(fmt.Errorf("Failed to open the local folder:\n%s\n\nError: %v", folderPath, openErr), parent) + } + customDialog.Hide() + }), + ) + + content.Add(buttonBox) + customDialog.Show() +} + +// showUploadSuccessDialog displays a dialog when upload succeeds +func showUploadSuccessDialog(parent fyne.Window, localPath, uploadedKey string) { + keyEntry := widget.NewEntry() + keyEntry.SetText(uploadedKey) + keyEntry.Disable() + + content := container.NewVBox( + widget.NewLabel("Bundle uploaded successfully!"), + widget.NewLabel(""), + widget.NewLabel("Upload Key:"), + keyEntry, + widget.NewLabel(""), + widget.NewLabel(fmt.Sprintf("Local copy saved at:\n%s", localPath)), + ) + + customDialog := dialog.NewCustom("Upload Successful", "OK", content, parent) + + buttonBox := container.NewHBox( + widget.NewButton("Copy Key", func() { + parent.Clipboard().SetContent(uploadedKey) + log.Info("Upload key copied to clipboard") + }), + widget.NewButton("Open Local Folder", func() { + folderPath := filepath.Dir(localPath) + log.Infof("Attempting to open local folder: %s", folderPath) + if openErr := open.Start(folderPath); openErr != nil { + log.Errorf("Failed to open local folder '%s': %v", folderPath, openErr) + dialog.ShowError(fmt.Errorf("Failed to open the local folder:\n%s\n\nError: %v", folderPath, openErr), parent) + } + }), + ) + + content.Add(buttonBox) + customDialog.Show() +} + +// showBundleCreatedDialog displays a dialog when bundle is created without upload +func showBundleCreatedDialog(parent fyne.Window, localPath string) { + content := container.NewVBox( + widget.NewLabel(fmt.Sprintf("Bundle created locally at:\n%s\n\n"+ + "Administrator privileges may be required to access the file.", localPath)), + ) + + customDialog := dialog.NewCustom("Debug Bundle Created", "Cancel", content, parent) + + buttonBox := container.NewHBox( + widget.NewButton("Open File", func() { + log.Infof("Attempting to open local file: %s", localPath) + if openErr := open.Start(localPath); openErr != nil { + log.Errorf("Failed to open local file '%s': %v", localPath, openErr) + dialog.ShowError(fmt.Errorf("Failed to open the local file:\n%s\n\nError: %v", localPath, openErr), parent) + } + customDialog.Hide() + }), + widget.NewButton("Open Folder", func() { + folderPath := filepath.Dir(localPath) + log.Infof("Attempting to open local folder: %s", folderPath) + if openErr := open.Start(folderPath); openErr != nil { + log.Errorf("Failed to open local folder '%s': %v", folderPath, openErr) + dialog.ShowError(fmt.Errorf("Failed to open the local folder:\n%s\n\nError: %v", folderPath, openErr), parent) + } + customDialog.Hide() + }), + ) + + content.Add(buttonBox) + customDialog.Show() } diff --git a/upload-server/types/upload.go b/upload-server/types/upload.go index 35d003582..327c28e75 100644 --- a/upload-server/types/upload.go +++ b/upload-server/types/upload.go @@ -7,6 +7,8 @@ const ( ClientHeaderValue = "netbird" // GetURLPath is the path for the GetURL request GetURLPath = "/upload-url" + + DefaultBundleURL = "https://upload.debug.netbird.io" + GetURLPath ) // GetURLResponse is the response for the GetURL request From 7b64953eedfb271f64fe290a112f3e41cb832fe6 Mon Sep 17 00:00:00 2001 From: Pedro Maia Costa <550684+pnmcosta@users.noreply.github.com> Date: Thu, 1 May 2025 11:24:55 +0100 Subject: [PATCH 31/45] [management] user info with role permissions (#3728) --- management/client/rest/users_test.go | 7 +- management/server/account/manager.go | 3 +- management/server/http/api/openapi.yml | 24 +- management/server/http/api/types.gen.go | 15 +- .../http/handlers/users/users_handler.go | 30 +- .../http/handlers/users/users_handler_test.go | 197 +++++++++--- management/server/mock_server/account_mock.go | 7 +- management/server/peer.go | 2 +- management/server/permissions/manager.go | 21 ++ management/server/permissions/manager_mock.go | 15 + .../server/permissions/modules/module.go | 16 + .../server/permissions/roles/network_admin.go | 12 +- management/server/types/user.go | 26 +- management/server/user.go | 60 ++-- management/server/user_test.go | 288 ++++++++---------- management/server/users/user.go | 14 + 16 files changed, 446 insertions(+), 291 deletions(-) create mode 100644 management/server/users/user.go diff --git a/management/client/rest/users_test.go b/management/client/rest/users_test.go index f68c5f083..715eb1661 100644 --- a/management/client/rest/users_test.go +++ b/management/client/rest/users_test.go @@ -30,11 +30,8 @@ var ( Issued: ptr("api"), LastLogin: &time.Time{}, Name: "M. Essam", - Permissions: &api.UserPermissions{ - DashboardView: ptr(api.UserPermissionsDashboardViewFull), - }, - Role: "user", - Status: api.UserStatusActive, + Role: "user", + Status: api.UserStatusActive, } ) diff --git a/management/server/account/manager.go b/management/server/account/manager.go index aed83349f..9bc4f9605 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/route" ) @@ -115,5 +116,5 @@ type Manager interface { CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) - GetCurrentUserInfo(ctx context.Context, accountID, userID string) (*types.UserInfo, error) + GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) } diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 51ffd65b2..bf40777fc 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -216,11 +216,25 @@ components: UserPermissions: type: object properties: - dashboard_view: - description: User's permission to view the dashboard - type: string - enum: [ "limited", "blocked", "full" ] - example: limited + is_restricted: + type: boolean + description: Indicates whether this User's Peers view is restricted + modules: + type: object + additionalProperties: + type: object + additionalProperties: + type: boolean + propertyNames: + type: string + description: The operation type + propertyNames: + type: string + description: The module name + example: {"networks": { "read": true, "create": false, "update": false, "delete": false}, "peers": { "read": false, "create": false, "update": false, "delete": false} } + required: + - modules + - is_restricted UserRequest: type: object properties: diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index e01275b99..e108c6884 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -178,13 +178,6 @@ const ( UserStatusInvited UserStatus = "invited" ) -// Defines values for UserPermissionsDashboardView. -const ( - UserPermissionsDashboardViewBlocked UserPermissionsDashboardView = "blocked" - UserPermissionsDashboardViewFull UserPermissionsDashboardView = "full" - UserPermissionsDashboardViewLimited UserPermissionsDashboardView = "limited" -) - // Defines values for GetApiEventsNetworkTrafficParamsType. const ( GetApiEventsNetworkTrafficParamsTypeTYPEDROP GetApiEventsNetworkTrafficParamsType = "TYPE_DROP" @@ -1757,13 +1750,11 @@ type UserCreateRequest struct { // UserPermissions defines model for UserPermissions. type UserPermissions struct { - // DashboardView User's permission to view the dashboard - DashboardView *UserPermissionsDashboardView `json:"dashboard_view,omitempty"` + // IsRestricted Indicates whether this User's Peers view is restricted + IsRestricted bool `json:"is_restricted"` + Modules map[string]map[string]bool `json:"modules"` } -// UserPermissionsDashboardView User's permission to view the dashboard -type UserPermissionsDashboardView string - // UserRequest defines model for UserRequest. type UserRequest struct { // AutoGroups Group IDs to auto-assign to peers registered by this user diff --git a/management/server/http/handlers/users/users_handler.go b/management/server/http/handlers/users/users_handler.go index c69c6b944..ac04b8e35 100644 --- a/management/server/http/handlers/users/users_handler.go +++ b/management/server/http/handlers/users/users_handler.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" nbcontext "github.com/netbirdio/netbird/management/server/context" ) @@ -272,15 +273,33 @@ func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request) { return } - accountID, userID := userAuth.AccountId, userAuth.UserId - - user, err := h.accountManager.GetCurrentUserInfo(ctx, accountID, userID) + user, err := h.accountManager.GetCurrentUserInfo(ctx, userAuth) if err != nil { util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(r.Context(), w, toUserResponse(user, userID)) + util.WriteJSONObject(r.Context(), w, toUserWithPermissionsResponse(user, userAuth.UserId)) +} + +func toUserWithPermissionsResponse(user *users.UserInfoWithPermissions, userID string) *api.User { + response := toUserResponse(user.UserInfo, userID) + + // stringify modules and operations keys + modules := make(map[string]map[string]bool) + for module, operations := range user.Permissions { + modules[string(module)] = make(map[string]bool) + for op, val := range operations { + modules[string(module)][string(op)] = val + } + } + + response.Permissions = &api.UserPermissions{ + IsRestricted: user.Restricted, + Modules: modules, + } + + return response } func toUserResponse(user *types.UserInfo, currenUserID string) *api.User { @@ -316,8 +335,5 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User { IsBlocked: user.IsBlocked, LastLogin: &user.LastLogin, Issued: &user.Issued, - Permissions: &api.UserPermissions{ - DashboardView: (*api.UserPermissionsDashboardView)(&user.Permissions.DashboardView), - }, } } diff --git a/management/server/http/handlers/users/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go index 604954819..58e33a6d5 100644 --- a/management/server/http/handlers/users/users_handler_test.go +++ b/management/server/http/handlers/users/users_handler_test.go @@ -13,12 +13,16 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/roles" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" ) const ( @@ -107,7 +111,7 @@ func initUsersTestData() *handler { return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID) } - info, err := update.Copy().ToUserInfo(nil, &types.Settings{RegularUsersViewBlocked: false}) + info, err := update.Copy().ToUserInfo(nil) if err != nil { return nil, err } @@ -124,8 +128,8 @@ func initUsersTestData() *handler { return nil }, - GetCurrentUserInfoFunc: func(ctx context.Context, accountID, userID string) (*types.UserInfo, error) { - switch userID { + GetCurrentUserInfoFunc: func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) { + switch userAuth.UserId { case "not-found": return nil, status.NewUserNotFoundError("not-found") case "not-of-account": @@ -135,52 +139,68 @@ func initUsersTestData() *handler { case "service-user": return nil, status.NewPermissionDeniedError() case "owner": - return &types.UserInfo{ - ID: "owner", - Name: "", - Role: "owner", - Status: "active", - IsServiceUser: false, - IsBlocked: false, - NonDeletable: false, - Issued: "api", - Permissions: types.UserPermissions{ - DashboardView: "full", + return &users.UserInfoWithPermissions{ + UserInfo: &types.UserInfo{ + ID: "owner", + Name: "", + Role: "owner", + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + Issued: "api", }, + Permissions: mergeRolePermissions(roles.Owner), }, nil case "regular-user": - return &types.UserInfo{ - ID: "regular-user", - Name: "", - Role: "user", - Status: "active", - IsServiceUser: false, - IsBlocked: false, - NonDeletable: false, - Issued: "api", - Permissions: types.UserPermissions{ - DashboardView: "limited", + return &users.UserInfoWithPermissions{ + UserInfo: &types.UserInfo{ + ID: "regular-user", + Name: "", + Role: "user", + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + Issued: "api", }, + Permissions: mergeRolePermissions(roles.User), }, nil case "admin-user": - return &types.UserInfo{ - ID: "admin-user", - Name: "", - Role: "admin", - Status: "active", - IsServiceUser: false, - IsBlocked: false, - NonDeletable: false, - LastLogin: time.Time{}, - Issued: "api", - Permissions: types.UserPermissions{ - DashboardView: "full", + return &users.UserInfoWithPermissions{ + UserInfo: &types.UserInfo{ + ID: "admin-user", + Name: "", + Role: "admin", + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + LastLogin: time.Time{}, + Issued: "api", }, + Permissions: mergeRolePermissions(roles.Admin), + }, nil + case "restricted-user": + return &users.UserInfoWithPermissions{ + UserInfo: &types.UserInfo{ + ID: "restricted-user", + Name: "", + Role: "user", + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + LastLogin: time.Time{}, + Issued: "api", + }, + Permissions: mergeRolePermissions(roles.User), + Restricted: true, }, nil } - return nil, fmt.Errorf("user id %s not handled", userID) + return nil, fmt.Errorf("user id %s not handled", userAuth.UserId) }, }, } @@ -546,6 +566,7 @@ func TestCurrentUser(t *testing.T) { name string expectedStatus int requestAuth nbcontext.UserAuth + expectedResult *api.User }{ { name: "without auth", @@ -575,16 +596,78 @@ func TestCurrentUser(t *testing.T) { name: "owner", requestAuth: nbcontext.UserAuth{UserId: "owner"}, expectedStatus: http.StatusOK, + expectedResult: &api.User{ + Id: "owner", + Role: "owner", + Status: "active", + IsBlocked: false, + IsCurrent: ptr(true), + IsServiceUser: ptr(false), + AutoGroups: []string{}, + Issued: ptr("api"), + LastLogin: ptr(time.Time{}), + Permissions: &api.UserPermissions{ + Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.Owner)), + }, + }, }, { name: "regular user", requestAuth: nbcontext.UserAuth{UserId: "regular-user"}, expectedStatus: http.StatusOK, + expectedResult: &api.User{ + Id: "regular-user", + Role: "user", + Status: "active", + IsBlocked: false, + IsCurrent: ptr(true), + IsServiceUser: ptr(false), + AutoGroups: []string{}, + Issued: ptr("api"), + LastLogin: ptr(time.Time{}), + Permissions: &api.UserPermissions{ + Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.User)), + }, + }, }, { name: "admin user", requestAuth: nbcontext.UserAuth{UserId: "admin-user"}, expectedStatus: http.StatusOK, + expectedResult: &api.User{ + Id: "admin-user", + Role: "admin", + Status: "active", + IsBlocked: false, + IsCurrent: ptr(true), + IsServiceUser: ptr(false), + AutoGroups: []string{}, + Issued: ptr("api"), + LastLogin: ptr(time.Time{}), + Permissions: &api.UserPermissions{ + Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.Admin)), + }, + }, + }, + { + name: "restricted user", + requestAuth: nbcontext.UserAuth{UserId: "restricted-user"}, + expectedStatus: http.StatusOK, + expectedResult: &api.User{ + Id: "restricted-user", + Role: "user", + Status: "active", + IsBlocked: false, + IsCurrent: ptr(true), + IsServiceUser: ptr(false), + AutoGroups: []string{}, + Issued: ptr("api"), + LastLogin: ptr(time.Time{}), + Permissions: &api.UserPermissions{ + IsRestricted: true, + Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.User)), + }, + }, }, } @@ -603,10 +686,42 @@ func TestCurrentUser(t *testing.T) { res := rr.Result() defer res.Body.Close() - if status := rr.Code; status != tc.expectedStatus { - t.Fatalf("handler returned wrong status code: got %v want %v", - status, tc.expectedStatus) + assert.Equal(t, tc.expectedStatus, rr.Code, "handler returned wrong status code") + + if tc.expectedResult != nil { + var result api.User + require.NoError(t, json.NewDecoder(res.Body).Decode(&result)) + assert.EqualValues(t, *tc.expectedResult, result) } }) } } + +func ptr[T any, PT *T](x T) PT { + return &x +} + +func mergeRolePermissions(role roles.RolePermissions) roles.Permissions { + permissions := roles.Permissions{} + + for k := range modules.All { + if rolePermissions, ok := role.Permissions[k]; ok { + permissions[k] = rolePermissions + continue + } + permissions[k] = role.AutoAllowNew + } + + return permissions +} + +func stringifyPermissionsKeys(permissions roles.Permissions) map[string]map[string]bool { + modules := make(map[string]map[string]bool) + for module, operations := range permissions { + modules[string(module)] = make(map[string]bool) + for op, val := range operations { + modules[string(module)][string(op)] = val + } + } + return modules +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 2b57e6888..0dd3f927e 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -19,6 +19,7 @@ import ( "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/route" ) @@ -115,7 +116,7 @@ type MockAccountManager struct { CreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, error) UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error) GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error) - GetCurrentUserInfoFunc func(ctx context.Context, accountID, userID string) (*types.UserInfo, error) + GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error) } @@ -882,9 +883,9 @@ func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string return nil, status.Errorf(codes.Unimplemented, "method GetOwnerInfo is not implemented") } -func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, accountID, userID string) (*types.UserInfo, error) { +func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) { if am.GetCurrentUserInfoFunc != nil { - return am.GetCurrentUserInfoFunc(ctx, accountID, userID) + return am.GetCurrentUserInfoFunc(ctx, userAuth) } return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented") } diff --git a/management/server/peer.go b/management/server/peer.go index a4210e3f0..9ff80442e 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -59,7 +59,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID return nil, fmt.Errorf("failed to get account settings: %w", err) } - if settings.RegularUsersViewBlocked { + if user.IsRestrictable() && settings.RegularUsersViewBlocked { return []*nbpeer.Peer{}, nil } diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go index 50a44eb0f..ebbce5d4a 100644 --- a/management/server/permissions/manager.go +++ b/management/server/permissions/manager.go @@ -20,6 +20,8 @@ type Manager interface { ValidateUserPermissions(ctx context.Context, accountID, userID string, module modules.Module, operation operations.Operation) (bool, error) ValidateRoleModuleAccess(ctx context.Context, accountID string, role roles.RolePermissions, module modules.Module, operation operations.Operation) bool ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error + + GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error) } type managerImpl struct { @@ -96,3 +98,22 @@ func (m *managerImpl) ValidateAccountAccess(ctx context.Context, accountID strin } return nil } + +func (m *managerImpl) GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error) { + roleMap, ok := roles.RolesMap[role] + if !ok { + return roles.Permissions{}, status.NewUserRoleNotFoundError(string(role)) + } + + permissions := roles.Permissions{} + + for k := range modules.All { + if rolePermissions, ok := roleMap.Permissions[k]; ok { + permissions[k] = rolePermissions + continue + } + permissions[k] = roleMap.AutoAllowNew + } + + return permissions, nil +} diff --git a/management/server/permissions/manager_mock.go b/management/server/permissions/manager_mock.go index 266a24270..fa115d628 100644 --- a/management/server/permissions/manager_mock.go +++ b/management/server/permissions/manager_mock.go @@ -38,6 +38,21 @@ func (m *MockManager) EXPECT() *MockManagerMockRecorder { return m.recorder } +// GetPermissionsByRole mocks base method. +func (m *MockManager) GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPermissionsByRole", ctx, role) + ret0, _ := ret[0].(roles.Permissions) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPermissionsByRole indicates an expected call of GetPermissionsByRole. +func (mr *MockManagerMockRecorder) GetPermissionsByRole(ctx, role interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPermissionsByRole", reflect.TypeOf((*MockManager)(nil).GetPermissionsByRole), ctx, role) +} + // ValidateAccountAccess mocks base method. func (m *MockManager) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error { m.ctrl.T.Helper() diff --git a/management/server/permissions/modules/module.go b/management/server/permissions/modules/module.go index 4c42b6190..3d021a235 100644 --- a/management/server/permissions/modules/module.go +++ b/management/server/permissions/modules/module.go @@ -17,3 +17,19 @@ const ( SetupKeys Module = "setup_keys" Pats Module = "pats" ) + +var All = map[Module]struct{}{ + Networks: {}, + Peers: {}, + Groups: {}, + Settings: {}, + Accounts: {}, + Dns: {}, + Nameservers: {}, + Events: {}, + Policies: {}, + Routes: {}, + Users: {}, + SetupKeys: {}, + Pats: {}, +} diff --git a/management/server/permissions/roles/network_admin.go b/management/server/permissions/roles/network_admin.go index 761933386..e95d58381 100644 --- a/management/server/permissions/roles/network_admin.go +++ b/management/server/permissions/roles/network_admin.go @@ -23,9 +23,9 @@ var NetworkAdmin = RolePermissions{ }, modules.Groups: { operations.Read: true, - operations.Create: false, - operations.Update: false, - operations.Delete: false, + operations.Create: true, + operations.Update: true, + operations.Delete: true, }, modules.Settings: { operations.Read: true, @@ -87,5 +87,11 @@ var NetworkAdmin = RolePermissions{ operations.Update: true, operations.Delete: true, }, + modules.Peers: { + operations.Read: true, + operations.Create: false, + operations.Update: false, + operations.Delete: false, + }, }, } diff --git a/management/server/types/user.go b/management/server/types/user.go index a2596b3cb..783fe14da 100644 --- a/management/server/types/user.go +++ b/management/server/types/user.go @@ -65,11 +65,6 @@ type UserInfo struct { LastLogin time.Time `json:"last_login"` Issued string `json:"issued"` IntegrationReference integration_reference.IntegrationReference `json:"-"` - Permissions UserPermissions `json:"permissions"` -} - -type UserPermissions struct { - DashboardView string `json:"dashboard_view"` } // User represents a user of the system @@ -132,21 +127,18 @@ func (u *User) IsRegularUser() bool { return !u.HasAdminPower() && !u.IsServiceUser } +// IsRestrictable checks whether a user is in a restrictable role. +func (u *User) IsRestrictable() bool { + return u.Role == UserRoleUser || u.Role == UserRoleBillingAdmin +} + // ToUserInfo converts a User object to a UserInfo object. -func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) { +func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { autoGroups := u.AutoGroups if autoGroups == nil { autoGroups = []string{} } - dashboardViewPermissions := "full" - if !u.HasAdminPower() { - dashboardViewPermissions = "limited" - if settings.RegularUsersViewBlocked { - dashboardViewPermissions = "blocked" - } - } - if userData == nil { return &UserInfo{ ID: u.Id, @@ -159,9 +151,6 @@ func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo IsBlocked: u.Blocked, LastLogin: u.GetLastLogin(), Issued: u.Issued, - Permissions: UserPermissions{ - DashboardView: dashboardViewPermissions, - }, }, nil } if userData.ID != u.Id { @@ -184,9 +173,6 @@ func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo IsBlocked: u.Blocked, LastLogin: u.GetLastLogin(), Issued: u.Issued, - Permissions: UserPermissions{ - DashboardView: dashboardViewPermissions, - }, }, nil } diff --git a/management/server/user.go b/management/server/user.go index b46ed24cf..44ad3b68f 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -12,6 +12,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbContext "github.com/netbirdio/netbird/management/server/context" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/permissions/modules" @@ -19,6 +20,7 @@ import ( "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/util" ) @@ -122,11 +124,6 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u CreatedAt: time.Now().UTC(), } - settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) - if err != nil { - return nil, err - } - if err = am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser); err != nil { return nil, err } @@ -138,7 +135,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u am.StoreEvent(ctx, userID, newUser.Id, accountID, activity.UserInvited, nil) - return newUser.ToUserInfo(idpUser, settings) + return newUser.ToUserInfo(idpUser) } // createNewIdpUser validates the invite and creates a new user in the IdP @@ -360,6 +357,7 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string return nil, err } + // @note this is essential to prevent non admin users with Pats create permission frpm creating one for a service user if initiatorUserID != targetUserID && !(initiatorUser.HasAdminPower() && targetUser.IsServiceUser) { return nil, status.NewAdminPermissionError() } @@ -727,19 +725,14 @@ func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initi // If the AccountManager has a non-nil idpManager and the User is not a service user, // it will attempt to look up the UserData from the cache. func (am *DefaultAccountManager) getUserInfo(ctx context.Context, user *types.User, accountID string) (*types.UserInfo, error) { - settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) - if err != nil { - return nil, err - } - if !isNil(am.idpManager) && !user.IsServiceUser { userData, err := am.lookupUserInCache(ctx, user.Id, accountID) if err != nil { return nil, err } - return user.ToUserInfo(userData, settings) + return user.ToUserInfo(userData) } - return user.ToUserInfo(nil, settings) + return user.ToUserInfo(nil) } // validateUserUpdate validates the update operation for a user. @@ -879,17 +872,12 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a queriedUsers = append(queriedUsers, usersFromIntegration...) } - settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) - if err != nil { - return nil, err - } - userInfosMap := make(map[string]*types.UserInfo) // in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo if len(queriedUsers) == 0 { for _, accountUser := range accountUsers { - info, err := accountUser.ToUserInfo(nil, settings) + info, err := accountUser.ToUserInfo(nil) if err != nil { return nil, err } @@ -902,7 +890,7 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a for _, localUser := range accountUsers { var info *types.UserInfo if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains { - info, err = localUser.ToUserInfo(queriedUser, settings) + info, err = localUser.ToUserInfo(queriedUser) if err != nil { return nil, err } @@ -912,14 +900,6 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a name = localUser.ServiceUserName } - dashboardViewPermissions := "full" - if !localUser.HasAdminPower() { - dashboardViewPermissions = "limited" - if settings.RegularUsersViewBlocked { - dashboardViewPermissions = "blocked" - } - } - info = &types.UserInfo{ ID: localUser.Id, Email: "", @@ -929,7 +909,6 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a Status: string(types.UserStatusActive), IsServiceUser: localUser.IsServiceUser, NonDeletable: localUser.NonDeletable, - Permissions: types.UserPermissions{DashboardView: dashboardViewPermissions}, } } userInfosMap[info.ID] = info @@ -1239,8 +1218,10 @@ func validateUserInvite(invite *types.UserInfo) error { return nil } -// GetCurrentUserInfo retrieves the account's current user info -func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, accountID, userID string) (*types.UserInfo, error) { +// GetCurrentUserInfo retrieves the account's current user info and permissions +func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) { + accountID, userID := userAuth.AccountId, userAuth.UserId + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err @@ -1258,10 +1239,25 @@ func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, account return nil, err } + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + userInfo, err := am.getUserInfo(ctx, user, accountID) if err != nil { return nil, err } - return userInfo, nil + userWithPermissions := &users.UserInfoWithPermissions{ + UserInfo: userInfo, + Restricted: !userAuth.IsChild && user.IsRestrictable() && settings.RegularUsersViewBlocked, + } + + permissions, err := am.permissionsManager.GetPermissionsByRole(ctx, user.Role) + if err == nil { + userWithPermissions.Permissions = permissions + } + + return userWithPermissions, nil } diff --git a/management/server/user_test.go b/management/server/user_test.go index 83c5ac49a..66bdc1683 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -13,7 +13,10 @@ import ( nbcache "github.com/netbirdio/netbird/management/server/cache" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/roles" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/util" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -1020,90 +1023,6 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) { assert.Equal(t, 2, regular) } -func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { - testCases := []struct { - name string - role types.UserRole - limitedViewSettings bool - expectedDashboardPermissions string - }{ - { - name: "Regular user, no limited view settings", - role: types.UserRoleUser, - limitedViewSettings: false, - expectedDashboardPermissions: "limited", - }, - { - name: "Admin user, no limited view settings", - role: types.UserRoleAdmin, - limitedViewSettings: false, - expectedDashboardPermissions: "full", - }, - { - name: "Owner, no limited view settings", - role: types.UserRoleOwner, - limitedViewSettings: false, - expectedDashboardPermissions: "full", - }, - { - name: "Regular user, limited view settings", - role: types.UserRoleUser, - limitedViewSettings: true, - expectedDashboardPermissions: "blocked", - }, - { - name: "Admin user, limited view settings", - role: types.UserRoleAdmin, - limitedViewSettings: true, - expectedDashboardPermissions: "full", - }, - { - name: "Owner, limited view settings", - role: types.UserRoleOwner, - limitedViewSettings: true, - expectedDashboardPermissions: "full", - }, - } - - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) - if err != nil { - t.Fatalf("Error when creating store: %s", err) - } - t.Cleanup(cleanup) - - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users["normal_user1"] = types.NewUser("normal_user1", testCase.role, false, false, "", []string{}, types.UserIssuedAPI) - account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings - delete(account.Users, mockUserID) - - err = store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } - - permissionsManager := permissions.NewManager(store) - am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, - permissionsManager: permissionsManager, - } - - users, err := am.ListUsers(context.Background(), mockAccountID) - if err != nil { - t.Fatalf("Error when checking user role: %s", err) - } - - assert.Equal(t, 1, len(users)) - - userInfo, _ := users[0].ToUserInfo(nil, account.Settings) - assert.Equal(t, testCase.expectedDashboardPermissions, userInfo.Permissions.DashboardView) - }) - } - -} - func TestDefaultAccountManager_ExternalCache(t *testing.T) { store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) if err != nil { @@ -1654,121 +1573,154 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { tt := []struct { name string - accountId string - userId string + userAuth nbcontext.UserAuth expectedErr error - expectedResult *types.UserInfo + expectedResult *users.UserInfoWithPermissions }{ { name: "not found", - accountId: account1.Id, - userId: "not-found", + userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "not-found"}, expectedErr: status.NewUserNotFoundError("not-found"), }, { name: "not part of account", - accountId: account1.Id, - userId: "account2Owner", + userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "account2Owner"}, expectedErr: status.NewUserNotPartOfAccountError(), }, { name: "blocked", - accountId: account1.Id, - userId: "blocked-user", + userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "blocked-user"}, expectedErr: status.NewUserBlockedError(), }, { name: "service user", - accountId: account1.Id, - userId: "service-user", + userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "service-user"}, expectedErr: status.NewPermissionDeniedError(), }, { - name: "owner user", - accountId: account1.Id, - userId: "account1Owner", - expectedResult: &types.UserInfo{ - ID: "account1Owner", - Name: "", - Role: "owner", - AutoGroups: []string{}, - Status: "active", - IsServiceUser: false, - IsBlocked: false, - NonDeletable: false, - LastLogin: time.Time{}, - Issued: "api", - IntegrationReference: integration_reference.IntegrationReference{}, - Permissions: types.UserPermissions{ - DashboardView: "full", + name: "owner user", + userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "account1Owner"}, + expectedResult: &users.UserInfoWithPermissions{ + UserInfo: &types.UserInfo{ + ID: "account1Owner", + Name: "", + Role: "owner", + AutoGroups: []string{}, + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + LastLogin: time.Time{}, + Issued: "api", + IntegrationReference: integration_reference.IntegrationReference{}, }, + Permissions: mergeRolePermissions(roles.Owner), }, }, { - name: "regular user", - accountId: account1.Id, - userId: "regular-user", - expectedResult: &types.UserInfo{ - ID: "regular-user", - Name: "", - Role: "user", - Status: "active", - IsServiceUser: false, - IsBlocked: false, - NonDeletable: false, - LastLogin: time.Time{}, - Issued: "api", - IntegrationReference: integration_reference.IntegrationReference{}, - Permissions: types.UserPermissions{ - DashboardView: "limited", + name: "regular user", + userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "regular-user"}, + expectedResult: &users.UserInfoWithPermissions{ + UserInfo: &types.UserInfo{ + ID: "regular-user", + Name: "", + Role: "user", + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + LastLogin: time.Time{}, + Issued: "api", + IntegrationReference: integration_reference.IntegrationReference{}, }, + Permissions: mergeRolePermissions(roles.User), }, }, { - name: "admin user", - accountId: account1.Id, - userId: "admin-user", - expectedResult: &types.UserInfo{ - ID: "admin-user", - Name: "", - Role: "admin", - Status: "active", - IsServiceUser: false, - IsBlocked: false, - NonDeletable: false, - LastLogin: time.Time{}, - Issued: "api", - IntegrationReference: integration_reference.IntegrationReference{}, - Permissions: types.UserPermissions{ - DashboardView: "full", + name: "admin user", + userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "admin-user"}, + expectedResult: &users.UserInfoWithPermissions{ + UserInfo: &types.UserInfo{ + ID: "admin-user", + Name: "", + Role: "admin", + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + LastLogin: time.Time{}, + Issued: "api", + IntegrationReference: integration_reference.IntegrationReference{}, }, + Permissions: mergeRolePermissions(roles.Admin), }, }, { - name: "settings blocked regular user", - accountId: account2.Id, - userId: "settings-blocked-user", - expectedResult: &types.UserInfo{ - ID: "settings-blocked-user", - Name: "", - Role: "user", - Status: "active", - IsServiceUser: false, - IsBlocked: false, - NonDeletable: false, - LastLogin: time.Time{}, - Issued: "api", - IntegrationReference: integration_reference.IntegrationReference{}, - Permissions: types.UserPermissions{ - DashboardView: "blocked", + name: "settings blocked regular user", + userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user"}, + expectedResult: &users.UserInfoWithPermissions{ + UserInfo: &types.UserInfo{ + ID: "settings-blocked-user", + Name: "", + Role: "user", + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + LastLogin: time.Time{}, + Issued: "api", + IntegrationReference: integration_reference.IntegrationReference{}, }, + Permissions: mergeRolePermissions(roles.User), + Restricted: true, + }, + }, + + { + name: "settings blocked regular user child account", + userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user", IsChild: true}, + expectedResult: &users.UserInfoWithPermissions{ + UserInfo: &types.UserInfo{ + ID: "settings-blocked-user", + Name: "", + Role: "user", + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + LastLogin: time.Time{}, + Issued: "api", + IntegrationReference: integration_reference.IntegrationReference{}, + }, + Permissions: mergeRolePermissions(roles.User), + Restricted: false, + }, + }, + { + name: "settings blocked owner user", + userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "account2Owner"}, + expectedResult: &users.UserInfoWithPermissions{ + UserInfo: &types.UserInfo{ + ID: "account2Owner", + Name: "", + Role: "owner", + AutoGroups: []string{}, + Status: "active", + IsServiceUser: false, + IsBlocked: false, + NonDeletable: false, + LastLogin: time.Time{}, + Issued: "api", + IntegrationReference: integration_reference.IntegrationReference{}, + }, + Permissions: mergeRolePermissions(roles.Owner), }, }, } for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { - result, err := am.GetCurrentUserInfo(context.Background(), tc.accountId, tc.userId) + result, err := am.GetCurrentUserInfo(context.Background(), tc.userAuth) if tc.expectedErr != nil { assert.Equal(t, err, tc.expectedErr) @@ -1780,3 +1732,17 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { }) } } + +func mergeRolePermissions(role roles.RolePermissions) roles.Permissions { + permissions := roles.Permissions{} + + for k := range modules.All { + if rolePermissions, ok := role.Permissions[k]; ok { + permissions[k] = rolePermissions + continue + } + permissions[k] = role.AutoAllowNew + } + + return permissions +} diff --git a/management/server/users/user.go b/management/server/users/user.go new file mode 100644 index 000000000..2f2788271 --- /dev/null +++ b/management/server/users/user.go @@ -0,0 +1,14 @@ +package users + +import ( + "github.com/netbirdio/netbird/management/server/permissions/roles" + "github.com/netbirdio/netbird/management/server/types" +) + +// Wrapped UserInfo with Role Permissions +type UserInfoWithPermissions struct { + *types.UserInfo + + Permissions roles.Permissions + Restricted bool +} From 01c3719c5d5a2a034e03672a5da181ce34c5d8ea Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 1 May 2025 23:25:27 +0200 Subject: [PATCH 32/45] [client] Add debug for duration option to netbird ui (#3772) --- client/cmd/debug.go | 7 - client/server/debug.go | 4 +- client/ui/client_ui.go | 84 ++++-- client/ui/debug.go | 575 ++++++++++++++++++++++++++++++++++++----- client/ui/network.go | 18 +- 5 files changed, 586 insertions(+), 102 deletions(-) diff --git a/client/cmd/debug.go b/client/cmd/debug.go index b4adee826..385bd95f5 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -235,13 +235,6 @@ func runForDuration(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) } - // Disable network map persistence after creating the debug bundle - if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{ - Enabled: false, - }); err != nil { - return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message()) - } - if stateWasDown { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) diff --git a/client/server/debug.go b/client/server/debug.go index b42b1467a..7de3e8609 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -51,14 +51,16 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) ( } if req.GetUploadURL() == "" { - return &proto.DebugBundleResponse{Path: path}, nil } key, err := uploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path) if err != nil { + log.Errorf("failed to upload debug bundle to %s: %v", req.GetUploadURL(), err) return &proto.DebugBundleResponse{Path: path, UploadFailureReason: err.Error()}, nil } + log.Infof("debug bundle uploaded to %s with key %s", req.GetUploadURL(), key) + return &proto.DebugBundleResponse{Path: path, UploadedKey: key}, nil } diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index d8c1ee7a2..2c8023185 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -54,11 +54,14 @@ func main() { daemonAddr, showSettings, showNetworks, showDebug, errorMsg, saveLogsInFile := parseFlags() // Initialize file logging if needed. + var logFile string if saveLogsInFile { - if err := initLogFile(); err != nil { + file, err := initLogFile() + if err != nil { log.Errorf("error while initializing log: %v", err) return } + logFile = file } // Create the Fyne application. @@ -72,7 +75,7 @@ func main() { } // Create the service client (this also builds the settings or networks UI if requested). - client := newServiceClient(daemonAddr, a, showSettings, showNetworks, showDebug) + client := newServiceClient(daemonAddr, logFile, a, showSettings, showNetworks, showDebug) // Watch for theme/settings changes to update the icon. go watchSettingsChanges(a, client) @@ -115,9 +118,9 @@ func parseFlags() (daemonAddr string, showSettings, showNetworks, showDebug bool } // initLogFile initializes logging into a file. -func initLogFile() error { +func initLogFile() (string, error) { logFile := path.Join(os.TempDir(), fmt.Sprintf("netbird-ui-%d.log", os.Getpid())) - return util.InitLog("trace", logFile) + return logFile, util.InitLog("trace", logFile) } // watchSettingsChanges listens for Fyne theme/settings changes and updates the client icon. @@ -160,9 +163,10 @@ var iconConnectingMacOS []byte var iconErrorMacOS []byte type serviceClient struct { - ctx context.Context - addr string - conn proto.DaemonServiceClient + ctx context.Context + cancel context.CancelFunc + addr string + conn proto.DaemonServiceClient icAbout []byte icConnected []byte @@ -224,12 +228,13 @@ type serviceClient struct { updateIndicationLock sync.Mutex isUpdateIconActive bool showNetworks bool - wRoutes fyne.Window + wNetworks fyne.Window eventManager *event.Manager exitNodeMu sync.Mutex mExitNodeItems []menuHandler + logFile string } type menuHandler struct { @@ -240,11 +245,14 @@ type menuHandler struct { // newServiceClient instance constructor // // This constructor also builds the UI elements for the settings window. -func newServiceClient(addr string, a fyne.App, showSettings bool, showNetworks bool, showDebug bool) *serviceClient { +func newServiceClient(addr string, logFile string, a fyne.App, showSettings bool, showNetworks bool, showDebug bool) *serviceClient { + ctx, cancel := context.WithCancel(context.Background()) s := &serviceClient{ - ctx: context.Background(), + ctx: ctx, + cancel: cancel, addr: addr, app: a, + logFile: logFile, sendNotification: false, showAdvancedSettings: showSettings, @@ -256,9 +264,7 @@ func newServiceClient(addr string, a fyne.App, showSettings bool, showNetworks b switch { case showSettings: - s.showSettingsUI() - return s case showNetworks: s.showNetworksUI() case showDebug: @@ -309,6 +315,8 @@ func (s *serviceClient) updateIcon() { func (s *serviceClient) showSettingsUI() { // add settings window UI elements. s.wSettings = s.app.NewWindow("NetBird Settings") + s.wSettings.SetOnClosed(s.cancel) + s.iMngURL = widget.NewEntry() s.iAdminURL = widget.NewEntry() s.iConfigFile = widget.NewEntry() @@ -784,7 +792,7 @@ func (s *serviceClient) onTrayReady() { func (s *serviceClient) runSelfCommand(command, arg string) { proc, err := os.Executable() if err != nil { - log.Errorf("show %s failed with error: %v", command, err) + log.Errorf("Error getting executable path: %v", err) return } @@ -793,14 +801,48 @@ func (s *serviceClient) runSelfCommand(command, arg string) { fmt.Sprintf("--daemon-addr=%s", s.addr), ) - out, err := cmd.CombinedOutput() - if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 { - log.Errorf("start %s UI: %v, %s", command, err, string(out)) + if out := s.attachOutput(cmd); out != nil { + defer func() { + if err := out.Close(); err != nil { + log.Errorf("Error closing log file %s: %v", s.logFile, err) + } + }() + } + + log.Printf("Running command: %s --%s=%s --daemon-addr=%s", proc, command, arg, s.addr) + + err = cmd.Run() + + if err != nil { + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + log.Printf("Command '%s %s' failed with exit code %d", command, arg, exitErr.ExitCode()) + } else { + log.Printf("Failed to start/run command '%s %s': %v", command, arg, err) + } return } - if len(out) != 0 { - log.Infof("command %s executed: %s", command, string(out)) + + log.Printf("Command '%s %s' completed successfully.", command, arg) +} + +func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File { + if s.logFile == "" { + // attach child's streams to parent's streams + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + return nil } + + out, err := os.OpenFile(s.logFile, os.O_WRONLY|os.O_APPEND, 0) + if err != nil { + log.Errorf("Failed to open log file %s: %v", s.logFile, err) + return nil + } + cmd.Stdout = out + cmd.Stderr = out + return out } func normalizedVersion(version string) string { @@ -813,9 +855,7 @@ func normalizedVersion(version string) string { // onTrayExit is called when the tray icon is closed. func (s *serviceClient) onTrayExit() { - for _, item := range s.mExitNodeItems { - item.cancel() - } + s.cancel() } // getSrvClient connection to the service. @@ -824,7 +864,7 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService return s.conn, nil } - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(s.ctx, timeout) defer cancel() conn, err := grpc.DialContext( diff --git a/client/ui/debug.go b/client/ui/debug.go index e950e6d1e..ab7dba37a 100644 --- a/client/ui/debug.go +++ b/client/ui/debug.go @@ -3,8 +3,12 @@ package main import ( + "context" "fmt" "path/filepath" + "strconv" + "sync" + "time" "fyne.io/fyne/v2" "fyne.io/fyne/v2/container" @@ -13,18 +17,46 @@ import ( log "github.com/sirupsen/logrus" "github.com/skratchdot/open-golang/open" + "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/proto" nbstatus "github.com/netbirdio/netbird/client/status" uptypes "github.com/netbirdio/netbird/upload-server/types" ) +// Initial state for the debug collection +type debugInitialState struct { + wasDown bool + logLevel proto.LogLevel + isLevelTrace bool +} + +// Debug collection parameters +type debugCollectionParams struct { + duration time.Duration + anonymize bool + systemInfo bool + upload bool + uploadURL string + enablePersistence bool +} + +// UI components for progress tracking +type progressUI struct { + statusLabel *widget.Label + progressBar *widget.ProgressBar + uiControls []fyne.Disableable + window fyne.Window +} + func (s *serviceClient) showDebugUI() { w := s.app.NewWindow("NetBird Debug") - w.Resize(fyne.NewSize(600, 400)) + w.SetOnClosed(s.cancel) + + w.Resize(fyne.NewSize(600, 500)) w.SetFixedSize(true) - anonymizeCheck := widget.NewCheck("Anonymize sensitive information (Public IPs, domains, ...)", nil) - systemInfoCheck := widget.NewCheck("Include system information", nil) + anonymizeCheck := widget.NewCheck("Anonymize sensitive information (public IPs, domains, ...)", nil) + systemInfoCheck := widget.NewCheck("Include system information (routes, interfaces, ...)", nil) systemInfoCheck.SetChecked(true) uploadCheck := widget.NewCheck("Upload bundle automatically after creation", nil) uploadCheck.SetChecked(true) @@ -34,11 +66,6 @@ func (s *serviceClient) showDebugUI() { uploadURL.SetText(uptypes.DefaultBundleURL) uploadURL.SetPlaceHolder("Enter upload URL") - statusLabel := widget.NewLabel("") - statusLabel.Hide() - - createButton := widget.NewButton("Create Debug Bundle", nil) - uploadURLContainer := container.NewVBox( uploadURLLabel, uploadURL, @@ -52,7 +79,71 @@ func (s *serviceClient) showDebugUI() { } } - createButton.OnTapped = s.getCreateHandler(createButton, statusLabel, uploadCheck, uploadURL, anonymizeCheck, systemInfoCheck, w) + debugModeContainer := container.NewHBox() + runForDurationCheck := widget.NewCheck("Run with trace logs before creating bundle", nil) + runForDurationCheck.SetChecked(true) + + forLabel := widget.NewLabel("for") + + durationInput := widget.NewEntry() + durationInput.SetText("1") + minutesLabel := widget.NewLabel("minute") + durationInput.Validator = func(s string) error { + return validateMinute(s, minutesLabel) + } + + noteLabel := widget.NewLabel("Note: NetBird will be brought up and down during collection") + + runForDurationCheck.OnChanged = func(checked bool) { + if checked { + forLabel.Show() + durationInput.Show() + minutesLabel.Show() + noteLabel.Show() + } else { + forLabel.Hide() + durationInput.Hide() + minutesLabel.Hide() + noteLabel.Hide() + } + } + + debugModeContainer.Add(runForDurationCheck) + debugModeContainer.Add(forLabel) + debugModeContainer.Add(durationInput) + debugModeContainer.Add(minutesLabel) + + statusLabel := widget.NewLabel("") + statusLabel.Hide() + + progressBar := widget.NewProgressBar() + progressBar.Hide() + + createButton := widget.NewButton("Create Debug Bundle", nil) + + // UI controls that should be disabled during debug collection + uiControls := []fyne.Disableable{ + anonymizeCheck, + systemInfoCheck, + uploadCheck, + uploadURL, + runForDurationCheck, + durationInput, + createButton, + } + + createButton.OnTapped = s.getCreateHandler( + statusLabel, + progressBar, + uploadCheck, + uploadURL, + anonymizeCheck, + systemInfoCheck, + runForDurationCheck, + durationInput, + uiControls, + w, + ) content := container.NewVBox( widget.NewLabel("Create a debug bundle to help troubleshoot issues with NetBird"), @@ -62,7 +153,11 @@ func (s *serviceClient) showDebugUI() { uploadCheck, uploadURLContainer, widget.NewLabel(""), + debugModeContainer, + noteLabel, + widget.NewLabel(""), statusLabel, + progressBar, createButton, ) @@ -72,18 +167,46 @@ func (s *serviceClient) showDebugUI() { w.Show() } +func validateMinute(s string, minutesLabel *widget.Label) error { + if val, err := strconv.Atoi(s); err != nil || val < 1 { + return fmt.Errorf("must be a number ≥ 1") + } + if s == "1" { + minutesLabel.SetText("minute") + } else { + minutesLabel.SetText("minutes") + } + return nil +} + +// disableUIControls disables the provided UI controls +func disableUIControls(controls []fyne.Disableable) { + for _, control := range controls { + control.Disable() + } +} + +// enableUIControls enables the provided UI controls +func enableUIControls(controls []fyne.Disableable) { + for _, control := range controls { + control.Enable() + } +} + func (s *serviceClient) getCreateHandler( - createButton *widget.Button, statusLabel *widget.Label, + progressBar *widget.ProgressBar, uploadCheck *widget.Check, uploadURL *widget.Entry, anonymizeCheck *widget.Check, systemInfoCheck *widget.Check, + runForDurationCheck *widget.Check, + duration *widget.Entry, + uiControls []fyne.Disableable, w fyne.Window, ) func() { return func() { - createButton.Disable() - statusLabel.SetText("Creating debug bundle...") + disableUIControls(uiControls) statusLabel.Show() var url string @@ -91,22 +214,329 @@ func (s *serviceClient) getCreateHandler( url = uploadURL.Text if url == "" { statusLabel.SetText("Error: Upload URL is required when upload is enabled") - createButton.Enable() + enableUIControls(uiControls) return } } - go s.handleDebugCreation(anonymizeCheck.Checked, systemInfoCheck.Checked, uploadCheck.Checked, url, statusLabel, createButton, w) + params := &debugCollectionParams{ + anonymize: anonymizeCheck.Checked, + systemInfo: systemInfoCheck.Checked, + upload: uploadCheck.Checked, + uploadURL: url, + enablePersistence: true, + } + + runForDuration := runForDurationCheck.Checked + if runForDuration { + minutes, err := time.ParseDuration(duration.Text + "m") + if err != nil { + statusLabel.SetText(fmt.Sprintf("Error: Invalid duration: %v", err)) + enableUIControls(uiControls) + return + } + params.duration = minutes + + statusLabel.SetText(fmt.Sprintf("Running in debug mode for %d minutes...", int(minutes.Minutes()))) + progressBar.Show() + progressBar.SetValue(0) + + go s.handleRunForDuration( + statusLabel, + progressBar, + uiControls, + w, + params, + ) + return + } + + statusLabel.SetText("Creating debug bundle...") + go s.handleDebugCreation( + anonymizeCheck.Checked, + systemInfoCheck.Checked, + uploadCheck.Checked, + url, + statusLabel, + uiControls, + w, + ) } } +func (s *serviceClient) handleRunForDuration( + statusLabel *widget.Label, + progressBar *widget.ProgressBar, + uiControls []fyne.Disableable, + w fyne.Window, + params *debugCollectionParams, +) { + progressUI := &progressUI{ + statusLabel: statusLabel, + progressBar: progressBar, + uiControls: uiControls, + window: w, + } + + conn, err := s.getSrvClient(failFastTimeout) + if err != nil { + handleError(progressUI, fmt.Sprintf("Failed to get client for debug: %v", err)) + return + } + + initialState, err := s.getInitialState(conn) + if err != nil { + handleError(progressUI, err.Error()) + return + } + + statusOutput, err := s.collectDebugData(conn, initialState, params, progressUI) + if err != nil { + handleError(progressUI, err.Error()) + return + } + + if err := s.createDebugBundleFromCollection(conn, params, statusOutput, progressUI); err != nil { + handleError(progressUI, err.Error()) + return + } + + s.restoreServiceState(conn, initialState) + + progressUI.statusLabel.SetText("Bundle created successfully") +} + +// Get initial state of the service +func (s *serviceClient) getInitialState(conn proto.DaemonServiceClient) (*debugInitialState, error) { + statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{}) + if err != nil { + return nil, fmt.Errorf(" get status: %v", err) + } + + logLevelResp, err := conn.GetLogLevel(s.ctx, &proto.GetLogLevelRequest{}) + if err != nil { + return nil, fmt.Errorf("get log level: %v", err) + } + + wasDown := statusResp.Status != string(internal.StatusConnected) && + statusResp.Status != string(internal.StatusConnecting) + + initialLogLevel := logLevelResp.GetLevel() + initialLevelTrace := initialLogLevel >= proto.LogLevel_TRACE + + return &debugInitialState{ + wasDown: wasDown, + logLevel: initialLogLevel, + isLevelTrace: initialLevelTrace, + }, nil +} + +// Handle progress tracking during collection +func startProgressTracker(ctx context.Context, wg *sync.WaitGroup, duration time.Duration, progress *progressUI) { + progress.progressBar.Show() + progress.progressBar.SetValue(0) + + startTime := time.Now() + endTime := startTime.Add(duration) + wg.Add(1) + + go func() { + defer wg.Done() + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + remaining := time.Until(endTime) + if remaining <= 0 { + remaining = 0 + } + + elapsed := time.Since(startTime) + progressVal := float64(elapsed) / float64(duration) + if progressVal > 1.0 { + progressVal = 1.0 + } + + progress.progressBar.SetValue(progressVal) + progress.statusLabel.SetText(fmt.Sprintf("Running with trace logs... %s remaining", formatDuration(remaining))) + } + } + }() + +} + +func (s *serviceClient) configureServiceForDebug( + conn proto.DaemonServiceClient, + state *debugInitialState, + enablePersistence bool, +) error { + if state.wasDown { + if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil { + return fmt.Errorf("bring service up: %v", err) + } + log.Info("Service brought up for debug") + time.Sleep(time.Second * 10) + } + + if !state.isLevelTrace { + if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: proto.LogLevel_TRACE}); err != nil { + return fmt.Errorf("set log level to TRACE: %v", err) + } + log.Info("Log level set to TRACE for debug") + } + + if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil { + return fmt.Errorf("bring service down: %v", err) + } + time.Sleep(time.Second) + + if enablePersistence { + if _, err := conn.SetNetworkMapPersistence(s.ctx, &proto.SetNetworkMapPersistenceRequest{ + Enabled: true, + }); err != nil { + return fmt.Errorf("enable network map persistence: %v", err) + } + log.Info("Network map persistence enabled for debug") + } + + if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil { + return fmt.Errorf("bring service back up: %v", err) + } + time.Sleep(time.Second * 3) + + return nil +} + +func (s *serviceClient) collectDebugData( + conn proto.DaemonServiceClient, + state *debugInitialState, + params *debugCollectionParams, + progress *progressUI, +) (string, error) { + ctx, cancel := context.WithTimeout(s.ctx, params.duration) + defer cancel() + var wg sync.WaitGroup + startProgressTracker(ctx, &wg, params.duration, progress) + + if err := s.configureServiceForDebug(conn, state, params.enablePersistence); err != nil { + return "", err + } + + postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) + if err != nil { + log.Warnf("Failed to get post-up status: %v", err) + } + + var postUpStatusOutput string + if postUpStatus != nil { + overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil) + postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview) + } + headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) + statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, postUpStatusOutput) + + wg.Wait() + progress.progressBar.Hide() + progress.statusLabel.SetText("Collecting debug data...") + + preDownStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) + if err != nil { + log.Warnf("Failed to get pre-down status: %v", err) + } + + var preDownStatusOutput string + if preDownStatus != nil { + overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil) + preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview) + } + headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", + time.Now().Format(time.RFC3339), params.duration) + statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, preDownStatusOutput) + + return statusOutput, nil +} + +// Create the debug bundle with collected data +func (s *serviceClient) createDebugBundleFromCollection( + conn proto.DaemonServiceClient, + params *debugCollectionParams, + statusOutput string, + progress *progressUI, +) error { + progress.statusLabel.SetText("Creating debug bundle with collected logs...") + + request := &proto.DebugBundleRequest{ + Anonymize: params.anonymize, + Status: statusOutput, + SystemInfo: params.systemInfo, + } + + if params.upload { + request.UploadURL = params.uploadURL + } + + resp, err := conn.DebugBundle(s.ctx, request) + if err != nil { + return fmt.Errorf("create debug bundle: %v", err) + } + + // Show appropriate dialog based on upload status + localPath := resp.GetPath() + uploadFailureReason := resp.GetUploadFailureReason() + uploadedKey := resp.GetUploadedKey() + + if params.upload { + if uploadFailureReason != "" { + showUploadFailedDialog(progress.window, localPath, uploadFailureReason) + } else { + showUploadSuccessDialog(progress.window, localPath, uploadedKey) + } + } else { + showBundleCreatedDialog(progress.window, localPath) + } + + enableUIControls(progress.uiControls) + return nil +} + +// Restore service to original state +func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, state *debugInitialState) { + if state.wasDown { + if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil { + log.Errorf("Failed to restore down state: %v", err) + } else { + log.Info("Service state restored to down") + } + } + + if !state.isLevelTrace { + if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: state.logLevel}); err != nil { + log.Errorf("Failed to restore log level: %v", err) + } else { + log.Info("Log level restored to original setting") + } + } +} + +// Handle errors during debug collection +func handleError(progress *progressUI, errMsg string) { + log.Errorf("%s", errMsg) + progress.statusLabel.SetText(errMsg) + progress.progressBar.Hide() + enableUIControls(progress.uiControls) +} + func (s *serviceClient) handleDebugCreation( anonymize bool, systemInfo bool, upload bool, uploadURL string, statusLabel *widget.Label, - createButton *widget.Button, + uiControls []fyne.Disableable, w fyne.Window, ) { log.Infof("Creating debug bundle (Anonymized: %v, System Info: %v, Upload Attempt: %v)...", @@ -116,7 +546,7 @@ func (s *serviceClient) handleDebugCreation( if err != nil { log.Errorf("Failed to create debug bundle: %v", err) statusLabel.SetText(fmt.Sprintf("Error creating bundle: %v", err)) - createButton.Enable() + enableUIControls(uiControls) return } @@ -134,7 +564,7 @@ func (s *serviceClient) handleDebugCreation( showBundleCreatedDialog(w, localPath) } - createButton.Enable() + enableUIControls(uiControls) statusLabel.SetText("Bundle created successfully") } @@ -173,32 +603,47 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa return resp, nil } +// formatDuration formats a duration in HH:MM:SS format +func formatDuration(d time.Duration) string { + d = d.Round(time.Second) + h := d / time.Hour + d %= time.Hour + m := d / time.Minute + d %= time.Minute + s := d / time.Second + return fmt.Sprintf("%02d:%02d:%02d", h, m, s) +} + +// createButtonWithAction creates a button with the given label and action +func createButtonWithAction(label string, action func()) *widget.Button { + button := widget.NewButton(label, action) + return button +} + // showUploadFailedDialog displays a dialog when upload fails -func showUploadFailedDialog(parent fyne.Window, localPath, failureReason string) { +func showUploadFailedDialog(w fyne.Window, localPath, failureReason string) { content := container.NewVBox( widget.NewLabel(fmt.Sprintf("Bundle upload failed:\n%s\n\n"+ "A local copy was saved at:\n%s", failureReason, localPath)), ) - customDialog := dialog.NewCustom("Upload Failed", "Cancel", content, parent) + customDialog := dialog.NewCustom("Upload Failed", "Cancel", content, w) buttonBox := container.NewHBox( - widget.NewButton("Open File", func() { + createButtonWithAction("Open file", func() { log.Infof("Attempting to open local file: %s", localPath) if openErr := open.Start(localPath); openErr != nil { log.Errorf("Failed to open local file '%s': %v", localPath, openErr) - dialog.ShowError(fmt.Errorf("Failed to open the local file:\n%s\n\nError: %v", localPath, openErr), parent) + dialog.ShowError(fmt.Errorf("open the local file:\n%s\n\nError: %v", localPath, openErr), w) } - customDialog.Hide() }), - widget.NewButton("Open Folder", func() { + createButtonWithAction("Open folder", func() { folderPath := filepath.Dir(localPath) log.Infof("Attempting to open local folder: %s", folderPath) if openErr := open.Start(folderPath); openErr != nil { log.Errorf("Failed to open local folder '%s': %v", folderPath, openErr) - dialog.ShowError(fmt.Errorf("Failed to open the local folder:\n%s\n\nError: %v", folderPath, openErr), parent) + dialog.ShowError(fmt.Errorf("open the local folder:\n%s\n\nError: %v", folderPath, openErr), w) } - customDialog.Hide() }), ) @@ -207,7 +652,8 @@ func showUploadFailedDialog(parent fyne.Window, localPath, failureReason string) } // showUploadSuccessDialog displays a dialog when upload succeeds -func showUploadSuccessDialog(parent fyne.Window, localPath, uploadedKey string) { +func showUploadSuccessDialog(w fyne.Window, localPath, uploadedKey string) { + log.Infof("Upload key: %s", uploadedKey) keyEntry := widget.NewEntry() keyEntry.SetText(uploadedKey) keyEntry.Disable() @@ -215,62 +661,63 @@ func showUploadSuccessDialog(parent fyne.Window, localPath, uploadedKey string) content := container.NewVBox( widget.NewLabel("Bundle uploaded successfully!"), widget.NewLabel(""), - widget.NewLabel("Upload Key:"), + widget.NewLabel("Upload key:"), keyEntry, widget.NewLabel(""), widget.NewLabel(fmt.Sprintf("Local copy saved at:\n%s", localPath)), ) - customDialog := dialog.NewCustom("Upload Successful", "OK", content, parent) + customDialog := dialog.NewCustom("Upload Successful", "OK", content, w) - buttonBox := container.NewHBox( - widget.NewButton("Copy Key", func() { - parent.Clipboard().SetContent(uploadedKey) - log.Info("Upload key copied to clipboard") - }), - widget.NewButton("Open Local Folder", func() { - folderPath := filepath.Dir(localPath) - log.Infof("Attempting to open local folder: %s", folderPath) - if openErr := open.Start(folderPath); openErr != nil { - log.Errorf("Failed to open local folder '%s': %v", folderPath, openErr) - dialog.ShowError(fmt.Errorf("Failed to open the local folder:\n%s\n\nError: %v", folderPath, openErr), parent) - } - }), - ) + copyBtn := createButtonWithAction("Copy key", func() { + w.Clipboard().SetContent(uploadedKey) + log.Info("Upload key copied to clipboard") + }) + buttonBox := createButtonBox(localPath, w, copyBtn) content.Add(buttonBox) customDialog.Show() } // showBundleCreatedDialog displays a dialog when bundle is created without upload -func showBundleCreatedDialog(parent fyne.Window, localPath string) { +func showBundleCreatedDialog(w fyne.Window, localPath string) { content := container.NewVBox( widget.NewLabel(fmt.Sprintf("Bundle created locally at:\n%s\n\n"+ "Administrator privileges may be required to access the file.", localPath)), ) - customDialog := dialog.NewCustom("Debug Bundle Created", "Cancel", content, parent) - - buttonBox := container.NewHBox( - widget.NewButton("Open File", func() { - log.Infof("Attempting to open local file: %s", localPath) - if openErr := open.Start(localPath); openErr != nil { - log.Errorf("Failed to open local file '%s': %v", localPath, openErr) - dialog.ShowError(fmt.Errorf("Failed to open the local file:\n%s\n\nError: %v", localPath, openErr), parent) - } - customDialog.Hide() - }), - widget.NewButton("Open Folder", func() { - folderPath := filepath.Dir(localPath) - log.Infof("Attempting to open local folder: %s", folderPath) - if openErr := open.Start(folderPath); openErr != nil { - log.Errorf("Failed to open local folder '%s': %v", folderPath, openErr) - dialog.ShowError(fmt.Errorf("Failed to open the local folder:\n%s\n\nError: %v", folderPath, openErr), parent) - } - customDialog.Hide() - }), - ) + customDialog := dialog.NewCustom("Debug Bundle Created", "Cancel", content, w) + buttonBox := createButtonBox(localPath, w, nil) content.Add(buttonBox) customDialog.Show() } + +func createButtonBox(localPath string, w fyne.Window, elems ...fyne.Widget) *fyne.Container { + box := container.NewHBox() + for _, elem := range elems { + box.Add(elem) + } + + fileBtn := createButtonWithAction("Open file", func() { + log.Infof("Attempting to open local file: %s", localPath) + if openErr := open.Start(localPath); openErr != nil { + log.Errorf("Failed to open local file '%s': %v", localPath, openErr) + dialog.ShowError(fmt.Errorf("open the local file:\n%s\n\nError: %v", localPath, openErr), w) + } + }) + + folderBtn := createButtonWithAction("Open folder", func() { + folderPath := filepath.Dir(localPath) + log.Infof("Attempting to open local folder: %s", folderPath) + if openErr := open.Start(folderPath); openErr != nil { + log.Errorf("Failed to open local folder '%s': %v", folderPath, openErr) + dialog.ShowError(fmt.Errorf("open the local folder:\n%s\n\nError: %v", folderPath, openErr), w) + } + }) + + box.Add(fileBtn) + box.Add(folderBtn) + + return box +} diff --git a/client/ui/network.go b/client/ui/network.go index ddd8d5000..435917f30 100644 --- a/client/ui/network.go +++ b/client/ui/network.go @@ -34,7 +34,8 @@ const ( type filter string func (s *serviceClient) showNetworksUI() { - s.wRoutes = s.app.NewWindow("Networks") + s.wNetworks = s.app.NewWindow("Networks") + s.wNetworks.SetOnClosed(s.cancel) allGrid := container.New(layout.NewGridLayout(3)) go s.updateNetworks(allGrid, allNetworks) @@ -78,8 +79,8 @@ func (s *serviceClient) showNetworksUI() { content := container.NewBorder(nil, buttonBox, nil, nil, scrollContainer) - s.wRoutes.SetContent(content) - s.wRoutes.Show() + s.wNetworks.SetContent(content) + s.wNetworks.Show() s.startAutoRefresh(10*time.Second, tabs, allGrid, overlappingGrid, exitNodeGrid) } @@ -148,7 +149,7 @@ func (s *serviceClient) updateNetworks(grid *fyne.Container, f filter) { grid.Add(resolvedIPsSelector) } - s.wRoutes.Content().Refresh() + s.wNetworks.Content().Refresh() grid.Refresh() } @@ -305,7 +306,7 @@ func (s *serviceClient) getNetworksRequest(f filter, appendRoute bool) *proto.Se func (s *serviceClient) showError(err error) { wrappedMessage := wrapText(err.Error(), 50) - dialog.ShowError(fmt.Errorf("%s", wrappedMessage), s.wRoutes) + dialog.ShowError(fmt.Errorf("%s", wrappedMessage), s.wNetworks) } func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) { @@ -316,14 +317,15 @@ func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container } }() - s.wRoutes.SetOnClosed(func() { + s.wNetworks.SetOnClosed(func() { ticker.Stop() + s.cancel() }) } func (s *serviceClient) updateNetworksBasedOnDisplayTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) { grid, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodesGrid) - s.wRoutes.Content().Refresh() + s.wNetworks.Content().Refresh() s.updateNetworks(grid, f) } @@ -373,7 +375,7 @@ func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) { node.Selected, ) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(s.ctx) s.mExitNodeItems = append(s.mExitNodeItems, menuHandler{ MenuItem: menuItem, cancel: cancel, From 2abb92b0d4c0bf9aa3a2c45863c0872f31f2908f Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 2 May 2025 00:25:46 +0200 Subject: [PATCH 33/45] [management] Get account id with order (#3773) updated log to display account id --- management/server/account.go | 2 +- management/server/store/sql_store.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index ab1ffe8b3..869b13f59 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -721,7 +721,7 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any) if err != nil { return nil, nil, err } - log.WithContext(ctx).Debugf("%d entries received from IdP management", len(userData)) + log.WithContext(ctx).Debugf("%d entries received from IdP management for account %s", len(userData), account.Id) dataMap := make(map[string]*idp.UserData, len(userData)) for _, datum := range userData { diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 7d3b288e0..dd39cf77d 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -802,7 +802,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) ( func (s *SqlStore) GetAnyAccountID(ctx context.Context) (string, error) { var account types.Account - result := s.db.WithContext(ctx).Select("id").Limit(1).Find(&account) + result := s.db.WithContext(ctx).Select("id").Order("created_at desc").Limit(1).Find(&account) if result.Error != nil { return "", status.NewGetAccountFromStoreError(result.Error) } From 12f883badfe00ec04ad8bd3d5dc96dca1c786289 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 2 May 2025 00:59:41 +0200 Subject: [PATCH 34/45] [management] Optimize load account (#3774) --- management/server/account.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 869b13f59..aa7cb0019 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -712,7 +712,7 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any) log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID) accountIDString := fmt.Sprintf("%v", accountID) - account, err := am.Store.GetAccount(ctx, accountIDString) + accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountIDString) if err != nil { return nil, nil, err } @@ -721,7 +721,7 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any) if err != nil { return nil, nil, err } - log.WithContext(ctx).Debugf("%d entries received from IdP management for account %s", len(userData), account.Id) + log.WithContext(ctx).Debugf("%d entries received from IdP management for account %s", len(userData), accountIDString) dataMap := make(map[string]*idp.UserData, len(userData)) for _, datum := range userData { @@ -729,7 +729,7 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any) } matchedUserData := make([]*idp.UserData, 0) - for _, user := range account.Users { + for _, user := range accountUsers { if user.IsServiceUser { continue } From 055df9854c8af522919f91eed59c8f12b730cbe5 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Sun, 4 May 2025 20:58:04 +0200 Subject: [PATCH 35/45] [management] add gorm tag for primary key for the networks objects (#3758) --- management/server/migration/migration.go | 21 ++++++++++++++++++ management/server/migration/migration_test.go | 22 +++++++++++++++++++ .../networks/resources/types/resource.go | 2 +- .../server/networks/routers/types/router.go | 2 +- management/server/networks/types/network.go | 2 +- management/server/store/sql_store.go | 4 ++++ management/server/store/sql_store_test.go | 6 ++--- management/server/store/store.go | 9 ++++++++ management/server/types/group.go | 2 +- 9 files changed, 63 insertions(+), 7 deletions(-) diff --git a/management/server/migration/migration.go b/management/server/migration/migration.go index d7abbad47..c8a852e0a 100644 --- a/management/server/migration/migration.go +++ b/management/server/migration/migration.go @@ -352,3 +352,24 @@ func MigrateNewField[T any](ctx context.Context, db *gorm.DB, columnName string, log.WithContext(ctx).Infof("Migration of empty %s to default value in table %s completed", columnName, tableName) return nil } + +func DropIndex[T any](ctx context.Context, db *gorm.DB, indexName string) error { + var model T + + if !db.Migrator().HasTable(&model) { + log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model) + return nil + } + + if !db.Migrator().HasIndex(&model, indexName) { + log.WithContext(ctx).Debugf("index %s does not exist in table %T, no migration needed", indexName, model) + return nil + } + + if err := db.Migrator().DropIndex(&model, indexName); err != nil { + return fmt.Errorf("failed to drop index %s: %w", indexName, err) + } + + log.WithContext(ctx).Infof("dropped index %s from table %T", indexName, model) + return nil +} diff --git a/management/server/migration/migration_test.go b/management/server/migration/migration_test.go index e907d6853..94377930a 100644 --- a/management/server/migration/migration_test.go +++ b/management/server/migration/migration_test.go @@ -227,3 +227,25 @@ func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case2(t *testing. assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed") } + +func TestDropIndex(t *testing.T) { + db := setupDatabase(t) + + err := db.AutoMigrate(&types.SetupKey{}) + require.NoError(t, err, "Failed to auto-migrate tables") + + err = db.Save(&types.SetupKey{ + Id: "1", + Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", + }).Error + require.NoError(t, err, "Failed to insert setup key") + + exist := db.Migrator().HasIndex(&types.SetupKey{}, "idx_setup_keys_account_id") + assert.True(t, exist, "Should have the index") + + err = migration.DropIndex[types.SetupKey](context.Background(), db, "idx_setup_keys_account_id") + require.NoError(t, err, "Migration should not fail to remove index") + + exist = db.Migrator().HasIndex(&types.SetupKey{}, "idx_setup_keys_account_id") + assert.False(t, exist, "Should not have the index") +} diff --git a/management/server/networks/resources/types/resource.go b/management/server/networks/resources/types/resource.go index ecac0a724..04c63608d 100644 --- a/management/server/networks/resources/types/resource.go +++ b/management/server/networks/resources/types/resource.go @@ -30,7 +30,7 @@ func (p NetworkResourceType) String() string { } type NetworkResource struct { - ID string `gorm:"index"` + ID string `gorm:"primaryKey"` NetworkID string `gorm:"index"` AccountID string `gorm:"index"` Name string diff --git a/management/server/networks/routers/types/router.go b/management/server/networks/routers/types/router.go index 5158ebb12..71465868f 100644 --- a/management/server/networks/routers/types/router.go +++ b/management/server/networks/routers/types/router.go @@ -10,7 +10,7 @@ import ( ) type NetworkRouter struct { - ID string `gorm:"index"` + ID string `gorm:"primaryKey"` NetworkID string `gorm:"index"` AccountID string `gorm:"index"` Peer string diff --git a/management/server/networks/types/network.go b/management/server/networks/types/network.go index a4ba7b821..d1c7f2b33 100644 --- a/management/server/networks/types/network.go +++ b/management/server/networks/types/network.go @@ -7,7 +7,7 @@ import ( ) type Network struct { - ID string `gorm:"index"` + ID string `gorm:"primaryKey"` AccountID string `gorm:"index"` Name string Description string diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index dd39cf77d..d0adad6ee 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -82,6 +82,10 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1") } conns = 1 + _, err = sql.Exec("PRAGMA foreign_keys = ON") + if err != nil { + return nil, fmt.Errorf("failed to set foreign keys for sqlite: %w", err) + } } sql.SetMaxOpenConns(conns) diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 8bd8ce098..8e99b34e1 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -60,10 +60,10 @@ func Test_NewStore(t *testing.T) { runTestForAllEngines(t, "", func(t *testing.T, store Store) { if store == nil { - t.Errorf("expected to create a new Store") + t.Fatalf("expected to create a new Store") } if len(store.GetAllAccounts(context.Background())) != 0 { - t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") + t.Fatalf("expected to create a new empty Accounts map when creating a new FileStore") } }) } @@ -1115,7 +1115,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { group := &types.Group{ ID: "group-id", - AccountID: "account-id", + AccountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", Name: "group-name", Issued: "api", Peers: nil, diff --git a/management/server/store/store.go b/management/server/store/store.go index ca332a493..6da623956 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -315,6 +315,15 @@ func getMigrations(ctx context.Context) []migrationFunc { func(db *gorm.DB) error { return migration.MigrateNewField[routerTypes.NetworkRouter](ctx, db, "enabled", true) }, + func(db *gorm.DB) error { + return migration.DropIndex[networkTypes.Network](ctx, db, "idx_networks_id") + }, + func(db *gorm.DB) error { + return migration.DropIndex[resourceTypes.NetworkResource](ctx, db, "idx_network_resources_id") + }, + func(db *gorm.DB) error { + return migration.DropIndex[routerTypes.NetworkRouter](ctx, db, "idx_network_routers_id") + }, } } diff --git a/management/server/types/group.go b/management/server/types/group.go index 00a28fa77..1b321387c 100644 --- a/management/server/types/group.go +++ b/management/server/types/group.go @@ -14,7 +14,7 @@ const ( // Group of the peers for ACL type Group struct { // ID of the group - ID string + ID string `gorm:"primaryKey"` // AccountID is a reference to Account that this object belongs AccountID string `json:"-" gorm:"index"` From ffdd115ded596f53085a37afbda9bcba7ea353a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alin=20Tr=C4=83istaru?= Date: Mon, 5 May 2025 12:20:54 +0200 Subject: [PATCH 36/45] [client] set TLS ServerName for hostname-based QUIC connections (#3673) * fix: set TLS ServerName for hostname-based QUIC connections When connecting to a relay server by hostname, certificates are validated against the IP address instead of the hostname. This change sets ServerName in the TLS config when connecting via hostname, ensuring proper certificate validation. * use default port if port is missing in URL string --- relay/client/dialer/quic/quic.go | 37 +++++++++++++++++++++++++++----- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/relay/client/dialer/quic/quic.go b/relay/client/dialer/quic/quic.go index 7fd486f87..3fd48fb19 100644 --- a/relay/client/dialer/quic/quic.go +++ b/relay/client/dialer/quic/quic.go @@ -28,6 +28,16 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { return nil, err } + // Get the base TLS config + tlsClientConfig := quictls.ClientQUICTLSConfig() + + // Set ServerName to hostname if not an IP address + host, _, splitErr := net.SplitHostPort(quicURL) + if splitErr == nil && net.ParseIP(host) == nil { + // It's a hostname, not an IP - modify directly + tlsClientConfig.ServerName = host + } + quicConfig := &quic.Config{ KeepAlivePeriod: 30 * time.Second, MaxIdleTimeout: 4 * time.Minute, @@ -47,7 +57,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { return nil, err } - session, err := quic.Dial(ctx, udpConn, udpAddr, quictls.ClientQUICTLSConfig(), quicConfig) + session, err := quic.Dial(ctx, udpConn, udpAddr, tlsClientConfig, quicConfig) if err != nil { if errors.Is(err, context.Canceled) { return nil, err @@ -61,12 +71,29 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { } func prepareURL(address string) (string, error) { - if !strings.HasPrefix(address, "rel://") && !strings.HasPrefix(address, "rels://") { + var host string + var defaultPort string + + switch { + case strings.HasPrefix(address, "rels://"): + host = address[7:] + defaultPort = "443" + case strings.HasPrefix(address, "rel://"): + host = address[6:] + defaultPort = "80" + default: return "", fmt.Errorf("unsupported scheme: %s", address) } - if strings.HasPrefix(address, "rels://") { - return address[7:], nil + finalHost, finalPort, err := net.SplitHostPort(host) + if err != nil { + if strings.Contains(err.Error(), "missing port") { + return host + ":" + defaultPort, nil + } + + // return any other split error as is + return "", err } - return address[6:], nil + + return finalHost + ":" + finalPort, nil } From 9762b39f29e63033bfbd8a5b68aa320db1ed4584 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 5 May 2025 14:29:05 +0200 Subject: [PATCH 37/45] [client] Fix stale local records (#3776) --- client/internal/dns/handler_chain_test.go | 40 +- client/internal/dns/local.go | 130 ------ client/internal/dns/local/local.go | 149 +++++++ client/internal/dns/local/local_test.go | 472 ++++++++++++++++++++++ client/internal/dns/local_test.go | 88 ---- client/internal/dns/mock_test.go | 26 -- client/internal/dns/server.go | 85 ++-- client/internal/dns/server_test.go | 160 ++++---- client/internal/dns/test/mock.go | 26 ++ client/internal/dns/types/types.go | 3 + client/internal/dns/upstream.go | 11 +- client/internal/dns/upstream_test.go | 6 +- dns/dns.go | 6 +- 13 files changed, 786 insertions(+), 416 deletions(-) delete mode 100644 client/internal/dns/local.go create mode 100644 client/internal/dns/local/local.go create mode 100644 client/internal/dns/local/local_test.go delete mode 100644 client/internal/dns/local_test.go delete mode 100644 client/internal/dns/mock_test.go create mode 100644 client/internal/dns/test/mock.go create mode 100644 client/internal/dns/types/types.go diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go index 4c910a95f..5f03e0758 100644 --- a/client/internal/dns/handler_chain_test.go +++ b/client/internal/dns/handler_chain_test.go @@ -1,7 +1,6 @@ package dns_test import ( - "net" "testing" "github.com/miekg/dns" @@ -9,6 +8,7 @@ import ( "github.com/stretchr/testify/mock" nbdns "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/dns/test" ) // TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order @@ -30,7 +30,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) { r.SetQuestion("example.com.", dns.TypeA) // Create test writer - w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // Setup expectations - only highest priority handler should be called dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once() @@ -142,7 +142,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) { r := new(dns.Msg) r.SetQuestion(tt.queryDomain, dns.TypeA) - w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} chain.ServeDNS(w, r) @@ -259,7 +259,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) { // Create and execute request r := new(dns.Msg) r.SetQuestion(tt.queryDomain, dns.TypeA) - w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} chain.ServeDNS(w, r) // Verify expectations @@ -316,7 +316,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) { }).Once() // Execute - w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} chain.ServeDNS(w, r) // Verify all handlers were called in order @@ -325,20 +325,6 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) { handler3.AssertExpectations(t) } -// mockResponseWriter implements dns.ResponseWriter for testing -type mockResponseWriter struct { - mock.Mock -} - -func (m *mockResponseWriter) LocalAddr() net.Addr { return nil } -func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil } -func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil } -func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil } -func (m *mockResponseWriter) Close() error { return nil } -func (m *mockResponseWriter) TsigStatus() error { return nil } -func (m *mockResponseWriter) TsigTimersOnly(bool) {} -func (m *mockResponseWriter) Hijack() {} - func TestHandlerChain_PriorityDeregistration(t *testing.T) { tests := []struct { name string @@ -425,7 +411,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) { // Create test request r := new(dns.Msg) r.SetQuestion(tt.query, dns.TypeA) - w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // Setup expectations for priority, handler := range handlers { @@ -471,7 +457,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain) // Test 1: Initial state - w1 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // Highest priority handler (routeHandler) should be called routeHandler.On("ServeDNS", mock.Anything, r).Return().Once() matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet @@ -490,7 +476,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { // Test 2: Remove highest priority handler chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute) - w2 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w2 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // Now middle priority handler (matchHandler) should be called matchHandler.On("ServeDNS", mock.Anything, r).Return().Once() defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet @@ -506,7 +492,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { // Test 3: Remove middle priority handler chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain) - w3 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // Now lowest priority handler (defaultHandler) should be called defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once() @@ -519,7 +505,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { // Test 4: Remove last handler chain.RemoveHandler(testDomain, nbdns.PriorityDefault) - w4 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w4 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain for _, m := range mocks { @@ -675,7 +661,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) { // Execute request r := new(dns.Msg) r.SetQuestion(tt.query, dns.TypeA) - chain.ServeDNS(&mockResponseWriter{}, r) + chain.ServeDNS(&test.MockResponseWriter{}, r) // Verify each handler was called exactly as expected for _, h := range tt.addHandlers { @@ -819,7 +805,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { r := new(dns.Msg) r.SetQuestion(tt.query, dns.TypeA) - w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // Setup handler expectations for pattern, handler := range handlers { @@ -969,7 +955,7 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) { handler := &nbdns.MockHandler{} r := new(dns.Msg) r.SetQuestion(tt.queryPattern, dns.TypeA) - w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // First verify no handler is called before adding any chain.ServeDNS(w, r) diff --git a/client/internal/dns/local.go b/client/internal/dns/local.go deleted file mode 100644 index 76e18e3ce..000000000 --- a/client/internal/dns/local.go +++ /dev/null @@ -1,130 +0,0 @@ -package dns - -import ( - "fmt" - "strings" - "sync" - - "github.com/miekg/dns" - log "github.com/sirupsen/logrus" - - nbdns "github.com/netbirdio/netbird/dns" -) - -type registrationMap map[string]struct{} - -type localResolver struct { - registeredMap registrationMap - records sync.Map // key: string (domain_class_type), value: []dns.RR -} - -func (d *localResolver) MatchSubdomains() bool { - return true -} - -func (d *localResolver) stop() { -} - -// String returns a string representation of the local resolver -func (d *localResolver) String() string { - return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap)) -} - -// ID returns the unique handler ID -func (d *localResolver) id() handlerID { - return "local-resolver" -} - -// ServeDNS handles a DNS request -func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - if len(r.Question) > 0 { - log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) - } - - replyMessage := &dns.Msg{} - replyMessage.SetReply(r) - replyMessage.RecursionAvailable = true - - // lookup all records matching the question - records := d.lookupRecords(r) - if len(records) > 0 { - replyMessage.Rcode = dns.RcodeSuccess - replyMessage.Answer = append(replyMessage.Answer, records...) - } else { - replyMessage.Rcode = dns.RcodeNameError - } - - err := w.WriteMsg(replyMessage) - if err != nil { - log.Debugf("got an error while writing the local resolver response, error: %v", err) - } -} - -// lookupRecords fetches *all* DNS records matching the first question in r. -func (d *localResolver) lookupRecords(r *dns.Msg) []dns.RR { - if len(r.Question) == 0 { - return nil - } - question := r.Question[0] - question.Name = strings.ToLower(question.Name) - key := buildRecordKey(question.Name, question.Qclass, question.Qtype) - - value, found := d.records.Load(key) - if !found { - // alternatively check if we have a cname - if question.Qtype != dns.TypeCNAME { - r.Question[0].Qtype = dns.TypeCNAME - return d.lookupRecords(r) - } - - return nil - } - - records, ok := value.([]dns.RR) - if !ok { - log.Errorf("failed to cast records to []dns.RR, records: %v", value) - return nil - } - - // if there's more than one record, rotate them (round-robin) - if len(records) > 1 { - first := records[0] - records = append(records[1:], first) - d.records.Store(key, records) - } - - return records -} - -// registerRecord stores a new record by appending it to any existing list -func (d *localResolver) registerRecord(record nbdns.SimpleRecord) (string, error) { - rr, err := dns.NewRR(record.String()) - if err != nil { - return "", fmt.Errorf("register record: %w", err) - } - - rr.Header().Rdlength = record.Len() - header := rr.Header() - key := buildRecordKey(header.Name, header.Class, header.Rrtype) - - // load any existing slice of records, then append - existing, _ := d.records.LoadOrStore(key, []dns.RR{}) - records := existing.([]dns.RR) - records = append(records, rr) - - // store updated slice - d.records.Store(key, records) - return key, nil -} - -// deleteRecord removes *all* records under the recordKey. -func (d *localResolver) deleteRecord(recordKey string) { - d.records.Delete(dns.Fqdn(recordKey)) -} - -// buildRecordKey consistently generates a key: name_class_type -func buildRecordKey(name string, class, qType uint16) string { - return fmt.Sprintf("%s_%d_%d", dns.Fqdn(name), class, qType) -} - -func (d *localResolver) probeAvailability() {} diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go new file mode 100644 index 000000000..de3d8514b --- /dev/null +++ b/client/internal/dns/local/local.go @@ -0,0 +1,149 @@ +package local + +import ( + "fmt" + "slices" + "strings" + "sync" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + + "github.com/netbirdio/netbird/client/internal/dns/types" + nbdns "github.com/netbirdio/netbird/dns" +) + +type Resolver struct { + mu sync.RWMutex + records map[dns.Question][]dns.RR +} + +func NewResolver() *Resolver { + return &Resolver{ + records: make(map[dns.Question][]dns.RR), + } +} + +func (d *Resolver) MatchSubdomains() bool { + return true +} + +// String returns a string representation of the local resolver +func (d *Resolver) String() string { + return fmt.Sprintf("local resolver [%d records]", len(d.records)) +} + +func (d *Resolver) Stop() {} + +// ID returns the unique handler ID +func (d *Resolver) ID() types.HandlerID { + return "local-resolver" +} + +func (d *Resolver) ProbeAvailability() {} + +// ServeDNS handles a DNS request +func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + if len(r.Question) == 0 { + log.Debugf("received local resolver request with no question") + return + } + question := r.Question[0] + question.Name = strings.ToLower(dns.Fqdn(question.Name)) + + log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass) + + replyMessage := &dns.Msg{} + replyMessage.SetReply(r) + replyMessage.RecursionAvailable = true + + // lookup all records matching the question + records := d.lookupRecords(question) + if len(records) > 0 { + replyMessage.Rcode = dns.RcodeSuccess + replyMessage.Answer = append(replyMessage.Answer, records...) + } else { + // TODO: return success if we have a different record type for the same name, relevant for search domains + replyMessage.Rcode = dns.RcodeNameError + } + + if err := w.WriteMsg(replyMessage); err != nil { + log.Warnf("failed to write the local resolver response: %v", err) + } +} + +// lookupRecords fetches *all* DNS records matching the first question in r. +func (d *Resolver) lookupRecords(question dns.Question) []dns.RR { + d.mu.RLock() + records, found := d.records[question] + + if !found { + d.mu.RUnlock() + // alternatively check if we have a cname + if question.Qtype != dns.TypeCNAME { + question.Qtype = dns.TypeCNAME + return d.lookupRecords(question) + } + return nil + } + + recordsCopy := slices.Clone(records) + d.mu.RUnlock() + + // if there's more than one record, rotate them (round-robin) + if len(recordsCopy) > 1 { + d.mu.Lock() + records = d.records[question] + if len(records) > 1 { + first := records[0] + records = append(records[1:], first) + d.records[question] = records + } + d.mu.Unlock() + } + + return recordsCopy +} + +func (d *Resolver) Update(update []nbdns.SimpleRecord) { + d.mu.Lock() + defer d.mu.Unlock() + + maps.Clear(d.records) + + for _, rec := range update { + if err := d.registerRecord(rec); err != nil { + log.Warnf("failed to register the record (%s): %v", rec, err) + continue + } + } +} + +// RegisterRecord stores a new record by appending it to any existing list +func (d *Resolver) RegisterRecord(record nbdns.SimpleRecord) error { + d.mu.Lock() + defer d.mu.Unlock() + + return d.registerRecord(record) +} + +// registerRecord performs the registration with the lock already held +func (d *Resolver) registerRecord(record nbdns.SimpleRecord) error { + rr, err := dns.NewRR(record.String()) + if err != nil { + return fmt.Errorf("register record: %w", err) + } + + rr.Header().Rdlength = record.Len() + header := rr.Header() + q := dns.Question{ + Name: strings.ToLower(dns.Fqdn(header.Name)), + Qtype: header.Rrtype, + Qclass: header.Class, + } + + d.records[q] = append(d.records[q], rr) + + return nil +} diff --git a/client/internal/dns/local/local_test.go b/client/internal/dns/local/local_test.go new file mode 100644 index 000000000..1d38191e7 --- /dev/null +++ b/client/internal/dns/local/local_test.go @@ -0,0 +1,472 @@ +package local + +import ( + "strings" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/dns/test" + nbdns "github.com/netbirdio/netbird/dns" +) + +func TestLocalResolver_ServeDNS(t *testing.T) { + recordA := nbdns.SimpleRecord{ + Name: "peera.netbird.cloud.", + Type: 1, + Class: nbdns.DefaultClass, + TTL: 300, + RData: "1.2.3.4", + } + + recordCNAME := nbdns.SimpleRecord{ + Name: "peerb.netbird.cloud.", + Type: 5, + Class: nbdns.DefaultClass, + TTL: 300, + RData: "www.netbird.io", + } + + testCases := []struct { + name string + inputRecord nbdns.SimpleRecord + inputMSG *dns.Msg + responseShouldBeNil bool + }{ + { + name: "Should Resolve A Record", + inputRecord: recordA, + inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA), + }, + { + name: "Should Resolve CNAME Record", + inputRecord: recordCNAME, + inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME), + }, + { + name: "Should Not Write When Not Found A Record", + inputRecord: recordA, + inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA), + responseShouldBeNil: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + resolver := NewResolver() + _ = resolver.RegisterRecord(testCase.inputRecord) + var responseMSG *dns.Msg + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + resolver.ServeDNS(responseWriter, testCase.inputMSG) + + if responseMSG == nil || len(responseMSG.Answer) == 0 { + if testCase.responseShouldBeNil { + return + } + t.Fatalf("should write a response message") + } + + answerString := responseMSG.Answer[0].String() + if !strings.Contains(answerString, testCase.inputRecord.Name) { + t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString) + } + if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) { + t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString) + } + if !strings.Contains(answerString, testCase.inputRecord.RData) { + t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString) + } + }) + } +} + +// TestLocalResolver_Update_StaleRecord verifies that updating +// a record correctly replaces the old one, preventing stale entries. +func TestLocalResolver_Update_StaleRecord(t *testing.T) { + recordName := "host.example.com." + recordType := dns.TypeA + recordClass := dns.ClassINET + + record1 := nbdns.SimpleRecord{ + Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1", + } + record2 := nbdns.SimpleRecord{ + Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "2.2.2.2", + } + + recordKey := dns.Question{Name: recordName, Qtype: uint16(recordClass), Qclass: recordType} + + resolver := NewResolver() + + update1 := []nbdns.SimpleRecord{record1} + update2 := []nbdns.SimpleRecord{record2} + + // Apply first update + resolver.Update(update1) + + // Verify first update + resolver.mu.RLock() + rrSlice1, found1 := resolver.records[recordKey] + resolver.mu.RUnlock() + + require.True(t, found1, "Record key %s not found after first update", recordKey) + require.Len(t, rrSlice1, 1, "Should have exactly 1 record after first update") + assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData) + + // Apply second update + resolver.Update(update2) + + // Verify second update + resolver.mu.RLock() + rrSlice2, found2 := resolver.records[recordKey] + resolver.mu.RUnlock() + + require.True(t, found2, "Record key %s not found after second update", recordKey) + require.Len(t, rrSlice2, 1, "Should have exactly 1 record after update overwriting the key") + assert.Contains(t, rrSlice2[0].String(), record2.RData, "The single record should be the updated one (%s)", record2.RData) + assert.NotContains(t, rrSlice2[0].String(), record1.RData, "The stale record (%s) should not be present", record1.RData) +} + +// TestLocalResolver_MultipleRecords_SameQuestion verifies that multiple records +// with the same question are stored properly +func TestLocalResolver_MultipleRecords_SameQuestion(t *testing.T) { + resolver := NewResolver() + + recordName := "multi.example.com." + recordType := dns.TypeA + + // Create two records with the same name and type but different IPs + record1 := nbdns.SimpleRecord{ + Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1", + } + record2 := nbdns.SimpleRecord{ + Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2", + } + + update := []nbdns.SimpleRecord{record1, record2} + + // Apply update with both records + resolver.Update(update) + + // Create question that matches both records + question := dns.Question{ + Name: recordName, + Qtype: recordType, + Qclass: dns.ClassINET, + } + + // Verify both records are stored + resolver.mu.RLock() + records, found := resolver.records[question] + resolver.mu.RUnlock() + + require.True(t, found, "Records for question %v not found", question) + require.Len(t, records, 2, "Should have exactly 2 records for the same question") + + // Verify both record data values are present + recordStrings := []string{records[0].String(), records[1].String()} + assert.Contains(t, recordStrings[0]+recordStrings[1], record1.RData, "First record data should be present") + assert.Contains(t, recordStrings[0]+recordStrings[1], record2.RData, "Second record data should be present") +} + +// TestLocalResolver_RecordRotation verifies that records are rotated in a round-robin fashion +func TestLocalResolver_RecordRotation(t *testing.T) { + resolver := NewResolver() + + recordName := "rotation.example.com." + recordType := dns.TypeA + + // Create three records with the same name and type but different IPs + record1 := nbdns.SimpleRecord{ + Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1", + } + record2 := nbdns.SimpleRecord{ + Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.2", + } + record3 := nbdns.SimpleRecord{ + Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.3", + } + + update := []nbdns.SimpleRecord{record1, record2, record3} + + // Apply update with all three records + resolver.Update(update) + + msg := new(dns.Msg).SetQuestion(recordName, recordType) + + // First lookup - should return the records in original order + var responses [3]*dns.Msg + + // Perform three lookups to verify rotation + for i := 0; i < 3; i++ { + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responses[i] = m + return nil + }, + } + + resolver.ServeDNS(responseWriter, msg) + } + + // Verify all three responses contain answers + for i, resp := range responses { + require.NotNil(t, resp, "Response %d should not be nil", i) + require.Len(t, resp.Answer, 3, "Response %d should have 3 answers", i) + } + + // Verify the first record in each response is different due to rotation + firstRecordIPs := []string{ + responses[0].Answer[0].String(), + responses[1].Answer[0].String(), + responses[2].Answer[0].String(), + } + + // Each record should be different (rotated) + assert.NotEqual(t, firstRecordIPs[0], firstRecordIPs[1], "First lookup should differ from second lookup due to rotation") + assert.NotEqual(t, firstRecordIPs[1], firstRecordIPs[2], "Second lookup should differ from third lookup due to rotation") + assert.NotEqual(t, firstRecordIPs[0], firstRecordIPs[2], "First lookup should differ from third lookup due to rotation") + + // After three rotations, we should have cycled through all records + assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record1.RData) + assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record2.RData) + assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record3.RData) +} + +// TestLocalResolver_CaseInsensitiveMatching verifies that DNS record lookups are case-insensitive +func TestLocalResolver_CaseInsensitiveMatching(t *testing.T) { + resolver := NewResolver() + + // Create record with lowercase name + lowerCaseRecord := nbdns.SimpleRecord{ + Name: "lower.example.com.", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "10.10.10.10", + } + + // Create record with mixed case name + mixedCaseRecord := nbdns.SimpleRecord{ + Name: "MiXeD.ExAmPlE.CoM.", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "20.20.20.20", + } + + // Update resolver with the records + resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord}) + + testCases := []struct { + name string + queryName string + expectedRData string + shouldResolve bool + }{ + { + name: "Query lowercase with lowercase record", + queryName: "lower.example.com.", + expectedRData: "10.10.10.10", + shouldResolve: true, + }, + { + name: "Query uppercase with lowercase record", + queryName: "LOWER.EXAMPLE.COM.", + expectedRData: "10.10.10.10", + shouldResolve: true, + }, + { + name: "Query mixed case with lowercase record", + queryName: "LoWeR.eXaMpLe.CoM.", + expectedRData: "10.10.10.10", + shouldResolve: true, + }, + { + name: "Query lowercase with mixed case record", + queryName: "mixed.example.com.", + expectedRData: "20.20.20.20", + shouldResolve: true, + }, + { + name: "Query uppercase with mixed case record", + queryName: "MIXED.EXAMPLE.COM.", + expectedRData: "20.20.20.20", + shouldResolve: true, + }, + { + name: "Query with different casing pattern", + queryName: "mIxEd.ExaMpLe.cOm.", + expectedRData: "20.20.20.20", + shouldResolve: true, + }, + { + name: "Query non-existent domain", + queryName: "nonexistent.example.com.", + shouldResolve: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var responseMSG *dns.Msg + + // Create DNS query with the test case name + msg := new(dns.Msg).SetQuestion(tc.queryName, dns.TypeA) + + // Create mock response writer to capture the response + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + // Perform DNS query + resolver.ServeDNS(responseWriter, msg) + + // Check if we expect a successful resolution + if !tc.shouldResolve { + if responseMSG == nil || len(responseMSG.Answer) == 0 { + // Expected no answer, test passes + return + } + t.Fatalf("Expected no resolution for %s, but got answer: %v", tc.queryName, responseMSG.Answer) + } + + // Verify we got a response + require.NotNil(t, responseMSG, "Should have received a response message") + require.Greater(t, len(responseMSG.Answer), 0, "Response should contain at least one answer") + + // Verify the response contains the expected data + answerString := responseMSG.Answer[0].String() + assert.Contains(t, answerString, tc.expectedRData, + "Answer should contain the expected IP address %s, got: %s", + tc.expectedRData, answerString) + }) + } +} + +// TestLocalResolver_CNAMEFallback verifies that the resolver correctly falls back +// to checking for CNAME records when the requested record type isn't found +func TestLocalResolver_CNAMEFallback(t *testing.T) { + resolver := NewResolver() + + // Create a CNAME record (but no A record for this name) + cnameRecord := nbdns.SimpleRecord{ + Name: "alias.example.com.", + Type: int(dns.TypeCNAME), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "target.example.com.", + } + + // Create an A record for the CNAME target + targetRecord := nbdns.SimpleRecord{ + Name: "target.example.com.", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "192.168.100.100", + } + + // Update resolver with both records + resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord}) + + testCases := []struct { + name string + queryName string + queryType uint16 + expectedType string + expectedRData string + shouldResolve bool + }{ + { + name: "Directly query CNAME record", + queryName: "alias.example.com.", + queryType: dns.TypeCNAME, + expectedType: "CNAME", + expectedRData: "target.example.com.", + shouldResolve: true, + }, + { + name: "Query A record but get CNAME fallback", + queryName: "alias.example.com.", + queryType: dns.TypeA, + expectedType: "CNAME", + expectedRData: "target.example.com.", + shouldResolve: true, + }, + { + name: "Query AAAA record but get CNAME fallback", + queryName: "alias.example.com.", + queryType: dns.TypeAAAA, + expectedType: "CNAME", + expectedRData: "target.example.com.", + shouldResolve: true, + }, + { + name: "Query direct A record", + queryName: "target.example.com.", + queryType: dns.TypeA, + expectedType: "A", + expectedRData: "192.168.100.100", + shouldResolve: true, + }, + { + name: "Query non-existent name", + queryName: "nonexistent.example.com.", + queryType: dns.TypeA, + shouldResolve: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var responseMSG *dns.Msg + + // Create DNS query with the test case parameters + msg := new(dns.Msg).SetQuestion(tc.queryName, tc.queryType) + + // Create mock response writer to capture the response + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + // Perform DNS query + resolver.ServeDNS(responseWriter, msg) + + // Check if we expect a successful resolution + if !tc.shouldResolve { + if responseMSG == nil || len(responseMSG.Answer) == 0 || responseMSG.Rcode != dns.RcodeSuccess { + // Expected no resolution, test passes + return + } + t.Fatalf("Expected no resolution for %s, but got answer: %v", tc.queryName, responseMSG.Answer) + } + + // Verify we got a successful response + require.NotNil(t, responseMSG, "Should have received a response message") + require.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "Response should have success status code") + require.Greater(t, len(responseMSG.Answer), 0, "Response should contain at least one answer") + + // Verify the response contains the expected data + answerString := responseMSG.Answer[0].String() + assert.Contains(t, answerString, tc.expectedType, + "Answer should be of type %s, got: %s", tc.expectedType, answerString) + assert.Contains(t, answerString, tc.expectedRData, + "Answer should contain the expected data %s, got: %s", tc.expectedRData, answerString) + }) + } +} diff --git a/client/internal/dns/local_test.go b/client/internal/dns/local_test.go deleted file mode 100644 index 0a42b321a..000000000 --- a/client/internal/dns/local_test.go +++ /dev/null @@ -1,88 +0,0 @@ -package dns - -import ( - "strings" - "testing" - - "github.com/miekg/dns" - - nbdns "github.com/netbirdio/netbird/dns" -) - -func TestLocalResolver_ServeDNS(t *testing.T) { - recordA := nbdns.SimpleRecord{ - Name: "peera.netbird.cloud.", - Type: 1, - Class: nbdns.DefaultClass, - TTL: 300, - RData: "1.2.3.4", - } - - recordCNAME := nbdns.SimpleRecord{ - Name: "peerb.netbird.cloud.", - Type: 5, - Class: nbdns.DefaultClass, - TTL: 300, - RData: "www.netbird.io", - } - - testCases := []struct { - name string - inputRecord nbdns.SimpleRecord - inputMSG *dns.Msg - responseShouldBeNil bool - }{ - { - name: "Should Resolve A Record", - inputRecord: recordA, - inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA), - }, - { - name: "Should Resolve CNAME Record", - inputRecord: recordCNAME, - inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME), - }, - { - name: "Should Not Write When Not Found A Record", - inputRecord: recordA, - inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA), - responseShouldBeNil: true, - }, - } - - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - resolver := &localResolver{ - registeredMap: make(registrationMap), - } - _, _ = resolver.registerRecord(testCase.inputRecord) - var responseMSG *dns.Msg - responseWriter := &mockResponseWriter{ - WriteMsgFunc: func(m *dns.Msg) error { - responseMSG = m - return nil - }, - } - - resolver.ServeDNS(responseWriter, testCase.inputMSG) - - if responseMSG == nil || len(responseMSG.Answer) == 0 { - if testCase.responseShouldBeNil { - return - } - t.Fatalf("should write a response message") - } - - answerString := responseMSG.Answer[0].String() - if !strings.Contains(answerString, testCase.inputRecord.Name) { - t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString) - } - if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) { - t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString) - } - if !strings.Contains(answerString, testCase.inputRecord.RData) { - t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString) - } - }) - } -} diff --git a/client/internal/dns/mock_test.go b/client/internal/dns/mock_test.go deleted file mode 100644 index d52ae24da..000000000 --- a/client/internal/dns/mock_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package dns - -import ( - "net" - - "github.com/miekg/dns" -) - -type mockResponseWriter struct { - WriteMsgFunc func(m *dns.Msg) error -} - -func (rw *mockResponseWriter) WriteMsg(m *dns.Msg) error { - if rw.WriteMsgFunc != nil { - return rw.WriteMsgFunc(m) - } - return nil -} - -func (rw *mockResponseWriter) LocalAddr() net.Addr { return nil } -func (rw *mockResponseWriter) RemoteAddr() net.Addr { return nil } -func (rw *mockResponseWriter) Write([]byte) (int, error) { return 0, nil } -func (rw *mockResponseWriter) Close() error { return nil } -func (rw *mockResponseWriter) TsigStatus() error { return nil } -func (rw *mockResponseWriter) TsigTimersOnly(bool) {} -func (rw *mockResponseWriter) Hijack() {} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 65b90e5f0..3f49c23fd 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -15,6 +15,8 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/internal/dns/local" + "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -46,8 +48,6 @@ type Server interface { ProbeAvailability() } -type handlerID string - type nsGroupsByDomain struct { domain string groups []*nbdns.NameServerGroup @@ -61,7 +61,7 @@ type DefaultServer struct { mux sync.Mutex service service dnsMuxMap registeredHandlerMap - localResolver *localResolver + localResolver *local.Resolver wgInterface WGIface hostManager hostManager updateSerial uint64 @@ -84,9 +84,9 @@ type DefaultServer struct { type handlerWithStop interface { dns.Handler - stop() - probeAvailability() - id() handlerID + Stop() + ProbeAvailability() + ID() types.HandlerID } type handlerWrapper struct { @@ -95,7 +95,7 @@ type handlerWrapper struct { priority int } -type registeredHandlerMap map[handlerID]handlerWrapper +type registeredHandlerMap map[types.HandlerID]handlerWrapper // NewDefaultServer returns a new dns server func NewDefaultServer( @@ -171,16 +171,14 @@ func newDefaultServer( handlerChain := NewHandlerChain() ctx, stop := context.WithCancel(ctx) defaultServer := &DefaultServer{ - ctx: ctx, - ctxCancel: stop, - disableSys: disableSys, - service: dnsService, - handlerChain: handlerChain, - extraDomains: make(map[domain.Domain]int), - dnsMuxMap: make(registeredHandlerMap), - localResolver: &localResolver{ - registeredMap: make(registrationMap), - }, + ctx: ctx, + ctxCancel: stop, + disableSys: disableSys, + service: dnsService, + handlerChain: handlerChain, + extraDomains: make(map[domain.Domain]int), + dnsMuxMap: make(registeredHandlerMap), + localResolver: local.NewResolver(), wgInterface: wgInterface, statusRecorder: statusRecorder, stateManager: stateManager, @@ -403,7 +401,7 @@ func (s *DefaultServer) ProbeAvailability() { wg.Add(1) go func(mux handlerWithStop) { defer wg.Done() - mux.probeAvailability() + mux.ProbeAvailability() }(mux.handler) } wg.Wait() @@ -420,7 +418,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { s.service.Stop() } - localMuxUpdates, localRecordsByDomain, err := s.buildLocalHandlerUpdate(update.CustomZones) + localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones) if err != nil { return fmt.Errorf("local handler updater: %w", err) } @@ -434,7 +432,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { s.updateMux(muxUpdates) // register local records - s.updateLocalResolver(localRecordsByDomain) + s.localResolver.Update(localRecords) s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort()) @@ -516,11 +514,9 @@ func (s *DefaultServer) handleErrNoGroupaAll(err error) { ) } -func (s *DefaultServer) buildLocalHandlerUpdate( - customZones []nbdns.CustomZone, -) ([]handlerWrapper, map[string][]nbdns.SimpleRecord, error) { +func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) { var muxUpdates []handlerWrapper - localRecords := make(map[string][]nbdns.SimpleRecord) + var localRecords []nbdns.SimpleRecord for _, customZone := range customZones { if len(customZone.Records) == 0 { @@ -534,17 +530,13 @@ func (s *DefaultServer) buildLocalHandlerUpdate( priority: PriorityMatchDomain, }) - // group all records under this domain for _, record := range customZone.Records { - var class uint16 = dns.ClassINET if record.Class != nbdns.DefaultClass { log.Warnf("received an invalid class type: %s", record.Class) continue } - - key := buildRecordKey(record.Name, class, uint16(record.Type)) - - localRecords[key] = append(localRecords[key], record) + // zone records contain the fqdn, so we can just flatten them + localRecords = append(localRecords, record) } } @@ -627,7 +619,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai } if len(handler.upstreamServers) == 0 { - handler.stop() + handler.Stop() log.Errorf("received a nameserver group with an invalid nameserver list") continue } @@ -656,7 +648,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { // this will introduce a short period of time when the server is not able to handle DNS requests for _, existing := range s.dnsMuxMap { s.deregisterHandler([]string{existing.domain}, existing.priority) - existing.handler.stop() + existing.handler.Stop() } muxUpdateMap := make(registeredHandlerMap) @@ -667,7 +659,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { containsRootUpdate = true } s.registerHandler([]string{update.domain}, update.handler, update.priority) - muxUpdateMap[update.handler.id()] = update + muxUpdateMap[update.handler.ID()] = update } // If there's no root update and we had a root handler, restore it @@ -683,33 +675,6 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { s.dnsMuxMap = muxUpdateMap } -func (s *DefaultServer) updateLocalResolver(update map[string][]nbdns.SimpleRecord) { - // remove old records that are no longer present - for key := range s.localResolver.registeredMap { - _, found := update[key] - if !found { - s.localResolver.deleteRecord(key) - } - } - - updatedMap := make(registrationMap) - for _, recs := range update { - for _, rec := range recs { - // convert the record to a dns.RR and register - key, err := s.localResolver.registerRecord(rec) - if err != nil { - log.Warnf("got an error while registering the record (%s), error: %v", - rec.String(), err) - continue - } - - updatedMap[key] = struct{}{} - } - } - - s.localResolver.registeredMap = updatedMap -} - func getNSHostPort(ns nbdns.NameServer) string { return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port) } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index ed69b0e93..1c7c9b117 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -23,6 +23,9 @@ import ( "github.com/netbirdio/netbird/client/iface/device" pfmock "github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/internal/dns/local" + "github.com/netbirdio/netbird/client/internal/dns/test" + "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -107,6 +110,7 @@ func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamRe } func TestUpdateDNSServer(t *testing.T) { + nameServers := []nbdns.NameServer{ { IP: netip.MustParseAddr("8.8.8.8"), @@ -120,22 +124,21 @@ func TestUpdateDNSServer(t *testing.T) { }, } - dummyHandler := &localResolver{} + dummyHandler := local.NewResolver() testCases := []struct { name string initUpstreamMap registeredHandlerMap - initLocalMap registrationMap + initLocalRecords []nbdns.SimpleRecord initSerial uint64 inputSerial uint64 inputUpdate nbdns.Config shouldFail bool expectedUpstreamMap registeredHandlerMap - expectedLocalMap registrationMap + expectedLocalQs []dns.Question }{ { name: "Initial Config Should Succeed", - initLocalMap: make(registrationMap), initUpstreamMap: make(registeredHandlerMap), initSerial: 0, inputSerial: 1, @@ -159,30 +162,30 @@ func TestUpdateDNSServer(t *testing.T) { }, }, expectedUpstreamMap: registeredHandlerMap{ - generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{ + generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{ domain: "netbird.io", handler: dummyHandler, priority: PriorityMatchDomain, }, - dummyHandler.id(): handlerWrapper{ + dummyHandler.ID(): handlerWrapper{ domain: "netbird.cloud", handler: dummyHandler, priority: PriorityMatchDomain, }, - generateDummyHandler(".", nameServers).id(): handlerWrapper{ + generateDummyHandler(".", nameServers).ID(): handlerWrapper{ domain: nbdns.RootZone, handler: dummyHandler, priority: PriorityDefault, }, }, - expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, + expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}, }, { - name: "New Config Should Succeed", - initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, + name: "New Config Should Succeed", + initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, initUpstreamMap: registeredHandlerMap{ - generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{ - domain: buildRecordKey(zoneRecords[0].Name, 1, 1), + generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ + domain: "netbird.cloud", handler: dummyHandler, priority: PriorityMatchDomain, }, @@ -205,7 +208,7 @@ func TestUpdateDNSServer(t *testing.T) { }, }, expectedUpstreamMap: registeredHandlerMap{ - generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{ + generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{ domain: "netbird.io", handler: dummyHandler, priority: PriorityMatchDomain, @@ -216,22 +219,22 @@ func TestUpdateDNSServer(t *testing.T) { priority: PriorityMatchDomain, }, }, - expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, + expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}}, }, { - name: "Smaller Config Serial Should Be Skipped", - initLocalMap: make(registrationMap), - initUpstreamMap: make(registeredHandlerMap), - initSerial: 2, - inputSerial: 1, - shouldFail: true, + name: "Smaller Config Serial Should Be Skipped", + initLocalRecords: []nbdns.SimpleRecord{}, + initUpstreamMap: make(registeredHandlerMap), + initSerial: 2, + inputSerial: 1, + shouldFail: true, }, { - name: "Empty NS Group Domain Or Not Primary Element Should Fail", - initLocalMap: make(registrationMap), - initUpstreamMap: make(registeredHandlerMap), - initSerial: 0, - inputSerial: 1, + name: "Empty NS Group Domain Or Not Primary Element Should Fail", + initLocalRecords: []nbdns.SimpleRecord{}, + initUpstreamMap: make(registeredHandlerMap), + initSerial: 0, + inputSerial: 1, inputUpdate: nbdns.Config{ ServiceEnable: true, CustomZones: []nbdns.CustomZone{ @@ -249,11 +252,11 @@ func TestUpdateDNSServer(t *testing.T) { shouldFail: true, }, { - name: "Invalid NS Group Nameservers list Should Fail", - initLocalMap: make(registrationMap), - initUpstreamMap: make(registeredHandlerMap), - initSerial: 0, - inputSerial: 1, + name: "Invalid NS Group Nameservers list Should Fail", + initLocalRecords: []nbdns.SimpleRecord{}, + initUpstreamMap: make(registeredHandlerMap), + initSerial: 0, + inputSerial: 1, inputUpdate: nbdns.Config{ ServiceEnable: true, CustomZones: []nbdns.CustomZone{ @@ -271,11 +274,11 @@ func TestUpdateDNSServer(t *testing.T) { shouldFail: true, }, { - name: "Invalid Custom Zone Records list Should Skip", - initLocalMap: make(registrationMap), - initUpstreamMap: make(registeredHandlerMap), - initSerial: 0, - inputSerial: 1, + name: "Invalid Custom Zone Records list Should Skip", + initLocalRecords: []nbdns.SimpleRecord{}, + initUpstreamMap: make(registeredHandlerMap), + initSerial: 0, + inputSerial: 1, inputUpdate: nbdns.Config{ ServiceEnable: true, CustomZones: []nbdns.CustomZone{ @@ -290,17 +293,17 @@ func TestUpdateDNSServer(t *testing.T) { }, }, }, - expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).id(): handlerWrapper{ + expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{ domain: ".", handler: dummyHandler, priority: PriorityDefault, }}, }, { - name: "Empty Config Should Succeed and Clean Maps", - initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, + name: "Empty Config Should Succeed and Clean Maps", + initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, initUpstreamMap: registeredHandlerMap{ - generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{ + generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ domain: zoneRecords[0].Name, handler: dummyHandler, priority: PriorityMatchDomain, @@ -310,13 +313,13 @@ func TestUpdateDNSServer(t *testing.T) { inputSerial: 1, inputUpdate: nbdns.Config{ServiceEnable: true}, expectedUpstreamMap: make(registeredHandlerMap), - expectedLocalMap: make(registrationMap), + expectedLocalQs: []dns.Question{}, }, { - name: "Disabled Service Should clean map", - initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, + name: "Disabled Service Should clean map", + initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, initUpstreamMap: registeredHandlerMap{ - generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{ + generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ domain: zoneRecords[0].Name, handler: dummyHandler, priority: PriorityMatchDomain, @@ -326,7 +329,7 @@ func TestUpdateDNSServer(t *testing.T) { inputSerial: 1, inputUpdate: nbdns.Config{ServiceEnable: false}, expectedUpstreamMap: make(registeredHandlerMap), - expectedLocalMap: make(registrationMap), + expectedLocalQs: []dns.Question{}, }, } @@ -377,7 +380,7 @@ func TestUpdateDNSServer(t *testing.T) { }() dnsServer.dnsMuxMap = testCase.initUpstreamMap - dnsServer.localResolver.registeredMap = testCase.initLocalMap + dnsServer.localResolver.Update(testCase.initLocalRecords) dnsServer.updateSerial = testCase.initSerial err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate) @@ -399,15 +402,23 @@ func TestUpdateDNSServer(t *testing.T) { } } - if len(dnsServer.localResolver.registeredMap) != len(testCase.expectedLocalMap) { - t.Fatalf("update local failed, registered map size is different than expected, want %d, got %d", len(testCase.expectedLocalMap), len(dnsServer.localResolver.registeredMap)) + var responseMSG *dns.Msg + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + for _, q := range testCase.expectedLocalQs { + dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{ + Question: []dns.Question{q}, + }) } - for key := range testCase.expectedLocalMap { - _, found := dnsServer.localResolver.registeredMap[key] - if !found { - t.Fatalf("update local failed, key %s was not found in the localResolver.registeredMap: %#v", key, dnsServer.localResolver.registeredMap) - } + if len(testCase.expectedLocalQs) > 0 { + assert.NotNil(t, responseMSG, "response message should not be nil") + assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success") + assert.NotEmpty(t, responseMSG.Answer, "response message should have answers") } }) } @@ -491,11 +502,12 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { dnsServer.dnsMuxMap = registeredHandlerMap{ "id1": handlerWrapper{ domain: zoneRecords[0].Name, - handler: &localResolver{}, + handler: &local.Resolver{}, priority: PriorityMatchDomain, }, } - dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}} + //dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}} + dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}) dnsServer.updateSerial = 0 nameServers := []nbdns.NameServer{ @@ -582,7 +594,7 @@ func TestDNSServerStartStop(t *testing.T) { } time.Sleep(100 * time.Millisecond) defer dnsServer.Stop() - _, err = dnsServer.localResolver.registerRecord(zoneRecords[0]) + err = dnsServer.localResolver.RegisterRecord(zoneRecords[0]) if err != nil { t.Error(err) } @@ -630,13 +642,11 @@ func TestDNSServerStartStop(t *testing.T) { func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { hostManager := &mockHostConfigurator{} server := DefaultServer{ - ctx: context.Background(), - service: NewServiceViaMemory(&mocWGIface{}), - localResolver: &localResolver{ - registeredMap: make(registrationMap), - }, - handlerChain: NewHandlerChain(), - hostManager: hostManager, + ctx: context.Background(), + service: NewServiceViaMemory(&mocWGIface{}), + localResolver: local.NewResolver(), + handlerChain: NewHandlerChain(), + hostManager: hostManager, currentConfig: HostDNSConfig{ Domains: []DomainConfig{ {false, "domain0", false}, @@ -1004,7 +1014,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) { t.Run(tc.name, func(t *testing.T) { r := new(dns.Msg) r.SetQuestion(tc.query, dns.TypeA) - w := &ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + w := &ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} if mh, ok := tc.expectedHandler.(*MockHandler); ok { mh.On("ServeDNS", mock.Anything, r).Once() @@ -1037,9 +1047,9 @@ type mockHandler struct { } func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {} -func (m *mockHandler) stop() {} -func (m *mockHandler) probeAvailability() {} -func (m *mockHandler) id() handlerID { return handlerID(m.Id) } +func (m *mockHandler) Stop() {} +func (m *mockHandler) ProbeAvailability() {} +func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) } type mockService struct{} @@ -1113,7 +1123,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { name string initialHandlers registeredHandlerMap updates []handlerWrapper - expectedHandlers map[string]string // map[handlerID]domain + expectedHandlers map[string]string // map[HandlerID]domain description string }{ { @@ -1409,7 +1419,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { // Check each expected handler for id, expectedDomain := range tt.expectedHandlers { - handler, exists := server.dnsMuxMap[handlerID(id)] + handler, exists := server.dnsMuxMap[types.HandlerID(id)] assert.True(t, exists, "Expected handler %s not found", id) if exists { assert.Equal(t, expectedDomain, handler.domain, @@ -1418,9 +1428,9 @@ func TestDefaultServer_UpdateMux(t *testing.T) { } // Verify no unexpected handlers exist - for handlerID := range server.dnsMuxMap { - _, expected := tt.expectedHandlers[string(handlerID)] - assert.True(t, expected, "Unexpected handler found: %s", handlerID) + for HandlerID := range server.dnsMuxMap { + _, expected := tt.expectedHandlers[string(HandlerID)] + assert.True(t, expected, "Unexpected handler found: %s", HandlerID) } // Verify the handlerChain state and order @@ -1696,7 +1706,7 @@ func TestExtraDomains(t *testing.T) { handlerChain: NewHandlerChain(), wgInterface: &mocWGIface{}, hostManager: mockHostConfig, - localResolver: &localResolver{}, + localResolver: &local.Resolver{}, service: mockSvc, statusRecorder: peer.NewRecorder("test"), extraDomains: make(map[domain.Domain]int), @@ -1781,7 +1791,7 @@ func TestExtraDomainsRefCounting(t *testing.T) { ctx: context.Background(), handlerChain: NewHandlerChain(), hostManager: mockHostConfig, - localResolver: &localResolver{}, + localResolver: &local.Resolver{}, service: mockSvc, statusRecorder: peer.NewRecorder("test"), extraDomains: make(map[domain.Domain]int), @@ -1833,7 +1843,7 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) { ctx: context.Background(), handlerChain: NewHandlerChain(), hostManager: mockHostConfig, - localResolver: &localResolver{}, + localResolver: &local.Resolver{}, service: mockSvc, statusRecorder: peer.NewRecorder("test"), extraDomains: make(map[domain.Domain]int), @@ -1916,7 +1926,7 @@ func TestDomainCaseHandling(t *testing.T) { ctx: context.Background(), handlerChain: NewHandlerChain(), hostManager: mockHostConfig, - localResolver: &localResolver{}, + localResolver: &local.Resolver{}, service: mockSvc, statusRecorder: peer.NewRecorder("test"), extraDomains: make(map[domain.Domain]int), diff --git a/client/internal/dns/test/mock.go b/client/internal/dns/test/mock.go new file mode 100644 index 000000000..1db452805 --- /dev/null +++ b/client/internal/dns/test/mock.go @@ -0,0 +1,26 @@ +package test + +import ( + "net" + + "github.com/miekg/dns" +) + +type MockResponseWriter struct { + WriteMsgFunc func(m *dns.Msg) error +} + +func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error { + if rw.WriteMsgFunc != nil { + return rw.WriteMsgFunc(m) + } + return nil +} + +func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil } +func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil } +func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil } +func (rw *MockResponseWriter) Close() error { return nil } +func (rw *MockResponseWriter) TsigStatus() error { return nil } +func (rw *MockResponseWriter) TsigTimersOnly(bool) {} +func (rw *MockResponseWriter) Hijack() {} diff --git a/client/internal/dns/types/types.go b/client/internal/dns/types/types.go new file mode 100644 index 000000000..5a8be03b7 --- /dev/null +++ b/client/internal/dns/types/types.go @@ -0,0 +1,3 @@ +package types + +type HandlerID string diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index fa69d4934..2fbfb3b91 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -19,6 +19,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/proto" ) @@ -81,21 +82,21 @@ func (u *upstreamResolverBase) String() string { } // ID returns the unique handler ID -func (u *upstreamResolverBase) id() handlerID { +func (u *upstreamResolverBase) ID() types.HandlerID { servers := slices.Clone(u.upstreamServers) slices.Sort(servers) hash := sha256.New() hash.Write([]byte(u.domain + ":")) hash.Write([]byte(strings.Join(servers, ","))) - return handlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8])) + return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8])) } func (u *upstreamResolverBase) MatchSubdomains() bool { return true } -func (u *upstreamResolverBase) stop() { +func (u *upstreamResolverBase) Stop() { log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers) u.cancel() } @@ -198,9 +199,9 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) { ) } -// probeAvailability tests all upstream servers simultaneously and +// ProbeAvailability tests all upstream servers simultaneously and // disables the resolver if none work -func (u *upstreamResolverBase) probeAvailability() { +func (u *upstreamResolverBase) ProbeAvailability() { u.mutex.Lock() defer u.mutex.Unlock() diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index 5dbcc9f79..13bc91a37 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -8,6 +8,8 @@ import ( "time" "github.com/miekg/dns" + + "github.com/netbirdio/netbird/client/internal/dns/test" ) func TestUpstreamResolver_ServeDNS(t *testing.T) { @@ -66,7 +68,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { } var responseMSG *dns.Msg - responseWriter := &mockResponseWriter{ + responseWriter := &test.MockResponseWriter{ WriteMsgFunc: func(m *dns.Msg) error { responseMSG = m return nil @@ -130,7 +132,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { resolver.failsTillDeact = 0 resolver.reactivatePeriod = time.Microsecond * 100 - responseWriter := &mockResponseWriter{ + responseWriter := &test.MockResponseWriter{ WriteMsgFunc: func(m *dns.Msg) error { return nil }, } diff --git a/dns/dns.go b/dns/dns.go index 3a1c76e56..f889a32ec 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -66,17 +66,17 @@ func (s SimpleRecord) String() string { func (s SimpleRecord) Len() uint16 { emptyString := s.RData == "" switch s.Type { - case 1: + case int(dns.TypeA): if emptyString { return 0 } return net.IPv4len - case 5: + case int(dns.TypeCNAME): if emptyString || s.RData == "." { return 1 } return uint16(len(s.RData) + 1) - case 28: + case int(dns.TypeAAAA): if emptyString { return 0 } From 59faaa99f691d5ed9f743e7aca298ecc9af621d1 Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Mon, 5 May 2025 17:05:01 +0300 Subject: [PATCH 38/45] [client] Improve NetBird installation script to handle daemon connection timeout (#3761) [client] Improve NetBird installation script to handle daemon connection timeout --- release_files/install.sh | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/release_files/install.sh b/release_files/install.sh index 49e313f2f..da5c613d5 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -224,16 +224,22 @@ check_use_bin_variable() { install_netbird() { if [ -x "$(command -v netbird)" ]; then - status_output=$(netbird status) - if echo "$status_output" | grep -q 'Management: Connected' && echo "$status_output" | grep -q 'Signal: Connected'; then - echo "NetBird service is running, please stop it before proceeding" - exit 1 - fi + status_output="$(netbird status 2>&1 || true)" - if [ -n "$status_output" ]; then - echo "NetBird seems to be installed already, please remove it before proceeding" - exit 1 - fi + if echo "$status_output" | grep -q 'failed to connect to daemon error: context deadline exceeded'; then + echo "Warning: could not reach NetBird daemon (timeout), proceeding anyway" + else + if echo "$status_output" | grep -q 'Management: Connected' && \ + echo "$status_output" | grep -q 'Signal: Connected'; then + echo "NetBird service is running, please stop it before proceeding" + exit 1 + fi + + if [ -n "$status_output" ]; then + echo "NetBird seems to be installed already, please remove it before proceeding" + exit 1 + fi + fi fi # Run the installation, if a desktop environment is not detected From 25faf9283da3361f77b1892fe52847235d340305 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 5 May 2025 18:21:48 +0200 Subject: [PATCH 39/45] [management] removal of foreign key constraint enforcement on sqlite (#3786) --- management/server/store/sql_store.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index d0adad6ee..dd39cf77d 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -82,10 +82,6 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1") } conns = 1 - _, err = sql.Exec("PRAGMA foreign_keys = ON") - if err != nil { - return nil, fmt.Errorf("failed to set foreign keys for sqlite: %w", err) - } } sql.SetMaxOpenConns(conns) From ac135ab11dd4001b40bb47db97fcedbdde259f88 Mon Sep 17 00:00:00 2001 From: "M. Essam" Date: Mon, 5 May 2025 19:54:47 +0300 Subject: [PATCH 40/45] [management/client/rest] fix panic when body is nil (#3714) Fixes panic occurring when body is nil (this usually happens when connections is refused) due to lack of nil check by centralizing response.Body.Close() behavior. --- management/client/rest/accounts.go | 12 +++-- management/client/rest/dns.go | 28 +++++++++--- management/client/rest/events.go | 4 +- management/client/rest/geo.go | 8 +++- management/client/rest/groups.go | 20 ++++++--- management/client/rest/networks.go | 60 ++++++++++++++++++------- management/client/rest/peers.go | 20 ++++++--- management/client/rest/policies.go | 20 ++++++--- management/client/rest/posturechecks.go | 20 ++++++--- management/client/rest/routes.go | 20 ++++++--- management/client/rest/setupkeys.go | 20 ++++++--- management/client/rest/tokens.go | 16 +++++-- management/client/rest/users.go | 24 +++++++--- 13 files changed, 204 insertions(+), 68 deletions(-) diff --git a/management/client/rest/accounts.go b/management/client/rest/accounts.go index f38b19f70..29d4ac79d 100644 --- a/management/client/rest/accounts.go +++ b/management/client/rest/accounts.go @@ -20,7 +20,9 @@ func (a *AccountsAPI) List(ctx context.Context) ([]api.Account, error) { if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.Account](resp) return ret, err } @@ -36,7 +38,9 @@ func (a *AccountsAPI) Update(ctx context.Context, accountID string, request api. if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Account](resp) return &ret, err } @@ -48,7 +52,9 @@ func (a *AccountsAPI) Delete(ctx context.Context, accountID string) error { if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } diff --git a/management/client/rest/dns.go b/management/client/rest/dns.go index ef9923b1f..0e2d15842 100644 --- a/management/client/rest/dns.go +++ b/management/client/rest/dns.go @@ -20,7 +20,9 @@ func (a *DNSAPI) ListNameserverGroups(ctx context.Context) ([]api.NameserverGrou if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.NameserverGroup](resp) return ret, err } @@ -32,7 +34,9 @@ func (a *DNSAPI) GetNameserverGroup(ctx context.Context, nameserverGroupID strin if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.NameserverGroup](resp) return &ret, err } @@ -48,7 +52,9 @@ func (a *DNSAPI) CreateNameserverGroup(ctx context.Context, request api.PostApiD if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.NameserverGroup](resp) return &ret, err } @@ -64,7 +70,9 @@ func (a *DNSAPI) UpdateNameserverGroup(ctx context.Context, nameserverGroupID st if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.NameserverGroup](resp) return &ret, err } @@ -76,7 +84,9 @@ func (a *DNSAPI) DeleteNameserverGroup(ctx context.Context, nameserverGroupID st if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } @@ -88,7 +98,9 @@ func (a *DNSAPI) GetSettings(ctx context.Context) (*api.DNSSettings, error) { if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.DNSSettings](resp) return &ret, err } @@ -104,7 +116,9 @@ func (a *DNSAPI) UpdateSettings(ctx context.Context, request api.PutApiDnsSettin if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.DNSSettings](resp) return &ret, err } diff --git a/management/client/rest/events.go b/management/client/rest/events.go index 1157700ff..ed74fae39 100644 --- a/management/client/rest/events.go +++ b/management/client/rest/events.go @@ -18,7 +18,9 @@ func (a *EventsAPI) List(ctx context.Context) ([]api.Event, error) { if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.Event](resp) return ret, err } diff --git a/management/client/rest/geo.go b/management/client/rest/geo.go index ed9090fe2..0bdcc0a22 100644 --- a/management/client/rest/geo.go +++ b/management/client/rest/geo.go @@ -18,7 +18,9 @@ func (a *GeoLocationAPI) ListCountries(ctx context.Context) ([]api.Country, erro if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.Country](resp) return ret, err } @@ -30,7 +32,9 @@ func (a *GeoLocationAPI) ListCountryCities(ctx context.Context, countryCode stri if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.City](resp) return ret, err } diff --git a/management/client/rest/groups.go b/management/client/rest/groups.go index feb664273..aac453b93 100644 --- a/management/client/rest/groups.go +++ b/management/client/rest/groups.go @@ -20,7 +20,9 @@ func (a *GroupsAPI) List(ctx context.Context) ([]api.Group, error) { if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.Group](resp) return ret, err } @@ -32,7 +34,9 @@ func (a *GroupsAPI) Get(ctx context.Context, groupID string) (*api.Group, error) if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Group](resp) return &ret, err } @@ -48,7 +52,9 @@ func (a *GroupsAPI) Create(ctx context.Context, request api.PostApiGroupsJSONReq if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Group](resp) return &ret, err } @@ -64,7 +70,9 @@ func (a *GroupsAPI) Update(ctx context.Context, groupID string, request api.PutA if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Group](resp) return &ret, err } @@ -76,7 +84,9 @@ func (a *GroupsAPI) Delete(ctx context.Context, groupID string) error { if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } diff --git a/management/client/rest/networks.go b/management/client/rest/networks.go index 2cdd6d73d..b211312c9 100644 --- a/management/client/rest/networks.go +++ b/management/client/rest/networks.go @@ -20,7 +20,9 @@ func (a *NetworksAPI) List(ctx context.Context) ([]api.Network, error) { if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.Network](resp) return ret, err } @@ -32,7 +34,9 @@ func (a *NetworksAPI) Get(ctx context.Context, networkID string) (*api.Network, if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Network](resp) return &ret, err } @@ -48,7 +52,9 @@ func (a *NetworksAPI) Create(ctx context.Context, request api.PostApiNetworksJSO if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Network](resp) return &ret, err } @@ -64,7 +70,9 @@ func (a *NetworksAPI) Update(ctx context.Context, networkID string, request api. if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Network](resp) return &ret, err } @@ -76,7 +84,9 @@ func (a *NetworksAPI) Delete(ctx context.Context, networkID string) error { if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } @@ -102,7 +112,9 @@ func (a *NetworkResourcesAPI) List(ctx context.Context) ([]api.NetworkResource, if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.NetworkResource](resp) return ret, err } @@ -114,7 +126,9 @@ func (a *NetworkResourcesAPI) Get(ctx context.Context, networkResourceID string) if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.NetworkResource](resp) return &ret, err } @@ -130,7 +144,9 @@ func (a *NetworkResourcesAPI) Create(ctx context.Context, request api.PostApiNet if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.NetworkResource](resp) return &ret, err } @@ -146,7 +162,9 @@ func (a *NetworkResourcesAPI) Update(ctx context.Context, networkResourceID stri if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.NetworkResource](resp) return &ret, err } @@ -158,7 +176,9 @@ func (a *NetworkResourcesAPI) Delete(ctx context.Context, networkResourceID stri if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } @@ -184,7 +204,9 @@ func (a *NetworkRoutersAPI) List(ctx context.Context) ([]api.NetworkRouter, erro if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.NetworkRouter](resp) return ret, err } @@ -196,7 +218,9 @@ func (a *NetworkRoutersAPI) Get(ctx context.Context, networkRouterID string) (*a if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.NetworkRouter](resp) return &ret, err } @@ -212,7 +236,9 @@ func (a *NetworkRoutersAPI) Create(ctx context.Context, request api.PostApiNetwo if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.NetworkRouter](resp) return &ret, err } @@ -228,7 +254,9 @@ func (a *NetworkRoutersAPI) Update(ctx context.Context, networkRouterID string, if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.NetworkRouter](resp) return &ret, err } @@ -240,7 +268,9 @@ func (a *NetworkRoutersAPI) Delete(ctx context.Context, networkRouterID string) if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } diff --git a/management/client/rest/peers.go b/management/client/rest/peers.go index 9d35f013c..2b1a65b4c 100644 --- a/management/client/rest/peers.go +++ b/management/client/rest/peers.go @@ -20,7 +20,9 @@ func (a *PeersAPI) List(ctx context.Context) ([]api.Peer, error) { if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.Peer](resp) return ret, err } @@ -32,7 +34,9 @@ func (a *PeersAPI) Get(ctx context.Context, peerID string) (*api.Peer, error) { if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Peer](resp) return &ret, err } @@ -48,7 +52,9 @@ func (a *PeersAPI) Update(ctx context.Context, peerID string, request api.PutApi if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Peer](resp) return &ret, err } @@ -60,7 +66,9 @@ func (a *PeersAPI) Delete(ctx context.Context, peerID string) error { if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } @@ -72,7 +80,9 @@ func (a *PeersAPI) ListAccessiblePeers(ctx context.Context, peerID string) ([]ap if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.Peer](resp) return ret, err } diff --git a/management/client/rest/policies.go b/management/client/rest/policies.go index be6abafaf..975a95440 100644 --- a/management/client/rest/policies.go +++ b/management/client/rest/policies.go @@ -20,7 +20,9 @@ func (a *PoliciesAPI) List(ctx context.Context) ([]api.Policy, error) { if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.Policy](resp) return ret, err } @@ -32,7 +34,9 @@ func (a *PoliciesAPI) Get(ctx context.Context, policyID string) (*api.Policy, er if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Policy](resp) return &ret, err } @@ -48,7 +52,9 @@ func (a *PoliciesAPI) Create(ctx context.Context, request api.PostApiPoliciesJSO if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Policy](resp) return &ret, err } @@ -64,7 +70,9 @@ func (a *PoliciesAPI) Update(ctx context.Context, policyID string, request api.P if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Policy](resp) return &ret, err } @@ -76,7 +84,9 @@ func (a *PoliciesAPI) Delete(ctx context.Context, policyID string) error { if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } diff --git a/management/client/rest/posturechecks.go b/management/client/rest/posturechecks.go index 950d17ba0..7343957a5 100644 --- a/management/client/rest/posturechecks.go +++ b/management/client/rest/posturechecks.go @@ -20,7 +20,9 @@ func (a *PostureChecksAPI) List(ctx context.Context) ([]api.PostureCheck, error) if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.PostureCheck](resp) return ret, err } @@ -32,7 +34,9 @@ func (a *PostureChecksAPI) Get(ctx context.Context, postureCheckID string) (*api if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.PostureCheck](resp) return &ret, err } @@ -48,7 +52,9 @@ func (a *PostureChecksAPI) Create(ctx context.Context, request api.PostApiPostur if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.PostureCheck](resp) return &ret, err } @@ -64,7 +70,9 @@ func (a *PostureChecksAPI) Update(ctx context.Context, postureCheckID string, re if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.PostureCheck](resp) return &ret, err } @@ -76,7 +84,9 @@ func (a *PostureChecksAPI) Delete(ctx context.Context, postureCheckID string) er if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } diff --git a/management/client/rest/routes.go b/management/client/rest/routes.go index bccbb8847..6ca4be2c5 100644 --- a/management/client/rest/routes.go +++ b/management/client/rest/routes.go @@ -20,7 +20,9 @@ func (a *RoutesAPI) List(ctx context.Context) ([]api.Route, error) { if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.Route](resp) return ret, err } @@ -32,7 +34,9 @@ func (a *RoutesAPI) Get(ctx context.Context, routeID string) (*api.Route, error) if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Route](resp) return &ret, err } @@ -48,7 +52,9 @@ func (a *RoutesAPI) Create(ctx context.Context, request api.PostApiRoutesJSONReq if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Route](resp) return &ret, err } @@ -64,7 +70,9 @@ func (a *RoutesAPI) Update(ctx context.Context, routeID string, request api.PutA if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.Route](resp) return &ret, err } @@ -76,7 +84,9 @@ func (a *RoutesAPI) Delete(ctx context.Context, routeID string) error { if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } diff --git a/management/client/rest/setupkeys.go b/management/client/rest/setupkeys.go index 645614fcf..91f370663 100644 --- a/management/client/rest/setupkeys.go +++ b/management/client/rest/setupkeys.go @@ -20,7 +20,9 @@ func (a *SetupKeysAPI) List(ctx context.Context) ([]api.SetupKey, error) { if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.SetupKey](resp) return ret, err } @@ -32,7 +34,9 @@ func (a *SetupKeysAPI) Get(ctx context.Context, setupKeyID string) (*api.SetupKe if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.SetupKey](resp) return &ret, err } @@ -48,7 +52,9 @@ func (a *SetupKeysAPI) Create(ctx context.Context, request api.PostApiSetupKeysJ if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.SetupKeyClear](resp) return &ret, err } @@ -64,7 +70,9 @@ func (a *SetupKeysAPI) Update(ctx context.Context, setupKeyID string, request ap if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.SetupKey](resp) return &ret, err } @@ -76,7 +84,9 @@ func (a *SetupKeysAPI) Delete(ctx context.Context, setupKeyID string) error { if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } diff --git a/management/client/rest/tokens.go b/management/client/rest/tokens.go index 3275bea81..7e5004147 100644 --- a/management/client/rest/tokens.go +++ b/management/client/rest/tokens.go @@ -20,7 +20,9 @@ func (a *TokensAPI) List(ctx context.Context, userID string) ([]api.PersonalAcce if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.PersonalAccessToken](resp) return ret, err } @@ -32,7 +34,9 @@ func (a *TokensAPI) Get(ctx context.Context, userID, tokenID string) (*api.Perso if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.PersonalAccessToken](resp) return &ret, err } @@ -48,7 +52,9 @@ func (a *TokensAPI) Create(ctx context.Context, userID string, request api.PostA if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.PersonalAccessTokenGenerated](resp) return &ret, err } @@ -60,7 +66,9 @@ func (a *TokensAPI) Delete(ctx context.Context, userID, tokenID string) error { if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } diff --git a/management/client/rest/users.go b/management/client/rest/users.go index 31ffad051..bb81796c0 100644 --- a/management/client/rest/users.go +++ b/management/client/rest/users.go @@ -20,7 +20,9 @@ func (a *UsersAPI) List(ctx context.Context) ([]api.User, error) { if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[[]api.User](resp) return ret, err } @@ -36,7 +38,9 @@ func (a *UsersAPI) Create(ctx context.Context, request api.PostApiUsersJSONReque if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.User](resp) return &ret, err } @@ -52,7 +56,9 @@ func (a *UsersAPI) Update(ctx context.Context, userID string, request api.PutApi if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.User](resp) return &ret, err } @@ -64,7 +70,9 @@ func (a *UsersAPI) Delete(ctx context.Context, userID string) error { if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } @@ -76,7 +84,9 @@ func (a *UsersAPI) ResendInvitation(ctx context.Context, userID string) error { if err != nil { return err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } return nil } @@ -88,7 +98,9 @@ func (a *UsersAPI) Current(ctx context.Context) (*api.User, error) { if err != nil { return nil, err } - defer resp.Body.Close() + if resp.Body != nil { + defer resp.Body.Close() + } ret, err := parseResponse[api.User](resp) return &ret, err From ebda0fc5385d26ebdbd83904f6d72e36f11ce156 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Tue, 6 May 2025 18:31:03 +0300 Subject: [PATCH 41/45] [management] Delete service users with account manager (#3793) --- management/server/account.go | 8 ++++-- management/server/account_test.go | 44 +++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index aa7cb0019..5c474a343 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -603,11 +603,15 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u } for _, otherUser := range account.Users { - if otherUser.IsServiceUser { + if otherUser.Id == userID { continue } - if otherUser.Id == userID { + if otherUser.IsServiceUser { + err = am.deleteServiceUser(ctx, accountID, userID, otherUser) + if err != nil { + return err + } continue } diff --git a/management/server/account_test.go b/management/server/account_test.go index fe082d9a0..c5583d226 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -853,6 +853,42 @@ func TestAccountManager_DeleteAccount(t *testing.T) { t.Fatal(err) } + account.Users["service-user-1"] = &types.User{ + Id: "service-user-1", + Role: types.UserRoleAdmin, + IsServiceUser: true, + Issued: types.UserIssuedAPI, + PATs: map[string]*types.PersonalAccessToken{ + "pat-1": { + ID: "pat-1", + UserID: "service-user-1", + Name: "service-user-1", + HashedToken: "hashedToken", + CreatedAt: time.Now(), + }, + }, + } + account.Users[userId] = &types.User{ + Id: "service-user-2", + Role: types.UserRoleUser, + IsServiceUser: true, + Issued: types.UserIssuedAPI, + PATs: map[string]*types.PersonalAccessToken{ + "pat-2": { + ID: "pat-2", + UserID: userId, + Name: userId, + HashedToken: "hashedToken", + CreatedAt: time.Now(), + }, + }, + } + + err = manager.Store.SaveAccount(context.Background(), account) + if err != nil { + t.Fatal(err) + } + err = manager.DeleteAccount(context.Background(), account.Id, userId) if err != nil { t.Fatal(err) @@ -862,6 +898,14 @@ func TestAccountManager_DeleteAccount(t *testing.T) { if err == nil { t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount)) } + + pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, "service-user-1") + require.NoError(t, err) + assert.Len(t, pats, 0) + + pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, userId) + require.NoError(t, err) + assert.Len(t, pats, 0) } func BenchmarkTest_GetAccountWithclaims(b *testing.B) { From fcd2c15a37320064ef29aa036ebf32a8d5b69714 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 7 May 2025 07:25:25 +0200 Subject: [PATCH 42/45] [management] policy delete cleans policy rules (#3788) --- management/server/store/sql_store.go | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index dd39cf77d..d568460f9 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -1683,18 +1683,26 @@ func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, } func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&types.Policy{}, accountAndIDQueryCondition, accountID, policyID) - if err := result.Error; err != nil { - log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err) - return status.Errorf(status.Internal, "failed to delete policy from store") - } + return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := tx.Where("policy_id = ?", policyID).Delete(&types.PolicyRule{}).Error; err != nil { + return fmt.Errorf("delete policy rules: %w", err) + } - if result.RowsAffected == 0 { - return status.NewPolicyNotFoundError(policyID) - } + result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}). + Where(accountAndIDQueryCondition, accountID, policyID). + Delete(&types.Policy{}) - return nil + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err) + return status.Errorf(status.Internal, "failed to delete policy from store") + } + + if result.RowsAffected == 0 { + return status.NewPolicyNotFoundError(policyID) + } + + return nil + }) } // GetAccountPostureChecks retrieves posture checks for an account. From cad2fe1f39f7ce850692feb4a857477717325337 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 9 May 2025 13:56:27 +0200 Subject: [PATCH 43/45] Return with the correct copied length (#3804) --- sharedsock/sock_linux.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go index 74ac6c163..1c22e7869 100644 --- a/sharedsock/sock_linux.go +++ b/sharedsock/sock_linux.go @@ -234,7 +234,7 @@ func (s *SharedSocket) read(receiver receiver) { } // ReadFrom reads packets received in the packetDemux channel -func (s *SharedSocket) ReadFrom(b []byte) (n int, addr net.Addr, err error) { +func (s *SharedSocket) ReadFrom(b []byte) (int, net.Addr, error) { var pkt rcvdPacket select { case <-s.ctx.Done(): @@ -263,8 +263,7 @@ func (s *SharedSocket) ReadFrom(b []byte) (n int, addr net.Addr, err error) { decodedLayers := make([]gopacket.LayerType, 0, 3) - err = parser.DecodeLayers(pkt.buf, &decodedLayers) - if err != nil { + if err := parser.DecodeLayers(pkt.buf, &decodedLayers); err != nil { return 0, nil, err } @@ -273,8 +272,8 @@ func (s *SharedSocket) ReadFrom(b []byte) (n int, addr net.Addr, err error) { Port: int(udp.SrcPort), } - copy(b, payload) - return int(udp.Length), remoteAddr, nil + n := copy(b, payload) + return n, remoteAddr, nil } // WriteTo builds a UDP packet and writes it using the specific IP version writer From d5b52e86b6386e7dab2d370d9f49da7e96893e3b Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 9 May 2025 14:01:21 +0200 Subject: [PATCH 44/45] [client] Ignore irrelevant route changes to tracked network monitor routes (#3796) --- .../networkmonitor/check_change_bsd.go | 2 +- .../networkmonitor/check_change_windows.go | 56 ++- .../check_change_windows_test.go | 404 ++++++++++++++++++ client/internal/networkmonitor/monitor.go | 5 +- .../routemanager/systemops/systemops.go | 15 + .../systemops/systemops_windows.go | 11 +- 6 files changed, 464 insertions(+), 29 deletions(-) create mode 100644 client/internal/networkmonitor/check_change_windows_test.go diff --git a/client/internal/networkmonitor/check_change_bsd.go b/client/internal/networkmonitor/check_change_bsd.go index bb327a877..f5eb2c739 100644 --- a/client/internal/networkmonitor/check_change_bsd.go +++ b/client/internal/networkmonitor/check_change_bsd.go @@ -19,7 +19,7 @@ import ( func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) if err != nil { - return fmt.Errorf("failed to open routing socket: %v", err) + return fmt.Errorf("open routing socket: %v", err) } defer func() { err := unix.Close(fd) diff --git a/client/internal/networkmonitor/check_change_windows.go b/client/internal/networkmonitor/check_change_windows.go index 582865738..814584863 100644 --- a/client/internal/networkmonitor/check_change_windows.go +++ b/client/internal/networkmonitor/check_change_windows.go @@ -13,7 +13,7 @@ import ( func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { routeMonitor, err := systemops.NewRouteMonitor(ctx) if err != nil { - return fmt.Errorf("failed to create route monitor: %w", err) + return fmt.Errorf("create route monitor: %w", err) } defer func() { if err := routeMonitor.Stop(); err != nil { @@ -38,35 +38,49 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er } func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop) bool { - intf := "" - if route.Interface != nil { - intf = route.Interface.Name - if isSoftInterface(intf) { - log.Debugf("Network monitor: ignoring default route change for soft interface %s", intf) - return false - } + if intf := route.NextHop.Intf; intf != nil && isSoftInterface(intf.Name) { + log.Debugf("Network monitor: ignoring default route change for next hop with soft interface %s", route.NextHop) + return false + } + + // TODO: for the empty nexthop ip (on-link), determine the family differently + nexthop := nexthopv4 + if route.NextHop.IP.Is6() { + nexthop = nexthopv6 } switch route.Type { - case systemops.RouteModified: - // TODO: get routing table to figure out if our route is affected for modified routes - log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf) - return true - case systemops.RouteAdded: - if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP { - log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf) - return true - } + case systemops.RouteModified, systemops.RouteAdded: + return handleRouteAddedOrModified(route, nexthop) case systemops.RouteDeleted: - if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP { - log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf) - return true - } + return handleRouteDeleted(route, nexthop) } return false } +func handleRouteAddedOrModified(route systemops.RouteUpdate, nexthop systemops.Nexthop) bool { + // For added/modified routes, we care about different next hops + if !nexthop.Equal(route.NextHop) { + action := "changed" + if route.Type == systemops.RouteAdded { + action = "added" + } + log.Infof("Network monitor: default route %s: via %s", action, route.NextHop) + return true + } + return false +} + +func handleRouteDeleted(route systemops.RouteUpdate, nexthop systemops.Nexthop) bool { + // For deleted routes, we care about our tracked next hop being deleted + if nexthop.Equal(route.NextHop) { + log.Infof("Network monitor: default route removed: via %s", route.NextHop) + return true + } + return false +} + func isSoftInterface(name string) bool { return strings.Contains(strings.ToLower(name), "isatap") || strings.Contains(strings.ToLower(name), "teredo") } diff --git a/client/internal/networkmonitor/check_change_windows_test.go b/client/internal/networkmonitor/check_change_windows_test.go new file mode 100644 index 000000000..29ff34dca --- /dev/null +++ b/client/internal/networkmonitor/check_change_windows_test.go @@ -0,0 +1,404 @@ +package networkmonitor + +import ( + "net" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" +) + +func TestRouteChanged(t *testing.T) { + tests := []struct { + name string + route systemops.RouteUpdate + nexthopv4 systemops.Nexthop + nexthopv6 systemops.Nexthop + expected bool + }{ + { + name: "soft interface should be ignored", + route: systemops.RouteUpdate{ + Type: systemops.RouteModified, + Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{ + Name: "ISATAP-Interface", // isSoftInterface checks name + }, + }, + }, + nexthopv4: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.2"), + }, + nexthopv6: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + }, + expected: false, + }, + { + name: "modified route with different v4 nexthop IP should return true", + route: systemops.RouteUpdate{ + Type: systemops.RouteModified, + Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{ + Index: 1, Name: "eth0", + }, + }, + }, + nexthopv4: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.2"), + Intf: &net.Interface{ + Index: 1, Name: "eth0", + }, + }, + nexthopv6: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + }, + expected: true, + }, + { + name: "modified route with same v4 nexthop (IP and Intf Index) should return false", + route: systemops.RouteUpdate{ + Type: systemops.RouteModified, + Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{ + Index: 1, Name: "eth0", + }, + }, + }, + nexthopv4: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{ + Index: 1, Name: "eth0", + }, + }, + nexthopv6: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + }, + expected: false, + }, + { + name: "added route with different v6 nexthop IP should return true", + route: systemops.RouteUpdate{ + Type: systemops.RouteAdded, + Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::2"), + Intf: &net.Interface{ + Index: 1, Name: "eth0", + }, + }, + }, + nexthopv4: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + }, + nexthopv6: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{ + Index: 1, Name: "eth0", + }, + }, + expected: true, + }, + { + name: "added route with same v6 nexthop (IP and Intf Index) should return false", + route: systemops.RouteUpdate{ + Type: systemops.RouteAdded, + Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{ + Index: 1, Name: "eth0", + }, + }, + }, + nexthopv4: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + }, + nexthopv6: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{ + Index: 1, Name: "eth0", + }, + }, + expected: false, + }, + { + name: "deleted route matching tracked v4 nexthop (IP and Intf Index) should return true", + route: systemops.RouteUpdate{ + Type: systemops.RouteDeleted, + Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{ + Index: 1, Name: "eth0", + }, + }, + }, + nexthopv4: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{ + Index: 1, Name: "eth0", + }, + }, + nexthopv6: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + }, + expected: true, + }, + { + name: "deleted route not matching tracked v4 nexthop (different IP) should return false", + route: systemops.RouteUpdate{ + Type: systemops.RouteDeleted, + Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.3"), // Different IP + Intf: &net.Interface{ + Index: 1, Name: "eth0", + }, + }, + }, + nexthopv4: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{ + Index: 1, Name: "eth0", + }, + }, + nexthopv6: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + }, + expected: false, + }, + { + name: "modified v4 route with same IP, different Intf Index should return true", + route: systemops.RouteUpdate{ + Type: systemops.RouteModified, + Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index + }, + }, + nexthopv4: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + expected: true, + }, + { + name: "modified v4 route with same IP, one Intf nil, other non-nil should return true", + route: systemops.RouteUpdate{ + Type: systemops.RouteModified, + Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: nil, // Intf is nil + }, + }, + nexthopv4: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, // Tracked Intf is not nil + }, + expected: true, + }, + { + name: "added v4 route with same IP, different Intf Index should return true", + route: systemops.RouteUpdate{ + Type: systemops.RouteAdded, + Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index + }, + }, + nexthopv4: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + expected: true, + }, + { + name: "deleted v4 route with same IP, different Intf Index should return false", + route: systemops.RouteUpdate{ + Type: systemops.RouteDeleted, + Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0), + NextHop: systemops.Nexthop{ // This is the route being deleted + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + }, + nexthopv4: systemops.Nexthop{ // This is our tracked nexthop + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index + }, + expected: false, // Because nexthopv4.Equal(route.NextHop) will be false + }, + { + name: "modified v6 route with different IP, same Intf Index should return true", + route: systemops.RouteUpdate{ + Type: systemops.RouteModified, + Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::3"), // Different IP + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + }, + nexthopv6: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + expected: true, + }, + { + name: "modified v6 route with same IP, different Intf Index should return true", + route: systemops.RouteUpdate{ + Type: systemops.RouteModified, + Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index + }, + }, + nexthopv6: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + expected: true, + }, + { + name: "modified v6 route with same IP, same Intf Index should return false", + route: systemops.RouteUpdate{ + Type: systemops.RouteModified, + Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + }, + nexthopv6: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + expected: false, + }, + { + name: "deleted v6 route matching tracked nexthop (IP and Intf Index) should return true", + route: systemops.RouteUpdate{ + Type: systemops.RouteDeleted, + Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + }, + nexthopv6: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + expected: true, + }, + { + name: "deleted v6 route not matching tracked nexthop (different IP) should return false", + route: systemops.RouteUpdate{ + Type: systemops.RouteDeleted, + Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::3"), // Different IP + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + }, + nexthopv6: systemops.Nexthop{ + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + expected: false, + }, + { + name: "deleted v6 route not matching tracked nexthop (same IP, different Intf Index) should return false", + route: systemops.RouteUpdate{ + Type: systemops.RouteDeleted, + Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0), + NextHop: systemops.Nexthop{ // This is the route being deleted + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + }, + nexthopv6: systemops.Nexthop{ // This is our tracked nexthop + IP: netip.MustParseAddr("2001:db8::1"), + Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index + }, + expected: false, + }, + { + name: "unknown route type should return false", + route: systemops.RouteUpdate{ + Type: systemops.RouteUpdateType(99), // Unknown type + Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0), + NextHop: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.1"), + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + }, + nexthopv4: systemops.Nexthop{ + IP: netip.MustParseAddr("192.168.1.2"), // Different from route.NextHop + Intf: &net.Interface{Index: 1, Name: "eth0"}, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := routeChanged(tt.route, tt.nexthopv4, tt.nexthopv6) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestIsSoftInterface(t *testing.T) { + tests := []struct { + name string + ifname string + expected bool + }{ + { + name: "ISATAP interface should be detected", + ifname: "ISATAP tunnel adapter", + expected: true, + }, + { + name: "lowercase soft interface should be detected", + ifname: "isatap.{14A5CF17-CA72-43EC-B4EA-B4B093641B7D}", + expected: true, + }, + { + name: "Teredo interface should be detected", + ifname: "Teredo Tunneling Pseudo-Interface", + expected: true, + }, + { + name: "regular interface should not be detected as soft", + ifname: "eth0", + expected: false, + }, + { + name: "another regular interface should not be detected as soft", + ifname: "wlan0", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isSoftInterface(tt.ifname) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/client/internal/networkmonitor/monitor.go b/client/internal/networkmonitor/monitor.go index 5896b66b6..accdd9c9d 100644 --- a/client/internal/networkmonitor/monitor.go +++ b/client/internal/networkmonitor/monitor.go @@ -118,9 +118,12 @@ func (nw *NetworkMonitor) Stop() { } func (nw *NetworkMonitor) checkChanges(ctx context.Context, event chan struct{}, nexthop4 systemops.Nexthop, nexthop6 systemops.Nexthop) { + defer close(event) for { if err := checkChangeFn(ctx, nexthop4, nexthop6); err != nil { - close(event) + if !errors.Is(err, context.Canceled) { + log.Errorf("Network monitor: failed to check for changes: %v", err) + } return } // prevent blocking diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index 5c117b94d..fd511fc20 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -1,6 +1,7 @@ package systemops import ( + "fmt" "net" "net/netip" "sync" @@ -15,6 +16,20 @@ type Nexthop struct { Intf *net.Interface } +// Equal checks if two nexthops are equal. +func (n Nexthop) Equal(other Nexthop) bool { + return n.IP == other.IP && (n.Intf == nil && other.Intf == nil || + n.Intf != nil && other.Intf != nil && n.Intf.Index == other.Intf.Index) +} + +// String returns a string representation of the nexthop. +func (n Nexthop) String() string { + if n.Intf == nil { + return n.IP.String() + } + return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name) +} + type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop] type SysOps struct { diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index ad325e123..f66161595 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -33,8 +33,7 @@ type RouteUpdateType int type RouteUpdate struct { Type RouteUpdateType Destination netip.Prefix - NextHop netip.Addr - Interface *net.Interface + NextHop Nexthop } // RouteMonitor provides a way to monitor changes in the routing table. @@ -231,15 +230,15 @@ func (rm *RouteMonitor) parseUpdate(row *MIB_IPFORWARD_ROW2, notificationType MI intf, err := net.InterfaceByIndex(idx) if err != nil { log.Warnf("failed to get interface name for index %d: %v", idx, err) - update.Interface = &net.Interface{ + update.NextHop.Intf = &net.Interface{ Index: idx, } } else { - update.Interface = intf + update.NextHop.Intf = intf } } - log.Tracef("Received route update with destination %v, next hop %v, interface %v", row.DestinationPrefix, row.NextHop, update.Interface) + log.Tracef("Received route update with destination %v, next hop %v, interface %v", row.DestinationPrefix, row.NextHop, update.NextHop.Intf) dest := parseIPPrefix(row.DestinationPrefix, idx) if !dest.Addr().IsValid() { return RouteUpdate{}, fmt.Errorf("invalid destination: %v", row) @@ -262,7 +261,7 @@ func (rm *RouteMonitor) parseUpdate(row *MIB_IPFORWARD_ROW2, notificationType MI update.Type = updateType update.Destination = dest - update.NextHop = nexthop + update.NextHop.IP = nexthop return update, nil } From 2f34e984b0051c5b0ad1992d88b93c62165bf765 Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Fri, 9 May 2025 15:06:34 +0300 Subject: [PATCH 45/45] [client] Add TCP support to DNS forwarder service listener (#3790) [client] Add TCP support to DNS forwarder service listener --- client/internal/dnsfwd/forwarder.go | 104 ++++++++++++++++++++++------ client/internal/dnsfwd/manager.go | 14 ++++ 2 files changed, 98 insertions(+), 20 deletions(-) diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 8f6a31f47..45b479632 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -33,6 +33,8 @@ type DNSForwarder struct { dnsServer *dns.Server mux *dns.ServeMux + tcpServer *dns.Server + tcpMux *dns.ServeMux mutex sync.RWMutex fwdEntries []*ForwarderEntry @@ -50,22 +52,41 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager } func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { - log.Infof("listen DNS forwarder on address=%s", f.listenAddress) - mux := dns.NewServeMux() + log.Infof("starting DNS forwarder on address=%s", f.listenAddress) - dnsServer := &dns.Server{ + // UDP server + mux := dns.NewServeMux() + f.mux = mux + f.dnsServer = &dns.Server{ Addr: f.listenAddress, Net: "udp", Handler: mux, } - f.dnsServer = dnsServer - f.mux = mux + // TCP server + tcpMux := dns.NewServeMux() + f.tcpMux = tcpMux + f.tcpServer = &dns.Server{ + Addr: f.listenAddress, + Net: "tcp", + Handler: tcpMux, + } f.UpdateDomains(entries) - return dnsServer.ListenAndServe() -} + errCh := make(chan error, 2) + go func() { + log.Infof("DNS UDP listener running on %s", f.listenAddress) + errCh <- f.dnsServer.ListenAndServe() + }() + go func() { + log.Infof("DNS TCP listener running on %s", f.listenAddress) + errCh <- f.tcpServer.ListenAndServe() + }() + + // return the first error we get (e.g. bind failure or shutdown) + return <-errCh +} func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { f.mutex.Lock() defer f.mutex.Unlock() @@ -77,31 +98,41 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { } oldDomains := filterDomains(f.fwdEntries) - for _, d := range oldDomains { f.mux.HandleRemove(d.PunycodeString()) + f.tcpMux.HandleRemove(d.PunycodeString()) } newDomains := filterDomains(entries) for _, d := range newDomains { - f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQuery) + f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQueryUDP) + f.tcpMux.HandleFunc(d.PunycodeString(), f.handleDNSQueryTCP) } f.fwdEntries = entries - log.Debugf("Updated domains from %v to %v", oldDomains, newDomains) } func (f *DNSForwarder) Close(ctx context.Context) error { - if f.dnsServer == nil { - return nil + var result *multierror.Error + + if f.dnsServer != nil { + if err := f.dnsServer.ShutdownContext(ctx); err != nil { + result = multierror.Append(result, fmt.Errorf("UDP shutdown: %w", err)) + } } - return f.dnsServer.ShutdownContext(ctx) + if f.tcpServer != nil { + if err := f.tcpServer.ShutdownContext(ctx); err != nil { + result = multierror.Append(result, fmt.Errorf("TCP shutdown: %w", err)) + } + } + + return nberrors.FormatErrorOrNil(result) } -func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { +func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns.Msg { if len(query.Question) == 0 { - return + return nil } question := query.Question[0] log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v", @@ -123,20 +154,53 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { if err := w.WriteMsg(resp); err != nil { log.Errorf("failed to write DNS response: %v", err) } - return + return nil } ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout) defer cancel() ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain) if err != nil { - f.handleDNSError(w, resp, domain, err) - return + f.handleDNSError(w, query, resp, domain, err) + return nil } f.updateInternalState(domain, ips) f.addIPsToResponse(resp, domain, ips) + return resp +} + +func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { + + resp := f.handleDNSQuery(w, query) + if resp == nil { + return + } + + opt := query.IsEdns0() + maxSize := dns.MinMsgSize + if opt != nil { + // client advertised a larger EDNS0 buffer + maxSize = int(opt.UDPSize()) + } + + // if our response is too big, truncate and set the TC bit + if resp.Len() > maxSize { + resp.Truncate(maxSize) + } + + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write DNS response: %v", err) + } +} + +func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) { + resp := f.handleDNSQuery(w, query) + if resp == nil { + return + } + if err := w.WriteMsg(resp); err != nil { log.Errorf("failed to write DNS response: %v", err) } @@ -179,7 +243,7 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe } // handleDNSError processes DNS lookup errors and sends an appropriate error response -func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domain string, err error) { +func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, query, resp *dns.Msg, domain string, err error) { var dnsErr *net.DNSError switch { @@ -191,7 +255,7 @@ func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domai } if dnsErr.Server != "" { - log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err) + log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[query.Question[0].Qtype], domain, dnsErr.Server, err) } else { log.Warnf(errResolveFailed, domain, err) } diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index e4a23450f..91abce823 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -33,6 +33,7 @@ type Manager struct { statusRecorder *peer.Status fwRules []firewall.Rule + tcpRules []firewall.Rule dnsForwarder *DNSForwarder } @@ -107,6 +108,13 @@ func (m *Manager) allowDNSFirewall() error { } m.fwRules = dnsRules + tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "") + if err != nil { + log.Errorf("failed to add allow DNS router rules, err: %v", err) + return err + } + m.tcpRules = tcpRules + return nil } @@ -117,7 +125,13 @@ func (m *Manager) dropDNSFirewall() error { mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err)) } } + for _, rule := range m.tcpRules { + if err := m.firewall.DeletePeerRule(rule); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err)) + } + } m.fwRules = nil + m.tcpRules = nil return nberrors.FormatErrorOrNil(mErr) }