Compare commits

...

26 Commits

Author SHA1 Message Date
Maycon Santos
840b07c784 add todos 2024-07-05 11:15:28 +02:00
Maycon Santos
85e991ff78 Fix issue with canceled context before pushing metrics and decreasing pushing interval (#2235)
Fix a bug where the post context was canceled before sending metrics to the server.

The interval time was decreased, and an optional environment variable NETBIRD_METRICS_INTERVAL_IN_SECONDS was added to control the interval time.

* update doc URL
2024-07-04 19:15:59 +02:00
Maycon Santos
f9845e53a0 Sort routes by ID and remove DNS routes from overlapping list (#2234) 2024-07-04 16:50:07 +02:00
pascal-fischer
765aba2c1c Add context to throughout the project and update logging (#2209)
propagate context from all the API calls and log request ID, account ID and peer ID

---------

Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
2024-07-03 11:33:02 +02:00
Zoltan Papp
7cb81f1d70 Fix nil pointer exception in case of error (#2230) 2024-07-02 18:18:14 +02:00
Viktor Liu
cea19de667 Debounce network monitor restarts (#2225) 2024-07-02 17:09:00 +02:00
Bethuel Mmbaga
29e5eceb6b Fix linux serial number retrieval (#2206)
* Change source of serial number in sysInfo function

The serial number returned by the sysInfo function in info_linux.go has been fixed. Previously, it was incorrectly fetched from the Chassis object. Now it is correctly fetched from the Product object. This aligns better with the expected system info retrieval method.

* Fallback to product.Serial in sys info

In case of the chassis is "Default String" or empty then try to use product.serial

---------

Co-authored-by: Zoltán Papp <zoltan.pmail@gmail.com>
2024-07-02 13:19:08 +02:00
dependabot[bot]
0f63737330 Bump golang.org/x/image from 0.10.0 to 0.18.0 (#2205)
Bumps [golang.org/x/image](https://github.com/golang/image) from 0.10.0 to 0.18.0.
- [Commits](https://github.com/golang/image/compare/v0.10.0...v0.18.0)

---
updated-dependencies:
- dependency-name: golang.org/x/image
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-07-02 13:12:28 +02:00
Viktor Liu
bf518c5fba Remove interface network monitor checks (#2223) 2024-07-02 12:41:15 +02:00
Maycon Santos
eab6183a8e Add stack trace when saving empty domains (#2228)
added temporary domain check for existing accounts to trace where the issue originated

Refactor save account due to complexity score
2024-07-02 12:40:26 +02:00
Yxnt
4517da8b3a Feat: Client UI Multiple Language Support (#2192)
Signed-off-by: Yxnt <jyxnt1@gmail.com>
2024-07-02 12:47:26 +03:00
Maycon Santos
9c0d923124 fix: client/Dockerfile to reduce vulnerabilities (#2220)
The following vulnerabilities are fixed with an upgrade:
- https://snyk.io/vuln/SNYK-ALPINE318-BUSYBOX-7249236
- https://snyk.io/vuln/SNYK-ALPINE318-BUSYBOX-7249236
- https://snyk.io/vuln/SNYK-ALPINE318-BUSYBOX-7249265
- https://snyk.io/vuln/SNYK-ALPINE318-BUSYBOX-7249265
- https://snyk.io/vuln/SNYK-ALPINE318-BUSYBOX-7249419

Co-authored-by: snyk-bot <snyk-bot@snyk.io>
2024-07-02 09:42:30 +02:00
Maycon Santos
6857734c48 add MACOSX_DEPLOYMENT_TARGET environment to control GUI build target (#2221)
Add MACOSX_DEPLOYMENT_TARGET and MACOS_DEPLOYMENT_TARGET to target build compatible with macOS 11+ instead of relying on the builder's local Xcode version.
2024-07-01 17:59:09 +02:00
Maycon Santos
3b019800f8 Remove DNSSEC parameters and configure AuthenticatedData (#2208) 2024-06-27 18:36:24 +02:00
Maycon Santos
4cd4f88666 Add multiple tabs for route selection (#2198)
Add all routes, overlapping and exit routes tabs
2024-06-27 14:32:30 +02:00
Maycon Santos
d2157bda66 Set EDNS0 when no extra options are set by the dns client (#2195) 2024-06-25 17:18:04 +02:00
Maycon Santos
43a8ba97e3 Add log config and removed domain (#2194)
removed domainname for coturn service as it is needed only for SSL configs

Added log configuration for each service with a rotation and max size

ensure ZITADEL_DATABASE=postgres works
2024-06-25 13:54:09 +02:00
Robert Neumann
17874771cc Feature/Use Zitadel Postgres Integration by default (#2181)
replaces cockroachDB as default DB for Zitadel in the getting started script to deploy script. Users can switch back to cockroachDB by setting the environment variable ZITADEL_DATABASE to cockroach.
2024-06-25 11:10:11 +02:00
Viktor Liu
f6ccf6b97a Improve windows network monitor (#2184)
* Allow other states for windows neighbor network monitor

* Allow windows route network monitor to check for multiple default routes
2024-06-25 10:35:51 +02:00
Viktor Liu
6aae797baf Add loopback ignore rule to nat chains (#2190)
This makes sure loopback traffic is not affected by NAT
2024-06-25 09:43:36 +02:00
Maycon Santos
aca054e51e Using macOS-latest to build GUI (#2189) 2024-06-25 09:34:02 +02:00
Maycon Santos
10cee8f46e Use selector to display dns routes in GUI (#2185)
Use select widget for dns routes on GUI
2024-06-24 16:18:00 +02:00
Viktor Liu
628673db20 Lower retry interval on dns resolve failure (#2176) 2024-06-24 11:55:07 +02:00
Bethuel Mmbaga
eaa31c2dc6 Optimize process checks database read (#2182)
* Add posture checks to peer management

This commit includes posture checks to the peer management logic. The AddPeer, SyncPeer and LoginPeer functions now return a list of posture checks along with the peer and network map.

* Update peer methods to return posture checks

* Refactor

* return early if there is no posture checks

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2024-06-22 17:41:16 +03:00
Zoltan Papp
25723e9b07 Do not use eBPF proxy in case of USP mode (#2180) 2024-06-22 15:33:10 +02:00
Robert Neumann
3cf4d5758f Update Zitadel and CockroachDB Container Image Version (#2169)
* fix type in docker compose

* Update docker compose cockroachdb to latest-23.2 and zitadel to 2.54.3
2024-06-22 12:44:45 +02:00
149 changed files with 3852 additions and 3044 deletions

View File

@@ -173,7 +173,7 @@ jobs:
retention-days: 3 retention-days: 3
release_ui_darwin: release_ui_darwin:
runs-on: macos-11 runs-on: macos-latest
steps: steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV

View File

@@ -178,34 +178,79 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: run script - name: run script with Zitadel PostgreSQL
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
- name: test Caddy file gen - name: test Caddy file gen postgres
run: test -f Caddyfile run: test -f Caddyfile
- name: test docker-compose file gen
- name: test docker-compose file gen postgres
run: test -f docker-compose.yml run: test -f docker-compose.yml
- name: test management.json file gen
- name: test management.json file gen postgres
run: test -f management.json run: test -f management.json
- name: test turnserver.conf file gen
- name: test turnserver.conf file gen postgres
run: | run: |
set -x set -x
test -f turnserver.conf test -f turnserver.conf
grep external-ip turnserver.conf grep external-ip turnserver.conf
- name: test zitadel.env file gen
- name: test zitadel.env file gen postgres
run: test -f zitadel.env run: test -f zitadel.env
- name: test dashboard.env file gen
- name: test dashboard.env file gen postgres
run: test -f dashboard.env run: test -f dashboard.env
- name: test zdb.env file gen postgres
run: test -f zdb.env
- name: Postgres run cleanup
run: |
docker-compose down --volumes --rmi all
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
- name: run script with Zitadel CockroachDB
run: bash -x infrastructure_files/getting-started-with-zitadel.sh
env:
NETBIRD_DOMAIN: use-ip
ZITADEL_DATABASE: cockroach
- name: test Caddy file gen CockroachDB
run: test -f Caddyfile
- name: test docker-compose file gen CockroachDB
run: test -f docker-compose.yml
- name: test management.json file gen CockroachDB
run: test -f management.json
- name: test turnserver.conf file gen CockroachDB
run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
- name: test zitadel.env file gen CockroachDB
run: test -f zitadel.env
- name: test dashboard.env file gen CockroachDB
run: test -f dashboard.env
test-download-geolite2-script: test-download-geolite2-script:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Install jq - name: Install jq
run: sudo apt-get update && sudo apt-get install -y unzip sqlite3 run: sudo apt-get update && sudo apt-get install -y unzip sqlite3
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: test script - name: test script
run: bash -x infrastructure_files/download-geolite2.sh run: bash -x infrastructure_files/download-geolite2.sh
- name: test mmdb file exists - name: test mmdb file exists
run: test -f GeoLite2-City.mmdb run: test -f GeoLite2-City.mmdb
- name: test geonames file exists - name: test geonames file exists
run: test -f geonames.db run: test -f geonames.db

View File

@@ -3,8 +3,10 @@ builds:
- id: netbird-ui-darwin - id: netbird-ui-darwin
dir: client/ui dir: client/ui
binary: netbird-ui binary: netbird-ui
env: [CGO_ENABLED=1] env:
- CGO_ENABLED=1
- MACOSX_DEPLOYMENT_TARGET=11.0
- MACOS_DEPLOYMENT_TARGET=11.0
goos: goos:
- darwin - darwin
goarch: goarch:

View File

@@ -1,4 +1,4 @@
FROM alpine:3.18.5 FROM alpine:3.19
RUN apk add --no-cache ca-certificates iptables ip6tables RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"] ENTRYPOINT [ "/usr/local/bin/netbird","up"]

View File

@@ -76,7 +76,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
t.Fatal(err) t.Fatal(err)
} }
s := grpc.NewServer() s := grpc.NewServer()
store, cleanUp, err := mgmt.NewTestStoreFromJson(config.Datadir) store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -87,13 +87,13 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
if err != nil { if err != nil {
return nil, nil return nil, nil
} }
iv, _ := integrations.NewIntegratedValidator(eventStore) iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv) accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -74,12 +74,12 @@ func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
return nil return nil
} }
err = i.insertRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair) err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
if err != nil { if err != nil {
return err return err
} }
err = i.insertRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair)) err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
if err != nil { if err != nil {
return err return err
} }
@@ -101,6 +101,7 @@ func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string,
} }
delete(i.rules, ruleKey) delete(i.rules, ruleKey)
} }
err = i.iptablesClient.Insert(table, chain, 1, rule...) err = i.iptablesClient.Insert(table, chain, 1, rule...)
if err != nil { if err != nil {
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err) return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
@@ -317,6 +318,13 @@ func (i *routerManager) createChain(table, newChain string) error {
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err) return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
} }
// Add the loopback return rule to the NAT chain
loopbackRule := []string{"-o", "lo", "-j", "RETURN"}
err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...)
if err != nil {
return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err)
}
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN") err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
if err != nil { if err != nil {
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err) return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
@@ -326,6 +334,30 @@ func (i *routerManager) createChain(table, newChain string) error {
return nil return nil
} }
// addNATRule appends an iptables rule pair to the nat chain
func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, pair.Source, pair.Destination)
existingRule, found := i.rules[ruleKey]
if found {
err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
}
delete(i.rules, ruleKey)
}
// inserting after loopback ignore rule
err := i.iptablesClient.Insert(table, chain, 2, rule...)
if err != nil {
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
}
i.rules[ruleKey] = rule
return nil
}
// genRuleSpec generates rule specification // genRuleSpec generates rule specification
func genRuleSpec(jump, source, destination string) []string { func genRuleSpec(jump, source, destination string) []string {
return []string{"-s", source, "-d", destination, "-j", jump} return []string{"-s", source, "-d", destination, "-j", jump}

View File

@@ -95,7 +95,7 @@ func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.InsertRoutingRules(pair) return m.router.AddRoutingRules(pair)
} }
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error { func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {

View File

@@ -22,6 +22,8 @@ const (
userDataAcceptForwardRuleSrc = "frwacceptsrc" userDataAcceptForwardRuleSrc = "frwacceptsrc"
userDataAcceptForwardRuleDst = "frwacceptdst" userDataAcceptForwardRuleDst = "frwacceptdst"
loopbackInterface = "lo\x00"
) )
// some presets for building nftable rules // some presets for building nftable rules
@@ -126,6 +128,22 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeNAT, Type: nftables.ChainTypeNAT,
}) })
// Add RETURN rule for loopback interface
loRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(loopbackInterface),
},
&expr.Verdict{Kind: expr.VerdictReturn},
},
}
r.conn.InsertRule(loRule)
err := r.refreshRulesMap() err := r.refreshRulesMap()
if err != nil { if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err) log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
@@ -138,28 +156,28 @@ func (r *router) createContainers() error {
return nil return nil
} }
// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain // AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (r *router) InsertRoutingRules(pair manager.RouterPair) error { func (r *router) AddRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap() err := r.refreshRulesMap()
if err != nil { if err != nil {
return err return err
} }
err = r.insertRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false) err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
if err != nil { if err != nil {
return err return err
} }
err = r.insertRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false) err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
if err != nil { if err != nil {
return err return err
} }
if pair.Masquerade { if pair.Masquerade {
err = r.insertRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true) err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
if err != nil { if err != nil {
return err return err
} }
err = r.insertRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true) err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
if err != nil { if err != nil {
return err return err
} }
@@ -177,8 +195,8 @@ func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
return nil return nil
} }
// insertRoutingRule inserts a nftable rule to the conn client flush queue // addRoutingRule inserts a nftable rule to the conn client flush queue
func (r *router) insertRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error { func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source) sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination) destExp := generateCIDRMatcherExpressions(false, pair.Destination)
@@ -199,7 +217,7 @@ func (r *router) insertRoutingRule(format, chainName string, pair manager.Router
} }
} }
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{ r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable, Table: r.workTable,
Chain: r.chains[chainName], Chain: r.chains[chainName],
Exprs: expression, Exprs: expression,

View File

@@ -47,7 +47,7 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
err = manager.InsertRoutingRules(testCase.InputPair) err = manager.AddRoutingRules(testCase.InputPair)
defer func() { defer func() {
_ = manager.RemoveRoutingRules(testCase.InputPair) _ = manager.RemoveRoutingRules(testCase.InputPair)
}() }()

View File

@@ -78,6 +78,11 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}() }()
log.WithField("question", r.Question[0]).Trace("received an upstream question") log.WithField("question", r.Question[0]).Trace("received an upstream question")
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
if r.Extra == nil {
r.SetEdns0(4096, false)
r.MsgHdr.AuthenticatedData = true
}
select { select {
case <-u.ctx.Done(): case <-u.ctx.Done():

View File

@@ -282,8 +282,6 @@ func (e *Engine) Start() error {
} }
e.ctx, e.cancel = context.WithCancel(e.clientCtx) e.ctx, e.cancel = context.WithCancel(e.clientCtx)
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, e.config.WgPort)
wgIface, err := e.newWgIface() wgIface, err := e.newWgIface()
if err != nil { if err != nil {
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err) log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err)
@@ -291,6 +289,9 @@ func (e *Engine) Start() error {
} }
e.wgInterface = wgIface e.wgInterface = wgIface
userspace := e.wgInterface.IsUserspaceBind()
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort)
if e.config.RosenpassEnabled { if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled") log.Infof("rosenpass is enabled")
if e.config.RosenpassPermissive { if e.config.RosenpassPermissive {
@@ -1464,6 +1465,15 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs) return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs)
} }
func (e *Engine) restartEngine() {
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
if err := e.Start(); err != nil {
log.Errorf("Failed to start engine: %v", err)
}
}
func (e *Engine) startNetworkMonitor() { func (e *Engine) startNetworkMonitor() {
if !e.config.NetworkMonitor { if !e.config.NetworkMonitor {
log.Infof("Network monitor is disabled, not starting") log.Infof("Network monitor is disabled, not starting")
@@ -1472,14 +1482,29 @@ func (e *Engine) startNetworkMonitor() {
e.networkMonitor = networkmonitor.New() e.networkMonitor = networkmonitor.New()
go func() { go func() {
var mu sync.Mutex
var debounceTimer *time.Timer
// Start the network monitor with a callback, Start will block until the monitor is stopped,
// a network change is detected, or an error occurs on start up
err := e.networkMonitor.Start(e.ctx, func() { err := e.networkMonitor.Start(e.ctx, func() {
log.Infof("Network monitor detected network change, restarting engine") // This function is called when a network change is detected
if err := e.Stop(); err != nil { mu.Lock()
log.Errorf("Failed to stop engine: %v", err) defer mu.Unlock()
}
if err := e.Start(); err != nil { if debounceTimer != nil {
log.Errorf("Failed to start engine: %v", err) debounceTimer.Stop()
} }
// Set a new timer to debounce rapid network changes
debounceTimer = time.AfterFunc(1*time.Second, func() {
// This function is called after the debounce period
mu.Lock()
defer mu.Unlock()
log.Infof("Network monitor detected network change, restarting engine")
e.restartEngine()
})
}) })
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) { if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
log.Errorf("Network monitor: %v", err) log.Errorf("Network monitor: %v", err)

View File

@@ -174,7 +174,7 @@ func TestEngine_SSH(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
//time.Sleep(250 * time.Millisecond) // time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer) assert.NotNil(t, engine.sshServer)
assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=") assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=")
@@ -1057,7 +1057,7 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
} }
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := server.NewTestStoreFromJson(config.Datadir) store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
@@ -1068,13 +1068,13 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
ia, _ := integrations.NewIntegratedValidator(eventStore) ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil) mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -45,24 +45,6 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0])) msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
switch msg.Type { switch msg.Type {
// handle interface state changes
case unix.RTM_IFINFO:
ifinfo, err := parseInterfaceMessage(buf[:n])
if err != nil {
log.Errorf("Network monitor: error parsing interface message: %v", err)
continue
}
if msg.Flags&unix.IFF_UP != 0 {
continue
}
if (nexthopv4.Intf == nil || ifinfo.Index != nexthopv4.Intf.Index) && (nexthopv6.Intf == nil || ifinfo.Index != nexthopv6.Intf.Index) {
continue
}
log.Infof("Network monitor: monitored interface (%s) is down.", ifinfo.Name)
go callback()
// handle route changes // handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE: case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n]) route, err := parseRouteMessage(buf[:n])
@@ -94,24 +76,6 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
} }
} }
func parseInterfaceMessage(buf []byte) (*route.InterfaceMessage, error) {
msgs, err := route.ParseRIB(route.RIBTypeInterface, buf)
if err != nil {
return nil, fmt.Errorf("parse RIB: %v", err)
}
if len(msgs) != 1 {
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
}
msg, ok := msgs[0].(*route.InterfaceMessage)
if !ok {
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
}
return msg, nil
}
func parseRouteMessage(buf []byte) (*systemops.Route, error) { func parseRouteMessage(buf []byte) (*systemops.Route, error) {
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf) msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
if err != nil { if err != nil {

View File

@@ -19,14 +19,9 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
return errors.New("no interfaces available") return errors.New("no interfaces available")
} }
linkChan := make(chan netlink.LinkUpdate)
done := make(chan struct{}) done := make(chan struct{})
defer close(done) defer close(done)
if err := netlink.LinkSubscribe(linkChan, done); err != nil {
return fmt.Errorf("subscribe to link updates: %v", err)
}
routeChan := make(chan netlink.RouteUpdate) routeChan := make(chan netlink.RouteUpdate)
if err := netlink.RouteSubscribe(routeChan, done); err != nil { if err := netlink.RouteSubscribe(routeChan, done); err != nil {
return fmt.Errorf("subscribe to route updates: %v", err) return fmt.Errorf("subscribe to route updates: %v", err)
@@ -38,25 +33,6 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
case <-ctx.Done(): case <-ctx.Done():
return ErrStopped return ErrStopped
// handle interface state changes
case update := <-linkChan:
if (nexthopv4.Intf == nil || update.Index != int32(nexthopv4.Intf.Index)) && (nexthopv6.Intf == nil || update.Index != int32(nexthopv6.Intf.Index)) {
continue
}
switch update.Header.Type {
case syscall.RTM_DELLINK:
log.Infof("Network monitor: monitored interface (%s) is gone", update.Link.Attrs().Name)
go callback()
return nil
case syscall.RTM_NEWLINK:
if (update.IfInfomsg.Flags&syscall.IFF_RUNNING) == 0 && update.Link.Attrs().OperState == netlink.OperDown {
log.Infof("Network monitor: monitored interface (%s) is down.", update.Link.Attrs().Name)
go callback()
return nil
}
}
// handle route changes // handle route changes
case route := <-routeChan: case route := <-routeChan:
// default route and main table // default route and main table

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"strings"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -33,12 +34,8 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
return fmt.Errorf("get neighbors: %w", err) return fmt.Errorf("get neighbors: %w", err)
} }
if n, ok := initialNeighbors[nexthopv4.IP]; ok { neighborv4 = assignNeighbor(nexthopv4, initialNeighbors)
neighborv4 = &n neighborv6 = assignNeighbor(nexthopv6, initialNeighbors)
}
if n, ok := initialNeighbors[nexthopv6.IP]; ok {
neighborv6 = &n
}
} }
log.Debugf("Network monitor: initial IPv4 neighbor: %v, IPv6 neighbor: %v", neighborv4, neighborv6) log.Debugf("Network monitor: initial IPv4 neighbor: %v, IPv6 neighbor: %v", neighborv4, neighborv6)
@@ -58,6 +55,16 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
} }
} }
func assignNeighbor(nexthop systemops.Nexthop, initialNeighbors map[netip.Addr]systemops.Neighbor) *systemops.Neighbor {
if n, ok := initialNeighbors[nexthop.IP]; ok &&
n.State != unreachable &&
n.State != incomplete &&
n.State != tbd {
return &n
}
return nil
}
func changed( func changed(
nexthopv4 systemops.Nexthop, nexthopv4 systemops.Nexthop,
neighborv4 *systemops.Neighbor, neighborv4 *systemops.Neighbor,
@@ -87,37 +94,64 @@ func changed(
} }
// routeChanged checks if the default routes still point to our nexthop/interface // routeChanged checks if the default routes still point to our nexthop/interface
func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes map[netip.Prefix]systemops.Route) bool { func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route) bool {
if !nexthop.IP.IsValid() { if !nexthop.IP.IsValid() {
return false return false
} }
var unspec netip.Prefix unspec := getUnspecifiedPrefix(nexthop.IP)
if nexthop.IP.Is6() { defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec)
unspec = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
} else {
unspec = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
}
if r, ok := routes[unspec]; ok { log.Tracef("network monitor: all default routes:\n%s", strings.Join(defaultRoutes, "\n"))
if r.Nexthop != nexthop.IP || compareIntf(r.Interface, intf) != 0 {
oldIntf, newIntf := "<nil>", "<nil>" if !foundMatchingRoute {
if intf != nil { logRouteChange(nexthop.IP, intf)
oldIntf = intf.Name
}
if r.Interface != nil {
newIntf = r.Interface.Name
}
log.Infof("network monitor: default route changed: %s from %s (%s) to %s (%s)", r.Destination, nexthop.IP, oldIntf, r.Nexthop, newIntf)
return true
}
} else {
log.Infof("network monitor: default route is gone")
return true return true
} }
return false return false
}
func getUnspecifiedPrefix(ip netip.Addr) netip.Prefix {
if ip.Is6() {
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
}
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
}
func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) {
var defaultRoutes []string
foundMatchingRoute := false
for _, r := range routes {
if r.Destination == unspec {
routeInfo := formatRouteInfo(r)
defaultRoutes = append(defaultRoutes, routeInfo)
if r.Nexthop == nexthop.IP && compareIntf(r.Interface, intf) == 0 {
foundMatchingRoute = true
log.Debugf("network monitor: found matching default route: %s", routeInfo)
}
}
}
return defaultRoutes, foundMatchingRoute
}
func formatRouteInfo(r systemops.Route) string {
newIntf := "<nil>"
if r.Interface != nil {
newIntf = r.Interface.Name
}
return fmt.Sprintf("Nexthop: %s, Interface: %s", r.Nexthop, newIntf)
}
func logRouteChange(ip netip.Addr, intf *net.Interface) {
oldIntf := "<nil>"
if intf != nil {
oldIntf = intf.Name
}
log.Infof("network monitor: default route for %s (%s) is gone or changed", ip, oldIntf)
} }
func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, neighbors map[netip.Addr]systemops.Neighbor) bool { func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, neighbors map[netip.Addr]systemops.Neighbor) bool {
@@ -127,7 +161,7 @@ func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, ne
// TODO: consider non-local nexthops, e.g. on point-to-point interfaces // TODO: consider non-local nexthops, e.g. on point-to-point interfaces
if n, ok := neighbors[nexthop.IP]; ok { if n, ok := neighbors[nexthop.IP]; ok {
if n.State != reachable && n.State != permanent { if n.State == unreachable || n.State == incomplete {
log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State)) log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State))
return true return true
} else if n.InterfaceIndex != neighbor.InterfaceIndex { } else if n.InterfaceIndex != neighbor.InterfaceIndex {
@@ -165,18 +199,13 @@ func getNeighbors() (map[netip.Addr]systemops.Neighbor, error) {
return neighbours, nil return neighbours, nil
} }
func getRoutes() (map[netip.Prefix]systemops.Route, error) { func getRoutes() ([]systemops.Route, error) {
entries, err := systemops.GetRoutes() entries, err := systemops.GetRoutes()
if err != nil { if err != nil {
return nil, fmt.Errorf("get routes: %w", err) return nil, fmt.Errorf("get routes: %w", err)
} }
routes := make(map[netip.Prefix]systemops.Route, len(entries)) return entries, nil
for _, entry := range entries {
routes[entry.Destination] = entry
}
return routes, nil
} }
func stateFromInt(state uint8) string { func stateFromInt(state uint8) string {

View File

@@ -36,7 +36,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
} }
func TestConn_GetKey(t *testing.T) { func TestConn_GetKey(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
@@ -51,7 +51,7 @@ func TestConn_GetKey(t *testing.T) {
} }
func TestConn_OnRemoteOffer(t *testing.T) { func TestConn_OnRemoteOffer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
@@ -88,7 +88,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
} }
func TestConn_OnRemoteAnswer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
@@ -124,7 +124,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestConn_Status(t *testing.T) { func TestConn_Status(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
@@ -154,7 +154,7 @@ func TestConn_Status(t *testing.T) {
} }
func TestConn_Close(t *testing.T) { func TestConn_Close(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()

View File

@@ -23,7 +23,8 @@ import (
const ( const (
DefaultInterval = time.Minute DefaultInterval = time.Minute
minInterval = 2 * time.Second minInterval = 2 * time.Second
failureInterval = 5 * time.Second
addAllowedIP = "add allowed IP %s: %w" addAllowedIP = "add allowed IP %s: %w"
) )
@@ -160,7 +161,12 @@ func (r *Route) startResolver(ctx context.Context) {
ticker := time.NewTicker(interval) ticker := time.NewTicker(interval)
defer ticker.Stop() defer ticker.Stop()
r.update(ctx) if err := r.update(ctx); err != nil {
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
if interval > failureInterval {
ticker.Reset(failureInterval)
}
}
for { for {
select { select {
@@ -168,17 +174,28 @@ func (r *Route) startResolver(ctx context.Context) {
log.Debugf("Stopping dynamic route resolver for domains [%v]", r) log.Debugf("Stopping dynamic route resolver for domains [%v]", r)
return return
case <-ticker.C: case <-ticker.C:
r.update(ctx) if err := r.update(ctx); err != nil {
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
// Use a lower ticker interval if the update fails
if interval > failureInterval {
ticker.Reset(failureInterval)
}
} else if interval > failureInterval {
// Reset to the original interval if the update succeeds
ticker.Reset(interval)
}
} }
} }
} }
func (r *Route) update(ctx context.Context) { func (r *Route) update(ctx context.Context) error {
if resolved, err := r.resolveDomains(); err != nil { if resolved, err := r.resolveDomains(); err != nil {
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err) return fmt.Errorf("resolve domains: %w", err)
} else if err := r.updateDynamicRoutes(ctx, resolved); err != nil { } else if err := r.updateDynamicRoutes(ctx, resolved); err != nil {
log.Errorf("Failed to update dynamic routes for [%v]: %v", r, err) return fmt.Errorf("update dynamic routes: %w", err)
} }
return nil
} }
func (r *Route) resolveDomains() (domainMap, error) { func (r *Route) resolveDomains() (domainMap, error) {

View File

@@ -8,9 +8,13 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
func NewFactory(ctx context.Context, wgPort int) *Factory { func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
f := &Factory{wgPort: wgPort} f := &Factory{wgPort: wgPort}
if userspace {
return f
}
ebpfProxy := NewWGEBPFProxy(ctx, wgPort) ebpfProxy := NewWGEBPFProxy(ctx, wgPort)
err := ebpfProxy.listen() err := ebpfProxy.listen()
if err != nil { if err != nil {

View File

@@ -4,6 +4,6 @@ package wgproxy
import "context" import "context"
func NewFactory(ctx context.Context, wgPort int) *Factory { func NewFactory(ctx context.Context, _ bool, wgPort int) *Factory {
return &Factory{wgPort: wgPort} return &Factory{wgPort: wgPort}
} }

View File

@@ -108,7 +108,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
return nil, "", err return nil, "", err
} }
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := server.NewTestStoreFromJson(config.Datadir) store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
@@ -119,13 +119,13 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
ia, _ := integrations.NewIntegratedValidator(eventStore) ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil) mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -89,5 +89,9 @@ func _getInfo() string {
func sysInfo() (serialNumber string, productName string, manufacturer string) { func sysInfo() (serialNumber string, productName string, manufacturer string) {
var si sysinfo.SysInfo var si sysinfo.SysInfo
si.GetSysInfo() si.GetSysInfo()
return si.Chassis.Serial, si.Product.Name, si.Product.Vendor serial := si.Chassis.Serial
if (serial == "Default string" || serial == "") && si.Product.Serial != "" {
serial = si.Product.Serial
}
return serial, si.Product.Name, si.Product.Vendor
} }

View File

@@ -80,6 +80,7 @@ func main() {
log.Errorf("check PID file: %v", err) log.Errorf("check PID file: %v", err)
return return
} }
client.setDefaultFonts()
systray.Run(client.onTrayReady, client.onTrayExit) systray.Run(client.onTrayReady, client.onTrayExit)
} }
} }
@@ -876,3 +877,88 @@ func checkPIDFile() error {
return os.WriteFile(pidFile, []byte(fmt.Sprintf("%d", os.Getpid())), 0o664) //nolint:gosec return os.WriteFile(pidFile, []byte(fmt.Sprintf("%d", os.Getpid())), 0o664) //nolint:gosec
} }
func (s *serviceClient) setDefaultFonts() {
var (
defaultFontPath string
)
//TODO: Linux Multiple Language Support
switch runtime.GOOS {
case "darwin":
defaultFontPath = "/Library/Fonts/Arial Unicode.ttf"
case "windows":
fontPath := s.getWindowsFontFilePath()
defaultFontPath = fontPath
}
_, err := os.Stat(defaultFontPath)
if err == nil {
os.Setenv("FYNE_FONT", defaultFontPath)
}
}
func (s *serviceClient) getWindowsFontFilePath() (fontPath string) {
/*
https://learn.microsoft.com/en-us/windows/apps/design/globalizing/loc-international-fonts
https://learn.microsoft.com/en-us/typography/fonts/windows_11_font_list
*/
var (
fontFolder string = "C:/Windows/Fonts"
fontMapping = map[string]string{
"default": "Segoeui.ttf",
"zh-CN": "Msyh.ttc",
"am-ET": "Ebrima.ttf",
"nirmala": "Nirmala.ttf",
"chr-CHER-US": "Gadugi.ttf",
"zh-HK": "Msjh.ttc",
"zh-TW": "Msjh.ttc",
"ja-JP": "Yugothm.ttc",
"km-KH": "Leelawui.ttf",
"ko-KR": "Malgun.ttf",
"th-TH": "Leelawui.ttf",
"ti-ET": "Ebrima.ttf",
}
nirMalaLang = []string{
"as-IN",
"bn-BD",
"bn-IN",
"gu-IN",
"hi-IN",
"kn-IN",
"kok-IN",
"ml-IN",
"mr-IN",
"ne-NP",
"or-IN",
"pa-IN",
"si-LK",
"ta-IN",
"te-IN",
}
)
cmd := exec.Command("powershell", "-Command", "(Get-Culture).Name")
output, err := cmd.Output()
if err != nil {
log.Errorf("Failed to get Windows default language setting: %v", err)
fontPath = path.Join(fontFolder, fontMapping["default"])
return
}
defaultLanguage := strings.TrimSpace(string(output))
for _, lang := range nirMalaLang {
if defaultLanguage == lang {
fontPath = path.Join(fontFolder, fontMapping["nirmala"])
return
}
}
if font, ok := fontMapping[defaultLanguage]; ok {
fontPath = path.Join(fontFolder, font)
} else {
fontPath = path.Join(fontFolder, fontMapping["default"])
}
return
}

View File

@@ -4,6 +4,7 @@ package main
import ( import (
"fmt" "fmt"
"sort"
"strings" "strings"
"time" "time"
@@ -17,28 +18,57 @@ import (
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
) )
const (
allRoutesText = "All routes"
overlappingRoutesText = "Overlapping routes"
exitNodeRoutesText = "Exit-node routes"
allRoutes filter = "all"
overlappingRoutes filter = "overlapping"
exitNodeRoutes filter = "exit-node"
getClientFMT = "get client: %v"
)
type filter string
func (s *serviceClient) showRoutesUI() { func (s *serviceClient) showRoutesUI() {
s.wRoutes = s.app.NewWindow("NetBird Routes") s.wRoutes = s.app.NewWindow("NetBird Routes")
grid := container.New(layout.NewGridLayout(3)) allGrid := container.New(layout.NewGridLayout(3))
go s.updateRoutes(grid) go s.updateRoutes(allGrid, allRoutes)
overlappingGrid := container.New(layout.NewGridLayout(3))
exitNodeGrid := container.New(layout.NewGridLayout(3))
routeCheckContainer := container.NewVBox() routeCheckContainer := container.NewVBox()
routeCheckContainer.Add(grid) tabs := container.NewAppTabs(
container.NewTabItem(allRoutesText, allGrid),
container.NewTabItem(overlappingRoutesText, overlappingGrid),
container.NewTabItem(exitNodeRoutesText, exitNodeGrid),
)
tabs.OnSelected = func(item *container.TabItem) {
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
}
tabs.OnUnselected = func(item *container.TabItem) {
grid, _ := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
grid.Objects = nil
}
routeCheckContainer.Add(tabs)
scrollContainer := container.NewVScroll(routeCheckContainer) scrollContainer := container.NewVScroll(routeCheckContainer)
scrollContainer.SetMinSize(fyne.NewSize(200, 300)) scrollContainer.SetMinSize(fyne.NewSize(200, 300))
buttonBox := container.NewHBox( buttonBox := container.NewHBox(
layout.NewSpacer(), layout.NewSpacer(),
widget.NewButton("Refresh", func() { widget.NewButton("Refresh", func() {
s.updateRoutes(grid) s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
}), }),
widget.NewButton("Select all", func() { widget.NewButton("Select all", func() {
s.selectAllRoutes() _, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
s.updateRoutes(grid) s.selectAllFilteredRoutes(f)
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
}), }),
widget.NewButton("Deselect All", func() { widget.NewButton("Deselect All", func() {
s.deselectAllRoutes() _, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
s.updateRoutes(grid) s.deselectAllFilteredRoutes(f)
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
}), }),
layout.NewSpacer(), layout.NewSpacer(),
) )
@@ -48,18 +78,12 @@ func (s *serviceClient) showRoutesUI() {
s.wRoutes.SetContent(content) s.wRoutes.SetContent(content)
s.wRoutes.Show() s.wRoutes.Show()
s.startAutoRefresh(5*time.Second, grid) s.startAutoRefresh(10*time.Second, tabs, allGrid, overlappingGrid, exitNodeGrid)
} }
func (s *serviceClient) updateRoutes(grid *fyne.Container) { func (s *serviceClient) updateRoutes(grid *fyne.Container, f filter) {
routes, err := s.fetchRoutes()
if err != nil {
log.Errorf("get client: %v", err)
s.showError(fmt.Errorf("get client: %v", err))
return
}
grid.Objects = nil grid.Objects = nil
grid.Refresh()
idHeader := widget.NewLabelWithStyle(" ID", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) idHeader := widget.NewLabelWithStyle(" ID", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
networkHeader := widget.NewLabelWithStyle("Network/Domains", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) networkHeader := widget.NewLabelWithStyle("Network/Domains", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
resolvedIPsHeader := widget.NewLabelWithStyle("Resolved IPs", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) resolvedIPsHeader := widget.NewLabelWithStyle("Resolved IPs", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
@@ -67,7 +91,15 @@ func (s *serviceClient) updateRoutes(grid *fyne.Container) {
grid.Add(idHeader) grid.Add(idHeader)
grid.Add(networkHeader) grid.Add(networkHeader)
grid.Add(resolvedIPsHeader) grid.Add(resolvedIPsHeader)
for _, route := range routes {
filteredRoutes, err := s.getFilteredRoutes(f)
if err != nil {
return
}
sortRoutesByIDs(filteredRoutes)
for _, route := range filteredRoutes {
r := route r := route
checkBox := widget.NewCheck(r.GetID(), func(checked bool) { checkBox := widget.NewCheck(r.GetID(), func(checked bool) {
@@ -80,35 +112,104 @@ func (s *serviceClient) updateRoutes(grid *fyne.Container) {
grid.Add(checkBox) grid.Add(checkBox)
network := r.GetNetwork() network := r.GetNetwork()
domains := r.GetDomains() domains := r.GetDomains()
if len(domains) > 0 {
network = strings.Join(domains, ", ")
}
grid.Add(widget.NewLabel(network))
if len(domains) > 0 { if len(domains) == 0 {
var resolvedIPsList []string grid.Add(widget.NewLabel(network))
for _, domain := range r.GetDomains() {
if ipList, exists := r.GetResolvedIPs()[domain]; exists {
resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", ")))
}
}
// TODO: limit width
resolvedIPsLabel := widget.NewLabel(strings.Join(resolvedIPsList, ", "))
grid.Add(resolvedIPsLabel)
} else {
grid.Add(widget.NewLabel("")) grid.Add(widget.NewLabel(""))
continue
} }
// our selectors are only for display
noopFunc := func(_ string) {
// do nothing
}
domainsSelector := widget.NewSelect(domains, noopFunc)
domainsSelector.Selected = domains[0]
grid.Add(domainsSelector)
var resolvedIPsList []string
for _, domain := range domains {
if ipList, exists := r.GetResolvedIPs()[domain]; exists {
resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", ")))
}
}
if len(resolvedIPsList) == 0 {
grid.Add(widget.NewLabel(""))
continue
}
// TODO: limit width within the selector display
resolvedIPsSelector := widget.NewSelect(resolvedIPsList, noopFunc)
resolvedIPsSelector.Selected = resolvedIPsList[0]
resolvedIPsSelector.Resize(fyne.NewSize(100, 100))
grid.Add(resolvedIPsSelector)
} }
s.wRoutes.Content().Refresh() s.wRoutes.Content().Refresh()
grid.Refresh() grid.Refresh()
} }
func (s *serviceClient) getFilteredRoutes(f filter) ([]*proto.Route, error) {
routes, err := s.fetchRoutes()
if err != nil {
log.Errorf(getClientFMT, err)
s.showError(fmt.Errorf(getClientFMT, err))
return nil, err
}
switch f {
case overlappingRoutes:
return getOverlappingRoutes(routes), nil
case exitNodeRoutes:
return getExitNodeRoutes(routes), nil
default:
}
return routes, nil
}
func getOverlappingRoutes(routes []*proto.Route) []*proto.Route {
var filteredRoutes []*proto.Route
existingRange := make(map[string][]*proto.Route)
for _, route := range routes {
if len(route.Domains) > 0 {
continue
}
if r, exists := existingRange[route.GetNetwork()]; exists {
r = append(r, route)
existingRange[route.GetNetwork()] = r
} else {
existingRange[route.GetNetwork()] = []*proto.Route{route}
}
}
for _, r := range existingRange {
if len(r) > 1 {
filteredRoutes = append(filteredRoutes, r...)
}
}
return filteredRoutes
}
func getExitNodeRoutes(routes []*proto.Route) []*proto.Route {
var filteredRoutes []*proto.Route
for _, route := range routes {
if route.Network == "0.0.0.0/0" {
filteredRoutes = append(filteredRoutes, route)
}
}
return filteredRoutes
}
func sortRoutesByIDs(routes []*proto.Route) {
sort.Slice(routes, func(i, j int) bool {
return strings.ToLower(routes[i].GetID()) < strings.ToLower(routes[j].GetID())
})
}
func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) { func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) {
conn, err := s.getSrvClient(defaultFailTimeout) conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil { if err != nil {
return nil, fmt.Errorf("get client: %v", err) return nil, fmt.Errorf(getClientFMT, err)
} }
resp, err := conn.ListRoutes(s.ctx, &proto.ListRoutesRequest{}) resp, err := conn.ListRoutes(s.ctx, &proto.ListRoutesRequest{})
@@ -122,8 +223,8 @@ func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) {
func (s *serviceClient) selectRoute(id string, checked bool) { func (s *serviceClient) selectRoute(id string, checked bool) {
conn, err := s.getSrvClient(defaultFailTimeout) conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil { if err != nil {
log.Errorf("get client: %v", err) log.Errorf(getClientFMT, err)
s.showError(fmt.Errorf("get client: %v", err)) s.showError(fmt.Errorf(getClientFMT, err))
return return
} }
@@ -149,16 +250,14 @@ func (s *serviceClient) selectRoute(id string, checked bool) {
} }
} }
func (s *serviceClient) selectAllRoutes() { func (s *serviceClient) selectAllFilteredRoutes(f filter) {
conn, err := s.getSrvClient(defaultFailTimeout) conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil { if err != nil {
log.Errorf("get client: %v", err) log.Errorf(getClientFMT, err)
return return
} }
req := &proto.SelectRoutesRequest{ req := s.getRoutesRequest(f, true)
All: true,
}
if _, err := conn.SelectRoutes(s.ctx, req); err != nil { if _, err := conn.SelectRoutes(s.ctx, req); err != nil {
log.Errorf("failed to select all routes: %v", err) log.Errorf("failed to select all routes: %v", err)
s.showError(fmt.Errorf("failed to select all routes: %v", err)) s.showError(fmt.Errorf("failed to select all routes: %v", err))
@@ -168,16 +267,14 @@ func (s *serviceClient) selectAllRoutes() {
log.Debug("All routes selected") log.Debug("All routes selected")
} }
func (s *serviceClient) deselectAllRoutes() { func (s *serviceClient) deselectAllFilteredRoutes(f filter) {
conn, err := s.getSrvClient(defaultFailTimeout) conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil { if err != nil {
log.Errorf("get client: %v", err) log.Errorf(getClientFMT, err)
return return
} }
req := &proto.SelectRoutesRequest{ req := s.getRoutesRequest(f, false)
All: true,
}
if _, err := conn.DeselectRoutes(s.ctx, req); err != nil { if _, err := conn.DeselectRoutes(s.ctx, req); err != nil {
log.Errorf("failed to deselect all routes: %v", err) log.Errorf("failed to deselect all routes: %v", err)
s.showError(fmt.Errorf("failed to deselect all routes: %v", err)) s.showError(fmt.Errorf("failed to deselect all routes: %v", err))
@@ -187,17 +284,34 @@ func (s *serviceClient) deselectAllRoutes() {
log.Debug("All routes deselected") log.Debug("All routes deselected")
} }
func (s *serviceClient) getRoutesRequest(f filter, appendRoute bool) *proto.SelectRoutesRequest {
req := &proto.SelectRoutesRequest{}
if f == allRoutes {
req.All = true
} else {
routes, err := s.getFilteredRoutes(f)
if err != nil {
return nil
}
for _, route := range routes {
req.RouteIDs = append(req.RouteIDs, route.GetID())
}
req.Append = appendRoute
}
return req
}
func (s *serviceClient) showError(err error) { func (s *serviceClient) showError(err error) {
wrappedMessage := wrapText(err.Error(), 50) wrappedMessage := wrapText(err.Error(), 50)
dialog.ShowError(fmt.Errorf("%s", wrappedMessage), s.wRoutes) dialog.ShowError(fmt.Errorf("%s", wrappedMessage), s.wRoutes)
} }
func (s *serviceClient) startAutoRefresh(interval time.Duration, grid *fyne.Container) { func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) {
ticker := time.NewTicker(interval) ticker := time.NewTicker(interval)
go func() { go func() {
for range ticker.C { for range ticker.C {
s.updateRoutes(grid) s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodesGrid)
} }
}() }()
@@ -206,6 +320,23 @@ func (s *serviceClient) startAutoRefresh(interval time.Duration, grid *fyne.Cont
}) })
} }
func (s *serviceClient) updateRoutesBasedOnDisplayTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) {
grid, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodesGrid)
s.wRoutes.Content().Refresh()
s.updateRoutes(grid, f)
}
func getGridAndFilterFromTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) (*fyne.Container, filter) {
switch tabs.Selected().Text {
case overlappingRoutesText:
return overlappingGrid, overlappingRoutes
case exitNodeRoutesText:
return exitNodesGrid, exitNodeRoutes
default:
return allGrid, allRoutes
}
}
// wrapText inserts newlines into the text to ensure that each line is // wrapText inserts newlines into the text to ensure that each line is
// no longer than 'lineLength' runes. // no longer than 'lineLength' runes.
func wrapText(text string, lineLength int) string { func wrapText(text string, lineLength int) string {

View File

@@ -7,6 +7,18 @@ import (
"strings" "strings"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/context"
)
type ExecutionContext string
const (
ExecutionContextKey = "executionContext"
HTTPSource ExecutionContext = "HTTP"
GRPCSource ExecutionContext = "GRPC"
SystemSource ExecutionContext = "SYSTEM"
) )
// ContextHook is a custom hook for add the source information for the entry // ContextHook is a custom hook for add the source information for the entry
@@ -30,6 +42,27 @@ func (hook ContextHook) Levels() []logrus.Level {
func (hook ContextHook) Fire(entry *logrus.Entry) error { func (hook ContextHook) Fire(entry *logrus.Entry) error {
src := hook.parseSrc(entry.Caller.File) src := hook.parseSrc(entry.Caller.File)
entry.Data["source"] = fmt.Sprintf("%s:%v", src, entry.Caller.Line) entry.Data["source"] = fmt.Sprintf("%s:%v", src, entry.Caller.Line)
if entry.Context == nil {
return nil
}
source, ok := entry.Context.Value(ExecutionContextKey).(ExecutionContext)
if !ok {
return nil
}
entry.Data["context"] = source
switch source {
case HTTPSource:
addHTTPFields(entry)
case GRPCSource:
addGRPCFields(entry)
case SystemSource:
addSystemFields(entry)
}
return nil return nil
} }
@@ -59,3 +92,42 @@ func (hook ContextHook) parseSrc(filePath string) string {
file := path.Base(filePath) file := path.Base(filePath)
return fmt.Sprintf("%s/%s", pkg, file) return fmt.Sprintf("%s/%s", pkg, file)
} }
func addHTTPFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID
}
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
entry.Data[context.AccountIDKey] = ctxAccountID
}
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
entry.Data[context.UserIDKey] = ctxInitiatorID
}
}
func addGRPCFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID
}
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
entry.Data[context.AccountIDKey] = ctxAccountID
}
if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok {
entry.Data[context.PeerIDKey] = ctxDeviceID
}
}
func addSystemFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID
}
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
entry.Data[context.UserIDKey] = ctxInitiatorID
}
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
entry.Data[context.AccountIDKey] = ctxAccountID
}
if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok {
entry.Data[context.PeerIDKey] = ctxDeviceID
}
}

View File

@@ -1,6 +1,8 @@
package formatter package formatter
import "github.com/sirupsen/logrus" import (
"github.com/sirupsen/logrus"
)
// SetTextFormatter set the text formatter for given logger. // SetTextFormatter set the text formatter for given logger.
func SetTextFormatter(logger *logrus.Logger) { func SetTextFormatter(logger *logrus.Logger) {
@@ -9,6 +11,13 @@ func SetTextFormatter(logger *logrus.Logger) {
logger.AddHook(NewContextHook()) logger.AddHook(NewContextHook())
} }
// SetJSONFormatter set the JSON formatter for given logger.
func SetJSONFormatter(logger *logrus.Logger) {
logger.Formatter = &logrus.JSONFormatter{}
logger.ReportCaller = true
logger.AddHook(NewContextHook())
}
// SetLogcatFormatter set the logcat formatter for given logger. // SetLogcatFormatter set the logcat formatter for given logger.
func SetLogcatFormatter(logger *logrus.Logger) { func SetLogcatFormatter(logger *logrus.Logger) {
logger.Formatter = NewLogcatFormatter() logger.Formatter = NewLogcatFormatter()

7
go.mod
View File

@@ -44,7 +44,6 @@ require (
github.com/golang/mock v1.6.0 github.com/golang/mock v1.6.0
github.com/google/go-cmp v0.6.0 github.com/google/go-cmp v0.6.0
github.com/google/gopacket v1.1.19 github.com/google/gopacket v1.1.19
github.com/google/martian/v3 v3.0.0
github.com/google/nftables v0.0.0-20220808154552-2eca00135732 github.com/google/nftables v0.0.0-20220808154552-2eca00135732
github.com/gopacket/gopacket v1.1.1 github.com/gopacket/gopacket v1.1.1
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
@@ -58,7 +57,7 @@ require (
github.com/miekg/dns v1.1.43 github.com/miekg/dns v1.1.43
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0 github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
@@ -189,8 +188,8 @@ require (
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
go.opentelemetry.io/otel/sdk v1.26.0 // indirect go.opentelemetry.io/otel/sdk v1.26.0 // indirect
go.opentelemetry.io/otel/trace v1.26.0 // indirect go.opentelemetry.io/otel/trace v1.26.0 // indirect
golang.org/x/image v0.10.0 // indirect golang.org/x/image v0.18.0 // indirect
golang.org/x/text v0.15.0 // indirect golang.org/x/text v0.16.0 // indirect
golang.org/x/time v0.5.0 // indirect golang.org/x/time v0.5.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect

16
go.sum
View File

@@ -209,8 +209,6 @@ github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSN
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
github.com/google/martian/v3 v3.0.0 h1:pMen7vLs8nvgEYhywH3KDWJIJTeEr2ULsVWHWYHQyBs=
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A= github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A=
github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc= github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc=
github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o= github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o=
@@ -335,8 +333,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd h1:IzGGIJMpz07aPs3R6/4sxZv63JoCMddftLpVodUK+Ec= github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc= github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM= github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=
@@ -542,8 +540,8 @@ golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJ
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20200430140353-33d19683fad8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.0.0-20200430140353-33d19683fad8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/image v0.10.0 h1:gXjUUtwtx5yOE0VKWq1CH4IJAClq4UGgUA3i+rpON9M= golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ=
golang.org/x/image v0.10.0/go.mod h1:jtrku+n79PfroUbvDdeUWMAI+heR786BofxrbiSF+J0= golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
@@ -565,7 +563,6 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191004110552-13f9640d40b9/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191004110552-13f9640d40b9/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
@@ -655,11 +652,10 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@@ -28,7 +28,11 @@ services:
- LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL - LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL
volumes: volumes:
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt/ - $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt/
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Signal # Signal
signal: signal:
image: netbirdio/signal:$NETBIRD_SIGNAL_TAG image: netbirdio/signal:$NETBIRD_SIGNAL_TAG
@@ -40,6 +44,11 @@ services:
# # port and command for Let's Encrypt validation # # port and command for Let's Encrypt validation
# - 443:443 # - 443:443
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Management # Management
management: management:
@@ -63,12 +72,16 @@ services:
"--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN", "--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN",
"--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN" "--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"
] ]
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Coturn # Coturn
coturn: coturn:
image: coturn/coturn:$COTURN_TAG image: coturn/coturn:$COTURN_TAG
restart: unless-stopped restart: unless-stopped
domainname: $TURN_DOMAIN #domainname: $TURN_DOMAIN # only needed when TLS is enabled
volumes: volumes:
- ./turnserver.conf:/etc/turnserver.conf:ro - ./turnserver.conf:/etc/turnserver.conf:ro
# - ./privkey.pem:/etc/coturn/private/privkey.pem:ro # - ./privkey.pem:/etc/coturn/private/privkey.pem:ro
@@ -76,7 +89,11 @@ services:
network_mode: host network_mode: host
command: command:
- -c /etc/turnserver.conf - -c /etc/turnserver.conf
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
volumes: volumes:
$MGMT_VOLUMENAME: $MGMT_VOLUMENAME:
$SIGNAL_VOLUMENAME: $SIGNAL_VOLUMENAME:

View File

@@ -50,7 +50,7 @@ check_jq() {
wait_crdb() { wait_crdb() {
set +e set +e
while true; do while true; do
if $DOCKER_COMPOSE_COMMAND exec -T crdb curl -sf -o /dev/null 'http://localhost:8080/health?ready=1'; then if $DOCKER_COMPOSE_COMMAND exec -T zdb curl -sf -o /dev/null 'http://localhost:8080/health?ready=1'; then
break break
fi fi
echo -n " ." echo -n " ."
@@ -61,14 +61,16 @@ wait_crdb() {
} }
init_crdb() { init_crdb() {
echo -e "\nInitializing Zitadel's CockroachDB\n\n" if [[ $ZITADEL_DATABASE == "cockroach" ]]; then
$DOCKER_COMPOSE_COMMAND up -d crdb echo -e "\nInitializing Zitadel's CockroachDB\n\n"
echo "" $DOCKER_COMPOSE_COMMAND up -d zdb
# shellcheck disable=SC2028 echo ""
echo -n "Waiting cockroachDB to become ready " # shellcheck disable=SC2028
wait_crdb echo -n "Waiting CockroachDB to become ready"
$DOCKER_COMPOSE_COMMAND exec -T crdb /bin/bash -c "cp /cockroach/certs/* /zitadel-certs/ && cockroach cert create-client --overwrite --certs-dir /zitadel-certs/ --ca-key /zitadel-certs/ca.key zitadel_user && chown -R 1000:1000 /zitadel-certs/" wait_crdb
handle_request_command_status $? "init_crdb failed" "" $DOCKER_COMPOSE_COMMAND exec -T zdb /bin/bash -c "cp /cockroach/certs/* /zitadel-certs/ && cockroach cert create-client --overwrite --certs-dir /zitadel-certs/ --ca-key /zitadel-certs/ca.key zitadel_user && chown -R 1000:1000 /zitadel-certs/"
handle_request_command_status $? "init_crdb failed" ""
fi
} }
get_main_ip_address() { get_main_ip_address() {
@@ -156,7 +158,7 @@ create_new_application() {
"'"$BASE_REDIRECT_URL2"'" "'"$BASE_REDIRECT_URL2"'"
], ],
"postLogoutRedirectUris": [ "postLogoutRedirectUris": [
"'"$LOGOUT_URL"'" "'"$LOGOUT_URL"'"
], ],
"RESPONSETypes": [ "RESPONSETypes": [
"OIDC_RESPONSE_TYPE_CODE" "OIDC_RESPONSE_TYPE_CODE"
@@ -461,6 +463,20 @@ initEnvironment() {
exit 1 exit 1
fi fi
if [[ $ZITADEL_DATABASE == "cockroach" ]]; then
echo "Use CockroachDB as Zitadel database."
ZDB=$(renderDockerComposeCockroachDB)
ZITADEL_DB_ENV=$(renderZitadelCockroachDBEnv)
else
echo "Use Postgres as default Zitadel database."
echo "For using CockroachDB please the environment variable 'export ZITADEL_DATABASE=cockroach'."
POSTGRES_ROOT_PASSWORD="$(openssl rand -base64 32 | sed 's/=//g')@"
POSTGRES_ZITADEL_PASSWORD="$(openssl rand -base64 32 | sed 's/=//g')@"
ZDB=$(renderDockerComposePostgres)
ZITADEL_DB_ENV=$(renderZitadelPostgresEnv)
renderPostgresEnv > zdb.env
fi
echo Rendering initial files... echo Rendering initial files...
renderDockerCompose > docker-compose.yml renderDockerCompose > docker-compose.yml
renderCaddyfile > Caddyfile renderCaddyfile > Caddyfile
@@ -474,7 +490,7 @@ initEnvironment() {
init_crdb init_crdb
echo -e "\nStarting Zidatel IDP for user management\n\n" echo -e "\nStarting Zitadel IDP for user management\n\n"
$DOCKER_COMPOSE_COMMAND up -d caddy zitadel $DOCKER_COMPOSE_COMMAND up -d caddy zitadel
init_zitadel init_zitadel
@@ -634,15 +650,15 @@ renderManagementJson() {
"ExtraConfig": { "ExtraConfig": {
"ManagementEndpoint": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/management/v1" "ManagementEndpoint": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/management/v1"
} }
}, },
"DeviceAuthorizationFlow": { "DeviceAuthorizationFlow": {
"Provider": "hosted", "Provider": "hosted",
"ProviderConfig": { "ProviderConfig": {
"Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI", "Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI",
"ClientID": "$NETBIRD_AUTH_CLIENT_ID_CLI", "ClientID": "$NETBIRD_AUTH_CLIENT_ID_CLI",
"Scope": "openid" "Scope": "openid"
} }
}, },
"PKCEAuthorizationFlow": { "PKCEAuthorizationFlow": {
"ProviderConfig": { "ProviderConfig": {
"Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI", "Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI",
@@ -679,16 +695,6 @@ renderZitadelEnv() {
cat <<EOF cat <<EOF
ZITADEL_LOG_LEVEL=debug ZITADEL_LOG_LEVEL=debug
ZITADEL_MASTERKEY=$ZITADEL_MASTERKEY ZITADEL_MASTERKEY=$ZITADEL_MASTERKEY
ZITADEL_DATABASE_COCKROACH_HOST=crdb
ZITADEL_DATABASE_COCKROACH_USER_USERNAME=zitadel_user
ZITADEL_DATABASE_COCKROACH_USER_SSL_MODE=verify-full
ZITADEL_DATABASE_COCKROACH_USER_SSL_ROOTCERT="/crdb-certs/ca.crt"
ZITADEL_DATABASE_COCKROACH_USER_SSL_CERT="/crdb-certs/client.zitadel_user.crt"
ZITADEL_DATABASE_COCKROACH_USER_SSL_KEY="/crdb-certs/client.zitadel_user.key"
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_MODE=verify-full
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_ROOTCERT="/crdb-certs/ca.crt"
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_CERT="/crdb-certs/client.root.crt"
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_KEY="/crdb-certs/client.root.key"
ZITADEL_EXTERNALSECURE=$ZITADEL_EXTERNALSECURE ZITADEL_EXTERNALSECURE=$ZITADEL_EXTERNALSECURE
ZITADEL_TLS_ENABLED="false" ZITADEL_TLS_ENABLED="false"
ZITADEL_EXTERNALPORT=$NETBIRD_PORT ZITADEL_EXTERNALPORT=$NETBIRD_PORT
@@ -698,6 +704,43 @@ ZITADEL_FIRSTINSTANCE_ORG_MACHINE_MACHINE_USERNAME=zitadel-admin-sa
ZITADEL_FIRSTINSTANCE_ORG_MACHINE_MACHINE_NAME=Admin ZITADEL_FIRSTINSTANCE_ORG_MACHINE_MACHINE_NAME=Admin
ZITADEL_FIRSTINSTANCE_ORG_MACHINE_PAT_SCOPES=openid ZITADEL_FIRSTINSTANCE_ORG_MACHINE_PAT_SCOPES=openid
ZITADEL_FIRSTINSTANCE_ORG_MACHINE_PAT_EXPIRATIONDATE=$ZIDATE_TOKEN_EXPIRATION_DATE ZITADEL_FIRSTINSTANCE_ORG_MACHINE_PAT_EXPIRATIONDATE=$ZIDATE_TOKEN_EXPIRATION_DATE
$ZITADEL_DB_ENV
EOF
}
renderZitadelCockroachDBEnv() {
cat <<EOF
ZITADEL_DATABASE_COCKROACH_HOST=zdb
ZITADEL_DATABASE_COCKROACH_USER_USERNAME=zitadel_user
ZITADEL_DATABASE_COCKROACH_USER_SSL_MODE=verify-full
ZITADEL_DATABASE_COCKROACH_USER_SSL_ROOTCERT="/zdb-certs/ca.crt"
ZITADEL_DATABASE_COCKROACH_USER_SSL_CERT="/zdb-certs/client.zitadel_user.crt"
ZITADEL_DATABASE_COCKROACH_USER_SSL_KEY="/zdb-certs/client.zitadel_user.key"
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_MODE=verify-full
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_ROOTCERT="/zdb-certs/ca.crt"
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_CERT="/zdb-certs/client.root.crt"
ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_KEY="/zdb-certs/client.root.key"
EOF
}
renderZitadelPostgresEnv() {
cat <<EOF
ZITADEL_DATABASE_POSTGRES_HOST=zdb
ZITADEL_DATABASE_POSTGRES_PORT=5432
ZITADEL_DATABASE_POSTGRES_DATABASE=zitadel
ZITADEL_DATABASE_POSTGRES_USER_USERNAME=zitadel
ZITADEL_DATABASE_POSTGRES_USER_PASSWORD=$POSTGRES_ZITADEL_PASSWORD
ZITADEL_DATABASE_POSTGRES_USER_SSL_MODE=disable
ZITADEL_DATABASE_POSTGRES_ADMIN_USERNAME=root
ZITADEL_DATABASE_POSTGRES_ADMIN_PASSWORD=$POSTGRES_ROOT_PASSWORD
ZITADEL_DATABASE_POSTGRES_ADMIN_SSL_MODE=disable
EOF
}
renderPostgresEnv() {
cat <<EOF
POSTGRES_USER=root
POSTGRES_PASSWORD=$POSTGRES_ROOT_PASSWORD
EOF EOF
} }
@@ -717,18 +760,28 @@ services:
volumes: volumes:
- netbird_caddy_data:/data - netbird_caddy_data:/data
- ./Caddyfile:/etc/caddy/Caddyfile - ./Caddyfile:/etc/caddy/Caddyfile
#UI dashboard # UI dashboard
dashboard: dashboard:
image: netbirdio/dashboard:latest image: netbirdio/dashboard:latest
restart: unless-stopped restart: unless-stopped
networks: [netbird] networks: [netbird]
env_file: env_file:
- ./dashboard.env - ./dashboard.env
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Signal # Signal
signal: signal:
image: netbirdio/signal:latest image: netbirdio/signal:latest
restart: unless-stopped restart: unless-stopped
networks: [netbird] networks: [netbird]
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Management # Management
management: management:
image: netbirdio/management:latest image: netbirdio/management:latest
@@ -746,52 +799,49 @@ services:
"--dns-domain=netbird.selfhosted", "--dns-domain=netbird.selfhosted",
"--idp-sign-key-refresh-enabled", "--idp-sign-key-refresh-enabled",
] ]
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Coturn, AKA relay server # Coturn, AKA relay server
coturn: coturn:
image: coturn/coturn image: coturn/coturn
restart: unless-stopped restart: unless-stopped
domainname: netbird.relay.selfhosted #domainname: netbird.relay.selfhosted
volumes: volumes:
- ./turnserver.conf:/etc/turnserver.conf:ro - ./turnserver.conf:/etc/turnserver.conf:ro
network_mode: host network_mode: host
command: command:
- -c /etc/turnserver.conf - -c /etc/turnserver.conf
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
# Zitadel - identity provider # Zitadel - identity provider
zitadel: zitadel:
restart: 'always' restart: 'always'
networks: [netbird] networks: [netbird]
image: 'ghcr.io/zitadel/zitadel:v2.31.3' image: 'ghcr.io/zitadel/zitadel:v2.54.3'
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE' command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
env_file: env_file:
- ./zitadel.env - ./zitadel.env
depends_on: depends_on:
crdb: zdb:
condition: 'service_healthy' condition: 'service_healthy'
volumes: volumes:
- ./machinekey:/machinekey - ./machinekey:/machinekey
- netbird_zitadel_certs:/crdb-certs:ro - netbird_zitadel_certs:/zdb-certs:ro
# CockroachDB for zitadel logging:
crdb: driver: "json-file"
restart: 'always' options:
networks: [netbird] max-size: "500m"
image: 'cockroachdb/cockroach:v22.2.2' max-file: "2"
command: 'start-single-node --advertise-addr crdb' $ZDB
volumes: netbird_zdb_data:
- netbird_crdb_data:/cockroach/cockroach-data
- netbird_crdb_certs:/cockroach/certs
- netbird_zitadel_certs:/zitadel-certs
healthcheck:
test: [ "CMD", "curl", "-f", "http://localhost:8080/health?ready=1" ]
interval: '10s'
timeout: '30s'
retries: 5
start_period: '20s'
volumes:
netbird_management: netbird_management:
netbird_caddy_data: netbird_caddy_data:
netbird_crdb_data:
netbird_crdb_certs:
netbird_zitadel_certs: netbird_zitadel_certs:
networks: networks:
@@ -799,4 +849,59 @@ networks:
EOF EOF
} }
renderDockerComposeCockroachDB() {
cat <<EOF
# CockroachDB for Zitadel
zdb:
restart: 'always'
networks: [netbird]
image: 'cockroachdb/cockroach:latest-v23.2'
command: 'start-single-node --advertise-addr zdb'
volumes:
- netbird_zdb_data:/cockroach/cockroach-data
- netbird_zdb_certs:/cockroach/certs
- netbird_zitadel_certs:/zitadel-certs
healthcheck:
test: [ "CMD", "curl", "-f", "http://localhost:8080/health?ready=1" ]
interval: '10s'
timeout: '30s'
retries: 5
start_period: '20s'
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
volumes:
netbird_zdb_certs:
EOF
}
renderDockerComposePostgres() {
cat <<EOF
# Postgres for Zitadel
zdb:
restart: 'always'
networks: [netbird]
image: 'postgres:16-alpine'
env_file:
- ./zdb.env
volumes:
- netbird_zdb_data:/var/lib/postgresql/data:rw
healthcheck:
test: ["CMD-SHELL", "pg_isready", "-d", "db_prod"]
interval: 5s
timeout: 60s
retries: 10
start_period: 5s
logging:
driver: "json-file"
options:
max-size: "500m"
max-file: "2"
volumes:
EOF
}
initEnvironment initEnvironment

View File

@@ -62,7 +62,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err) t.Fatal(err)
} }
s := grpc.NewServer() s := grpc.NewServer()
store, cleanUp, err := mgmt.NewTestStoreFromJson(config.Datadir) store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -70,13 +70,13 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
peersUpdateManager := mgmt.NewPeersUpdateManager(nil) peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
ia, _ := integrations.NewIntegratedValidator(eventStore) ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -20,6 +20,7 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@@ -35,8 +36,10 @@ import (
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter"
mgmtProto "github.com/netbirdio/netbird/management/proto" mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
httpapi "github.com/netbirdio/netbird/management/server/http" httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
@@ -77,6 +80,10 @@ var (
Short: "start NetBird Management Server", Short: "start NetBird Management Server",
PreRunE: func(cmd *cobra.Command, args []string) error { PreRunE: func(cmd *cobra.Command, args []string) error {
flag.Parse() flag.Parse()
//nolint
ctx := context.WithValue(cmd.Context(), formatter.ExecutionContextKey, formatter.SystemSource)
err := util.InitLog(logLevel, logFile) err := util.InitLog(logLevel, logFile)
if err != nil { if err != nil {
return fmt.Errorf("failed initializing log %v", err) return fmt.Errorf("failed initializing log %v", err)
@@ -85,7 +92,7 @@ var (
// detect whether user specified a port // detect whether user specified a port
userPort := cmd.Flag("port").Changed userPort := cmd.Flag("port").Changed
config, err = loadMgmtConfig(mgmtConfig) config, err = loadMgmtConfig(ctx, mgmtConfig)
if err != nil { if err != nil {
return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err) return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err)
} }
@@ -116,6 +123,11 @@ var (
return nil return nil
}, },
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
//nolint
ctx = context.WithValue(ctx, formatter.ExecutionContextKey, formatter.SystemSource)
err := handleRebrand(cmd) err := handleRebrand(cmd)
if err != nil { if err != nil {
return fmt.Errorf("failed to migrate files %v", err) return fmt.Errorf("failed to migrate files %v", err)
@@ -131,11 +143,11 @@ var (
if err != nil { if err != nil {
return err return err
} }
err = appMetrics.Expose(mgmtMetricsPort, "/metrics") err = appMetrics.Expose(ctx, mgmtMetricsPort, "/metrics")
if err != nil { if err != nil {
return err return err
} }
store, err := server.NewStore(config.StoreConfig.Engine, config.Datadir, appMetrics) store, err := server.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics)
if err != nil { if err != nil {
return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err) return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err)
} }
@@ -143,7 +155,7 @@ var (
var idpManager idp.Manager var idpManager idp.Manager
if config.IdpManagerConfig != nil { if config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(*config.IdpManagerConfig, appMetrics) idpManager, err = idp.NewManager(ctx, *config.IdpManagerConfig, appMetrics)
if err != nil { if err != nil {
return fmt.Errorf("failed retrieving a new idp manager with err: %v", err) return fmt.Errorf("failed retrieving a new idp manager with err: %v", err)
} }
@@ -152,32 +164,32 @@ var (
if disableSingleAccMode { if disableSingleAccMode {
mgmtSingleAccModeDomain = "" mgmtSingleAccModeDomain = ""
} }
eventStore, key, err := integrations.InitEventStore(config.Datadir, config.DataStoreEncryptionKey) eventStore, key, err := integrations.InitEventStore(ctx, config.Datadir, config.DataStoreEncryptionKey)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize database: %s", err) return fmt.Errorf("failed to initialize database: %s", err)
} }
if config.DataStoreEncryptionKey != key { if config.DataStoreEncryptionKey != key {
log.Infof("update config with activity store key") log.WithContext(ctx).Infof("update config with activity store key")
config.DataStoreEncryptionKey = key config.DataStoreEncryptionKey = key
err := updateMgmtConfig(mgmtConfig, config) err := updateMgmtConfig(ctx, mgmtConfig, config)
if err != nil { if err != nil {
return fmt.Errorf("failed to write out store encryption key: %s", err) return fmt.Errorf("failed to write out store encryption key: %s", err)
} }
} }
geo, err := geolocation.NewGeolocation(config.Datadir) geo, err := geolocation.NewGeolocation(ctx, config.Datadir)
if err != nil { if err != nil {
log.Warnf("could not initialize geo location service: %v, we proceed without geo support", err) log.WithContext(ctx).Warnf("could not initialize geo location service: %v, we proceed without geo support", err)
} else { } else {
log.Infof("geo location service has been initialized from %s", config.Datadir) log.WithContext(ctx).Infof("geo location service has been initialized from %s", config.Datadir)
} }
integratedPeerValidator, err := integrations.NewIntegratedValidator(eventStore) integratedPeerValidator, err := integrations.NewIntegratedValidator(ctx, eventStore)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize integrated peer validator: %v", err) return fmt.Errorf("failed to initialize integrated peer validator: %v", err)
} }
accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain, accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator) dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator)
if err != nil { if err != nil {
return fmt.Errorf("failed to build default manager: %v", err) return fmt.Errorf("failed to build default manager: %v", err)
@@ -188,13 +200,13 @@ var (
trustedPeers := config.ReverseProxy.TrustedPeers trustedPeers := config.ReverseProxy.TrustedPeers
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")} defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) { if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) {
log.Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.") log.WithContext(ctx).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
trustedPeers = defaultTrustedPeers trustedPeers = defaultTrustedPeers
} }
trustedHTTPProxies := config.ReverseProxy.TrustedHTTPProxies trustedHTTPProxies := config.ReverseProxy.TrustedHTTPProxies
trustedProxiesCount := config.ReverseProxy.TrustedHTTPProxiesCount trustedProxiesCount := config.ReverseProxy.TrustedHTTPProxiesCount
if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 { if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 {
log.Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " + log.WithContext(ctx).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
"This is not recommended way to extract X-Forwarded-For. Consider using one of these options.") "This is not recommended way to extract X-Forwarded-For. Consider using one of these options.")
} }
realipOpts := []realip.Option{ realipOpts := []realip.Option{
@@ -206,8 +218,8 @@ var (
gRPCOpts := []grpc.ServerOption{ gRPCOpts := []grpc.ServerOption{
grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveEnforcementPolicy(kaep),
grpc.KeepaliveParams(kasp), grpc.KeepaliveParams(kasp),
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...)), grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor),
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...)), grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
} }
var certManager *autocert.Manager var certManager *autocert.Manager
@@ -224,7 +236,7 @@ var (
} else if config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "" { } else if config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "" {
tlsConfig, err = loadTLSConfig(config.HttpConfig.CertFile, config.HttpConfig.CertKey) tlsConfig, err = loadTLSConfig(config.HttpConfig.CertFile, config.HttpConfig.CertKey)
if err != nil { if err != nil {
log.Errorf("cannot load TLS credentials: %v", err) log.WithContext(ctx).Errorf("cannot load TLS credentials: %v", err)
return err return err
} }
transportCredentials := credentials.NewTLS(tlsConfig) transportCredentials := credentials.NewTLS(tlsConfig)
@@ -233,6 +245,7 @@ var (
} }
jwtValidator, err := jwtclaims.NewJWTValidator( jwtValidator, err := jwtclaims.NewJWTValidator(
ctx,
config.HttpConfig.AuthIssuer, config.HttpConfig.AuthIssuer,
config.GetAuthAudiences(), config.GetAuthAudiences(),
config.HttpConfig.AuthKeysLocation, config.HttpConfig.AuthKeysLocation,
@@ -249,26 +262,24 @@ var (
KeysLocation: config.HttpConfig.AuthKeysLocation, KeysLocation: config.HttpConfig.AuthKeysLocation,
} }
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator) httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
if err != nil { if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err) return fmt.Errorf("failed creating HTTP API handler: %v", err)
} }
ephemeralManager := server.NewEphemeralManager(store, accountManager) ephemeralManager := server.NewEphemeralManager(store, accountManager)
ephemeralManager.LoadInitialPeers() ephemeralManager.LoadInitialPeers(ctx)
gRPCAPIHandler := grpc.NewServer(gRPCOpts...) gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, appMetrics, ephemeralManager) srv, err := server.NewServer(ctx, config, accountManager, peersUpdateManager, turnManager, appMetrics, ephemeralManager)
if err != nil { if err != nil {
return fmt.Errorf("failed creating gRPC API handler: %v", err) return fmt.Errorf("failed creating gRPC API handler: %v", err)
} }
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv) mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
installationID, err := getInstallationID(store) installationID, err := getInstallationID(ctx, store)
if err != nil { if err != nil {
log.Errorf("cannot load TLS credentials: %v", err) log.WithContext(ctx).Errorf("cannot load TLS credentials: %v", err)
return err return err
} }
@@ -278,18 +289,18 @@ var (
idpManager = config.IdpManagerConfig.ManagerType idpManager = config.IdpManagerConfig.ManagerType
} }
metricsWorker := metrics.NewWorker(ctx, installationID, store, peersUpdateManager, idpManager) metricsWorker := metrics.NewWorker(ctx, installationID, store, peersUpdateManager, idpManager)
go metricsWorker.Run() go metricsWorker.Run(ctx)
} }
var compatListener net.Listener var compatListener net.Listener
if mgmtPort != ManagementLegacyPort { if mgmtPort != ManagementLegacyPort {
// The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it // The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it
// are using port 33073. For compatibility purposes we keep running a 2nd gRPC server on port 33073. // are using port 33073. For compatibility purposes we keep running a 2nd gRPC server on port 33073.
compatListener, err = serveGRPC(gRPCAPIHandler, ManagementLegacyPort) compatListener, err = serveGRPC(ctx, gRPCAPIHandler, ManagementLegacyPort)
if err != nil { if err != nil {
return err return err
} }
log.Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String()) log.WithContext(ctx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
} }
rootHandler := handlerFunc(gRPCAPIHandler, httpAPIHandler) rootHandler := handlerFunc(gRPCAPIHandler, httpAPIHandler)
@@ -306,8 +317,8 @@ var (
if err != nil { if err != nil {
return fmt.Errorf("failed creating TLS listener on port %d: %v", mgmtPort, err) return fmt.Errorf("failed creating TLS listener on port %d: %v", mgmtPort, err)
} }
log.Infof("running HTTP server (LetsEncrypt challenge handler): %s", cml.Addr().String()) log.WithContext(ctx).Infof("running HTTP server (LetsEncrypt challenge handler): %s", cml.Addr().String())
serveHTTP(cml, certManager.HTTPHandler(nil)) serveHTTP(ctx, cml, certManager.HTTPHandler(nil))
} }
} else if tlsConfig != nil { } else if tlsConfig != nil {
listener, err = tls.Listen("tcp", fmt.Sprintf(":%d", mgmtPort), tlsConfig) listener, err = tls.Listen("tcp", fmt.Sprintf(":%d", mgmtPort), tlsConfig)
@@ -321,14 +332,14 @@ var (
} }
} }
log.Infof("management server version %s", version.NetbirdVersion()) log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion())
log.Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String()) log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String())
serveGRPCWithHTTP(listener, rootHandler, tlsEnabled) serveGRPCWithHTTP(ctx, listener, rootHandler, tlsEnabled)
SetupCloseHandler() SetupCloseHandler()
<-stopCh <-stopCh
integratedPeerValidator.Stop() integratedPeerValidator.Stop(ctx)
if geo != nil { if geo != nil {
_ = geo.Stop() _ = geo.Stop()
} }
@@ -339,39 +350,68 @@ var (
_ = certManager.Listener().Close() _ = certManager.Listener().Close()
} }
gRPCAPIHandler.Stop() gRPCAPIHandler.Stop()
_ = store.Close() _ = store.Close(ctx)
_ = eventStore.Close() _ = eventStore.Close(ctx)
log.Infof("stopped Management Service") log.WithContext(ctx).Infof("stopped Management Service")
return nil return nil
}, },
} }
) )
func notifyStop(msg string) { func unaryInterceptor(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
reqID := uuid.New().String()
//nolint
ctx = context.WithValue(ctx, formatter.ExecutionContextKey, formatter.GRPCSource)
//nolint
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
return handler(ctx, req)
}
func streamInterceptor(
srv interface{},
ss grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
reqID := uuid.New().String()
wrapped := grpcMiddleware.WrapServerStream(ss)
//nolint
ctx := context.WithValue(ss.Context(), formatter.ExecutionContextKey, formatter.GRPCSource)
//nolint
wrapped.WrappedContext = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
return handler(srv, wrapped)
}
func notifyStop(ctx context.Context, msg string) {
select { select {
case stopCh <- 1: case stopCh <- 1:
log.Error(msg) log.WithContext(ctx).Error(msg)
default: default:
// stop has been already called, nothing to report // stop has been already called, nothing to report
} }
} }
func getInstallationID(store server.Store) (string, error) { func getInstallationID(ctx context.Context, store server.Store) (string, error) {
installationID := store.GetInstallationID() installationID := store.GetInstallationID()
if installationID != "" { if installationID != "" {
return installationID, nil return installationID, nil
} }
installationID = strings.ToUpper(uuid.New().String()) installationID = strings.ToUpper(uuid.New().String())
err := store.SaveInstallationID(installationID) err := store.SaveInstallationID(ctx, installationID)
if err != nil { if err != nil {
return "", err return "", err
} }
return installationID, nil return installationID, nil
} }
func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) { func serveGRPC(ctx context.Context, grpcServer *grpc.Server, port int) (net.Listener, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil { if err != nil {
return nil, err return nil, err
@@ -379,22 +419,22 @@ func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) {
go func() { go func() {
err := grpcServer.Serve(listener) err := grpcServer.Serve(listener)
if err != nil { if err != nil {
notifyStop(fmt.Sprintf("failed running gRPC server on port %d: %v", port, err)) notifyStop(ctx, fmt.Sprintf("failed running gRPC server on port %d: %v", port, err))
} }
}() }()
return listener, nil return listener, nil
} }
func serveHTTP(httpListener net.Listener, handler http.Handler) { func serveHTTP(ctx context.Context, httpListener net.Listener, handler http.Handler) {
go func() { go func() {
err := http.Serve(httpListener, handler) err := http.Serve(httpListener, handler)
if err != nil { if err != nil {
notifyStop(fmt.Sprintf("failed running HTTP server: %v", err)) notifyStop(ctx, fmt.Sprintf("failed running HTTP server: %v", err))
} }
}() }()
} }
func serveGRPCWithHTTP(listener net.Listener, handler http.Handler, tlsEnabled bool) { func serveGRPCWithHTTP(ctx context.Context, listener net.Listener, handler http.Handler, tlsEnabled bool) {
go func() { go func() {
var err error var err error
if tlsEnabled { if tlsEnabled {
@@ -411,7 +451,7 @@ func serveGRPCWithHTTP(listener net.Listener, handler http.Handler, tlsEnabled b
if err != nil { if err != nil {
select { select {
case stopCh <- 1: case stopCh <- 1:
log.Errorf("failed to serve HTTP and gRPC server: %v", err) log.WithContext(ctx).Errorf("failed to serve HTTP and gRPC server: %v", err)
default: default:
// stop has been already called, nothing to report // stop has been already called, nothing to report
} }
@@ -431,7 +471,7 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle
}) })
} }
func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) { func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) {
loadedConfig := &server.Config{} loadedConfig := &server.Config{}
_, err := util.ReadJson(mgmtConfigPath, loadedConfig) _, err := util.ReadJson(mgmtConfigPath, loadedConfig)
if err != nil { if err != nil {
@@ -452,26 +492,26 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
oidcEndpoint := loadedConfig.HttpConfig.OIDCConfigEndpoint oidcEndpoint := loadedConfig.HttpConfig.OIDCConfigEndpoint
if oidcEndpoint != "" { if oidcEndpoint != "" {
// if OIDCConfigEndpoint is specified, we can load DeviceAuthEndpoint and TokenEndpoint automatically // if OIDCConfigEndpoint is specified, we can load DeviceAuthEndpoint and TokenEndpoint automatically
log.Infof("loading OIDC configuration from the provided IDP configuration endpoint %s", oidcEndpoint) log.WithContext(ctx).Infof("loading OIDC configuration from the provided IDP configuration endpoint %s", oidcEndpoint)
oidcConfig, err := fetchOIDCConfig(oidcEndpoint) oidcConfig, err := fetchOIDCConfig(ctx, oidcEndpoint)
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint) log.WithContext(ctx).Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint)
log.Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s", log.WithContext(ctx).Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s",
oidcConfig.Issuer, loadedConfig.HttpConfig.AuthIssuer) oidcConfig.Issuer, loadedConfig.HttpConfig.AuthIssuer)
loadedConfig.HttpConfig.AuthIssuer = oidcConfig.Issuer loadedConfig.HttpConfig.AuthIssuer = oidcConfig.Issuer
log.Infof("overriding HttpConfig.AuthKeysLocation (JWT certs) with a new value %s, previously configured value: %s", log.WithContext(ctx).Infof("overriding HttpConfig.AuthKeysLocation (JWT certs) with a new value %s, previously configured value: %s",
oidcConfig.JwksURI, loadedConfig.HttpConfig.AuthKeysLocation) oidcConfig.JwksURI, loadedConfig.HttpConfig.AuthKeysLocation)
loadedConfig.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI loadedConfig.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(server.NONE)) { if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(server.NONE)) {
log.Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s", log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.TokenEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint) oidcConfig.TokenEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint)
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
log.Infof("overriding DeviceAuthorizationFlow.DeviceAuthEndpoint with a new value: %s, previously configured value: %s", log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.DeviceAuthEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.DeviceAuthEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint) oidcConfig.DeviceAuthEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint)
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint = oidcConfig.DeviceAuthEndpoint loadedConfig.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint = oidcConfig.DeviceAuthEndpoint
@@ -479,7 +519,7 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s", log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s",
u.Host, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain) u.Host, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain)
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host
@@ -489,10 +529,10 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
} }
if loadedConfig.PKCEAuthorizationFlow != nil { if loadedConfig.PKCEAuthorizationFlow != nil {
log.Infof("overriding PKCEAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s", log.WithContext(ctx).Infof("overriding PKCEAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.TokenEndpoint, loadedConfig.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint) oidcConfig.TokenEndpoint, loadedConfig.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint)
loadedConfig.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint loadedConfig.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
log.Infof("overriding PKCEAuthorizationFlow.AuthorizationEndpoint with a new value: %s, previously configured value: %s", log.WithContext(ctx).Infof("overriding PKCEAuthorizationFlow.AuthorizationEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.AuthorizationEndpoint, loadedConfig.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint) oidcConfig.AuthorizationEndpoint, loadedConfig.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint)
loadedConfig.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint = oidcConfig.AuthorizationEndpoint loadedConfig.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint = oidcConfig.AuthorizationEndpoint
} }
@@ -501,8 +541,8 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
return loadedConfig, err return loadedConfig, err
} }
func updateMgmtConfig(path string, config *server.Config) error { func updateMgmtConfig(ctx context.Context, path string, config *server.Config) error {
return util.DirectWriteJson(path, config) return util.DirectWriteJson(ctx, path, config)
} }
// OIDCConfigResponse used for parsing OIDC config response // OIDCConfigResponse used for parsing OIDC config response
@@ -515,7 +555,7 @@ type OIDCConfigResponse struct {
} }
// fetchOIDCConfig fetches OIDC configuration from the IDP // fetchOIDCConfig fetches OIDC configuration from the IDP
func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) { func fetchOIDCConfig(ctx context.Context, oidcEndpoint string) (OIDCConfigResponse, error) {
res, err := http.Get(oidcEndpoint) res, err := http.Get(oidcEndpoint)
if err != nil { if err != nil {
return OIDCConfigResponse{}, fmt.Errorf("failed fetching OIDC configuration from endpoint %s %v", oidcEndpoint, err) return OIDCConfigResponse{}, fmt.Errorf("failed fetching OIDC configuration from endpoint %s %v", oidcEndpoint, err)
@@ -524,7 +564,7 @@ func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) {
defer func() { defer func() {
err := res.Body.Close() err := res.Body.Close()
if err != nil { if err != nil {
log.Debugf("failed closing response body %v", err) log.WithContext(ctx).Debugf("failed closing response body %v", err)
} }
}() }()

View File

@@ -1,13 +1,16 @@
package cmd package cmd
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/util"
) )
var shortUp = "Migrate JSON file store to SQLite store. Please make a backup of the JSON file before running this command." var shortUp = "Migrate JSON file store to SQLite store. Please make a backup of the JSON file before running this command."
@@ -26,10 +29,13 @@ var upCmd = &cobra.Command{
return fmt.Errorf("failed initializing log %v", err) return fmt.Errorf("failed initializing log %v", err)
} }
if err := server.MigrateFileStoreToSqlite(mgmtDataDir); err != nil { //nolint
ctx := context.WithValue(cmd.Context(), formatter.ExecutionContextKey, formatter.SystemSource)
if err := server.MigrateFileStoreToSqlite(ctx, mgmtDataDir); err != nil {
return err return err
} }
log.Info("Migration finished successfully") log.WithContext(ctx).Info("Migration finished successfully")
return nil return nil
}, },

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"crypto/sha256" "crypto/sha256"
b64 "encoding/base64" b64 "encoding/base64"
"encoding/json" "encoding/json"
@@ -29,11 +30,11 @@ import (
type MocIntegratedValidator struct { type MocIntegratedValidator struct {
} }
func (a MocIntegratedValidator) ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
return nil return nil
} }
func (a MocIntegratedValidator) ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) { func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
return update, nil return update, nil
} }
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
@@ -44,15 +45,15 @@ func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[s
return validatedPeers, nil return validatedPeers, nil
} }
func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
return peer return peer
} }
func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) { func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
return false, false, nil return false, false, nil
} }
func (MocIntegratedValidator) PeerDeleted(_, _ string) error { func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
return nil return nil
} }
@@ -60,7 +61,7 @@ func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)
} }
func (MocIntegratedValidator) Stop() { func (MocIntegratedValidator) Stop(_ context.Context) {
} }
func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) { func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) {
@@ -85,7 +86,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Ac
setupKey = key.Key setupKey = key.Key
} }
_, _, err := manager.AddPeer(setupKey, userID, peer) _, _, _, err := manager.AddPeer(context.Background(), setupKey, userID, peer)
if err != nil { if err != nil {
t.Error("expected to add new peer successfully after creating new account, but failed", err) t.Error("expected to add new peer successfully after creating new account, but failed", err)
} }
@@ -395,7 +396,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
} }
for _, testCase := range tt { for _, testCase := range tt {
account := newAccountWithId("account-1", userID, "netbird.io") account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io")
account.UpdateSettings(&testCase.accountSettings) account.UpdateSettings(&testCase.accountSettings)
account.Network = network account.Network = network
account.Peers = testCase.peers account.Peers = testCase.peers
@@ -409,7 +410,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
validatedPeers[p] = struct{}{} validatedPeers[p] = struct{}{}
} }
networkMap := account.GetPeerNetworkMap(testCase.peerID, "netbird.io", validatedPeers) networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, "netbird.io", validatedPeers)
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
} }
@@ -419,7 +420,7 @@ func TestNewAccount(t *testing.T) {
domain := "netbird.io" domain := "netbird.io"
userId := "account_creator" userId := "account_creator"
accountID := "account_id" accountID := "account_id"
account := newAccountWithId(accountID, userId, domain) account := newAccountWithId(context.Background(), accountID, userId, domain)
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId}) verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
} }
@@ -430,7 +431,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
return return
} }
account, err := manager.GetOrCreateAccountByUser(userID, "") account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -439,7 +440,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
return return
} }
account, err = manager.Store.GetAccountByUser(userID) account, err = manager.Store.GetAccountByUser(context.Background(), userID)
if err != nil { if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userID) t.Errorf("expected to get existing account after creation, no account was found for a user %s", userID)
return return
@@ -630,11 +631,11 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
initAccount, err := manager.GetAccountByUserOrAccountID(testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) initAccount, err := manager.GetAccountByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
require.NoError(t, err, "create init user failed") require.NoError(t, err, "create init user failed")
if testCase.inputUpdateAttrs { if testCase.inputUpdateAttrs {
err = manager.updateAccountDomainAttributes(initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
require.NoError(t, err, "update init user failed") require.NoError(t, err, "update init user failed")
} }
@@ -642,7 +643,7 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
testCase.inputClaims.AccountId = initAccount.Id testCase.inputClaims.AccountId = initAccount.Id
} }
account, _, err := manager.GetAccountFromToken(testCase.inputClaims) account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims)
require.NoError(t, err, "support function failed") require.NoError(t, err, "support function failed")
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers) verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy) verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
@@ -661,12 +662,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
userId := "user-id" userId := "user-id"
domain := "test.domain" domain := "test.domain"
initAccount := newAccountWithId("", userId, domain) initAccount := newAccountWithId(context.Background(), "", userId, domain)
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID := initAccount.Id accountID := initAccount.Id
acc, err := manager.GetAccountByUserOrAccountID(userId, accountID, domain) acc, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, accountID, domain)
require.NoError(t, err, "create init user failed") require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization // as initAccount was created without account id we have to take the id after account initialization
// that happens inside the GetAccountByUserOrAccountID where the id is getting generated // that happens inside the GetAccountByUserOrAccountID where the id is getting generated
@@ -682,18 +683,18 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
} }
t.Run("JWT groups disabled", func(t *testing.T) { t.Run("JWT groups disabled", func(t *testing.T) {
account, _, err := manager.GetAccountFromToken(claims) account, _, err := manager.GetAccountFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed") require.NoError(t, err, "get account by token failed")
require.Len(t, account.Groups, 1, "only ALL group should exists") require.Len(t, account.Groups, 1, "only ALL group should exists")
}) })
t.Run("JWT groups enabled without claim name", func(t *testing.T) { t.Run("JWT groups enabled without claim name", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true initAccount.Settings.JWTGroupsEnabled = true
err := manager.Store.SaveAccount(initAccount) err := manager.Store.SaveAccount(context.Background(), initAccount)
require.NoError(t, err, "save account failed") require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(), 1, "only one account should exist") require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
account, _, err := manager.GetAccountFromToken(claims) account, _, err := manager.GetAccountFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed") require.NoError(t, err, "get account by token failed")
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT") require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
}) })
@@ -701,11 +702,11 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
t.Run("JWT groups enabled", func(t *testing.T) { t.Run("JWT groups enabled", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true initAccount.Settings.JWTGroupsEnabled = true
initAccount.Settings.JWTGroupsClaimName = "idp-groups" initAccount.Settings.JWTGroupsClaimName = "idp-groups"
err := manager.Store.SaveAccount(initAccount) err := manager.Store.SaveAccount(context.Background(), initAccount)
require.NoError(t, err, "save account failed") require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(), 1, "only one account should exist") require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
account, _, err := manager.GetAccountFromToken(claims) account, _, err := manager.GetAccountFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed") require.NoError(t, err, "get account by token failed")
require.Len(t, account.Groups, 3, "groups should be added to the account") require.Len(t, account.Groups, 3, "groups should be added to the account")
@@ -728,7 +729,7 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
func TestAccountManager_GetAccountFromPAT(t *testing.T) { func TestAccountManager_GetAccountFromPAT(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId("account_id", "testuser", "") account := newAccountWithId(context.Background(), "account_id", "testuser", "")
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token)) hashedToken := sha256.Sum256([]byte(token))
@@ -742,7 +743,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
}, },
}, },
} }
err := store.SaveAccount(account) err := store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
t.Fatalf("Error when saving account: %s", err) t.Fatalf("Error when saving account: %s", err)
} }
@@ -751,7 +752,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
Store: store, Store: store,
} }
account, user, pat, err := am.GetAccountFromPAT(token) account, user, pat, err := am.GetAccountFromPAT(context.Background(), token)
if err != nil { if err != nil {
t.Fatalf("Error when getting Account from PAT: %s", err) t.Fatalf("Error when getting Account from PAT: %s", err)
} }
@@ -763,7 +764,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId("account_id", "testuser", "") account := newAccountWithId(context.Background(), "account_id", "testuser", "")
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token)) hashedToken := sha256.Sum256([]byte(token))
@@ -778,7 +779,7 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
}, },
}, },
} }
err := store.SaveAccount(account) err := store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
t.Fatalf("Error when saving account: %s", err) t.Fatalf("Error when saving account: %s", err)
} }
@@ -787,12 +788,12 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
Store: store, Store: store,
} }
err = am.MarkPATUsed("tokenId") err = am.MarkPATUsed(context.Background(), "tokenId")
if err != nil { if err != nil {
t.Fatalf("Error when marking PAT used: %s", err) t.Fatalf("Error when marking PAT used: %s", err)
} }
account, err = am.Store.GetAccount("account_id") account, err = am.Store.GetAccount(context.Background(), "account_id")
if err != nil { if err != nil {
t.Fatalf("Error when getting account: %s", err) t.Fatalf("Error when getting account: %s", err)
} }
@@ -807,7 +808,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
} }
userId := "test_user" userId := "test_user"
account, err := manager.GetOrCreateAccountByUser(userId, "") account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -815,7 +816,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
t.Fatalf("expected to create an account for a user %s", userId) t.Fatalf("expected to create an account for a user %s", userId)
} }
account, err = manager.Store.GetAccountByUser(userId) account, err = manager.Store.GetAccountByUser(context.Background(), userId)
if err != nil { if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId) t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId)
} }
@@ -834,7 +835,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
userId := "test_user" userId := "test_user"
domain := "hotmail.com" domain := "hotmail.com"
account, err := manager.GetOrCreateAccountByUser(userId, domain) account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, domain)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -848,7 +849,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
domain = "gmail.com" domain = "gmail.com"
account, err = manager.GetOrCreateAccountByUser(userId, domain) account, err = manager.GetOrCreateAccountByUser(context.Background(), userId, domain)
if err != nil { if err != nil {
t.Fatalf("got the following error while retrieving existing acc: %v", err) t.Fatalf("got the following error while retrieving existing acc: %v", err)
} }
@@ -871,7 +872,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
userId := "test_user" userId := "test_user"
account, err := manager.GetAccountByUserOrAccountID(userId, "", "") account, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -880,20 +881,20 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
return return
} }
_, err = manager.GetAccountByUserOrAccountID("", account.Id, "") _, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
if err != nil { if err != nil {
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id) t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id)
} }
_, err = manager.GetAccountByUserOrAccountID("", "", "") _, err = manager.GetAccountByUserOrAccountID(context.Background(), "", "", "")
if err == nil { if err == nil {
t.Errorf("expected an error when user and account IDs are empty") t.Errorf("expected an error when user and account IDs are empty")
} }
} }
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) { func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) {
account := newAccountWithId(accountID, userID, domain) account := newAccountWithId(context.Background(), accountID, userID, domain)
err := am.Store.SaveAccount(account) err := am.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -915,7 +916,7 @@ func TestAccountManager_GetAccount(t *testing.T) {
} }
// AddAccount has been already tested so we can assume it is correct and compare results // AddAccount has been already tested so we can assume it is correct and compare results
getAccount, err := manager.Store.GetAccount(account.Id) getAccount, err := manager.Store.GetAccount(context.Background(), account.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -952,12 +953,12 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = manager.DeleteAccount(account.Id, userId) err = manager.DeleteAccount(context.Background(), account.Id, userId)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
getAccount, err := manager.Store.GetAccount(account.Id) getAccount, err := manager.Store.GetAccount(context.Background(), account.Id)
if err == nil { if err == nil {
t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount)) t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount))
} }
@@ -978,7 +979,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
serial := account.Network.CurrentSerial() // should be 0 serial := account.Network.CurrentSerial() // should be 0
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
if err != nil { if err != nil {
t.Fatal("error creating setup key") t.Fatal("error creating setup key")
return return
@@ -997,7 +998,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
expectedPeerKey := key.PublicKey().String() expectedPeerKey := key.PublicKey().String()
expectedSetupKey := setupKey.Key expectedSetupKey := setupKey.Key
peer, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey, Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
}) })
@@ -1006,7 +1007,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
return return
} }
account, err = manager.Store.GetAccount(account.Id) account, err = manager.Store.GetAccount(context.Background(), account.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -1045,7 +1046,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
return return
} }
account, err := manager.GetOrCreateAccountByUser(userID, "netbird.cloud") account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "netbird.cloud")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -1065,7 +1066,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
expectedPeerKey := key.PublicKey().String() expectedPeerKey := key.PublicKey().String()
expectedUserID := userID expectedUserID := userID
peer, _, err := manager.AddPeer("", userID, &nbpeer.Peer{ peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
Key: expectedPeerKey, Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
}) })
@@ -1074,7 +1075,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
return return
} }
account, err = manager.Store.GetAccount(account.Id) account, err = manager.Store.GetAccount(context.Background(), account.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -1121,7 +1122,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
if err != nil { if err != nil {
t.Fatal("error creating setup key") t.Fatal("error creating setup key")
return return
@@ -1140,7 +1141,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
} }
expectedPeerKey := key.PublicKey().String() expectedPeerKey := key.PublicKey().String()
peer, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey, Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
}) })
@@ -1156,14 +1157,14 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
peer2 := getPeer() peer2 := getPeer()
peer3 := getPeer() peer3 := getPeer()
account, err = manager.Store.GetAccount(account.Id) account, err = manager.Store.GetAccount(context.Background(), account.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
} }
updMsg := manager.peersUpdateManager.CreateChannel(peer1.ID) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
group := group.Group{ group := group.Group{
ID: "group-id", ID: "group-id",
@@ -1197,7 +1198,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
} }
}() }()
if err := manager.SaveGroup(account.Id, userID, &group); err != nil { if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
t.Errorf("save group: %v", err) t.Errorf("save group: %v", err)
return return
} }
@@ -1217,7 +1218,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
} }
}() }()
if err := manager.DeletePolicy(account.Id, account.Policies[0].ID, userID); err != nil { if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
t.Errorf("delete default rule: %v", err) t.Errorf("delete default rule: %v", err)
return return
} }
@@ -1237,7 +1238,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
} }
}() }()
if err := manager.SavePolicy(account.Id, userID, &policy); err != nil { if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil {
t.Errorf("delete default rule: %v", err) t.Errorf("delete default rule: %v", err)
return return
} }
@@ -1256,7 +1257,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
} }
}() }()
if err := manager.DeletePeer(account.Id, peer3.ID, userID); err != nil { if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil {
t.Errorf("delete peer: %v", err) t.Errorf("delete peer: %v", err)
return return
} }
@@ -1277,9 +1278,9 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
}() }()
// clean policy is pre requirement for delete group // clean policy is pre requirement for delete group
_ = manager.DeletePolicy(account.Id, policy.ID, userID) _ = manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID)
if err := manager.DeleteGroup(account.Id, "", group.ID); err != nil { if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
t.Errorf("delete group: %v", err) t.Errorf("delete group: %v", err)
return return
} }
@@ -1301,7 +1302,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
if err != nil { if err != nil {
t.Fatal("error creating setup key") t.Fatal("error creating setup key")
return return
@@ -1315,7 +1316,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
peerKey := key.PublicKey().String() peerKey := key.PublicKey().String()
peer, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
Key: peerKey, Key: peerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: peerKey}, Meta: nbpeer.PeerSystemMeta{Hostname: peerKey},
}) })
@@ -1324,12 +1325,12 @@ func TestAccountManager_DeletePeer(t *testing.T) {
return return
} }
err = manager.DeletePeer(account.Id, peerKey, userID) err = manager.DeletePeer(context.Background(), account.Id, peerKey, userID)
if err != nil { if err != nil {
return return
} }
account, err = manager.Store.GetAccount(account.Id) account, err = manager.Store.GetAccount(context.Background(), account.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -1357,7 +1358,7 @@ func getEvent(t *testing.T, accountID string, manager AccountManager, eventType
case <-time.After(time.Second): case <-time.After(time.Second):
t.Fatal("no PeerAddedWithSetupKey event was generated") t.Fatal("no PeerAddedWithSetupKey event was generated")
default: default:
events, err := manager.GetEvents(accountID, userID) events, err := manager.GetEvents(context.Background(), accountID, userID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -1389,7 +1390,7 @@ func TestGetUsersFromAccount(t *testing.T) {
account.Users[user.Id] = user account.Users[user.Id] = user
} }
userInfos, err := manager.GetUsersFromAccount(accountId, "1") userInfos, err := manager.GetUsersFromAccount(context.Background(), accountId, "1")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -1500,7 +1501,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
}, },
} }
routes := account.getRoutesToSync("peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}) routes := account.getRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
assert.Len(t, routes, 2) assert.Len(t, routes, 2)
routeIDs := make(map[route.ID]struct{}, 2) routeIDs := make(map[route.ID]struct{}, 2)
@@ -1510,7 +1511,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
assert.Contains(t, routeIDs, route.ID("route-2")) assert.Contains(t, routeIDs, route.ID("route-2"))
assert.Contains(t, routeIDs, route.ID("route-3")) assert.Contains(t, routeIDs, route.ID("route-3"))
emptyRoutes := account.getRoutesToSync("peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}) emptyRoutes := account.getRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
assert.Len(t, emptyRoutes, 0) assert.Len(t, emptyRoutes, 0)
} }
@@ -1645,7 +1646,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "") account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
assert.NotNil(t, account.Settings) assert.NotNil(t, account.Settings)
@@ -1657,23 +1658,23 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountByUserOrAccountID(userID, "", "") _, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key") require.NoError(t, err, "unable to generate WireGuard key")
peer, _, err := manager.AddPeer("", userID, &nbpeer.Peer{ peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
Key: key.PublicKey().String(), Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true, LoginExpirationEnabled: true,
}) })
require.NoError(t, err, "unable to add peer") require.NoError(t, err, "unable to add peer")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "") account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account") require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account) err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected") require.NoError(t, err, "unable to mark peer connected")
account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour, PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true, PeerLoginExpirationEnabled: true,
}) })
@@ -1682,10 +1683,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
wg.Add(2) wg.Add(2)
manager.peerLoginExpiry = &MockScheduler{ manager.peerLoginExpiry = &MockScheduler{
CancelFunc: func(IDs []string) { CancelFunc: func(ctx context.Context, IDs []string) {
wg.Done() wg.Done()
}, },
ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
wg.Done() wg.Done()
}, },
} }
@@ -1693,11 +1694,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
// disable expiration first // disable expiration first
update := peer.Copy() update := peer.Copy()
update.LoginExpirationEnabled = false update.LoginExpirationEnabled = false
_, err = manager.UpdatePeer(account.Id, userID, update) _, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
require.NoError(t, err, "unable to update peer") require.NoError(t, err, "unable to update peer")
// enabling expiration should trigger the routine // enabling expiration should trigger the routine
update.LoginExpirationEnabled = true update.LoginExpirationEnabled = true
_, err = manager.UpdatePeer(account.Id, userID, update) _, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
require.NoError(t, err, "unable to update peer") require.NoError(t, err, "unable to update peer")
failed := waitTimeout(wg, time.Second) failed := waitTimeout(wg, time.Second)
@@ -1710,18 +1711,18 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "") account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key") require.NoError(t, err, "unable to generate WireGuard key")
_, _, err = manager.AddPeer("", userID, &nbpeer.Peer{ _, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
Key: key.PublicKey().String(), Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true, LoginExpirationEnabled: true,
}) })
require.NoError(t, err, "unable to add peer") require.NoError(t, err, "unable to add peer")
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour, PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true, PeerLoginExpirationEnabled: true,
}) })
@@ -1730,18 +1731,18 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
wg.Add(2) wg.Add(2)
manager.peerLoginExpiry = &MockScheduler{ manager.peerLoginExpiry = &MockScheduler{
CancelFunc: func(IDs []string) { CancelFunc: func(ctx context.Context, IDs []string) {
wg.Done() wg.Done()
}, },
ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
wg.Done() wg.Done()
}, },
} }
account, err = manager.GetAccountByUserOrAccountID(userID, "", "") account, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account") require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger // when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account) err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected") require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second) failed := waitTimeout(wg, time.Second)
@@ -1754,35 +1755,35 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountByUserOrAccountID(userID, "", "") _, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key") require.NoError(t, err, "unable to generate WireGuard key")
_, _, err = manager.AddPeer("", userID, &nbpeer.Peer{ _, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
Key: key.PublicKey().String(), Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true, LoginExpirationEnabled: true,
}) })
require.NoError(t, err, "unable to add peer") require.NoError(t, err, "unable to add peer")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "") account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account") require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account) err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected") require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
wg.Add(2) wg.Add(2)
manager.peerLoginExpiry = &MockScheduler{ manager.peerLoginExpiry = &MockScheduler{
CancelFunc: func(IDs []string) { CancelFunc: func(ctx context.Context, IDs []string) {
wg.Done() wg.Done()
}, },
ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
wg.Done() wg.Done()
}, },
} }
// enabling PeerLoginExpirationEnabled should trigger the expiration job // enabling PeerLoginExpirationEnabled should trigger the expiration job
account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour, PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true, PeerLoginExpirationEnabled: true,
}) })
@@ -1795,7 +1796,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
wg.Add(1) wg.Add(1)
// disabling PeerLoginExpirationEnabled should trigger cancel // disabling PeerLoginExpirationEnabled should trigger cancel
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour, PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false, PeerLoginExpirationEnabled: false,
}) })
@@ -1810,10 +1811,10 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "") account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
updated, err := manager.UpdateAccountSettings(account.Id, userID, &Settings{ updated, err := manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour, PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false, PeerLoginExpirationEnabled: false,
}) })
@@ -1821,19 +1822,19 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
assert.False(t, updated.Settings.PeerLoginExpirationEnabled) assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour) assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
account, err = manager.GetAccountByUserOrAccountID("", account.Id, "") account, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
require.NoError(t, err, "unable to get account by ID") require.NoError(t, err, "unable to get account by ID")
assert.False(t, account.Settings.PeerLoginExpirationEnabled) assert.False(t, account.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, account.Settings.PeerLoginExpiration, time.Hour) assert.Equal(t, account.Settings.PeerLoginExpiration, time.Hour)
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
PeerLoginExpiration: time.Second, PeerLoginExpiration: time.Second,
PeerLoginExpirationEnabled: false, PeerLoginExpirationEnabled: false,
}) })
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour") require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour * 24 * 181, PeerLoginExpiration: time.Hour * 24 * 181,
PeerLoginExpirationEnabled: false, PeerLoginExpirationEnabled: false,
}) })
@@ -2294,7 +2295,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
} }
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
manager, err := BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}) manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -2305,7 +2306,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
func createStore(t *testing.T) (Store, error) { func createStore(t *testing.T) (Store, error) {
t.Helper() t.Helper()
dataDir := t.TempDir() dataDir := t.TempDir()
store, cleanUp, err := NewTestStoreFromJson(dataDir) store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -1,6 +1,7 @@
package sqlite package sqlite
import ( import (
"context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
@@ -86,7 +87,7 @@ type Store struct {
} }
// NewSQLiteStore creates a new Store with an event table if not exists. // NewSQLiteStore creates a new Store with an event table if not exists.
func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) { func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (*Store, error) {
dbFile := filepath.Join(dataDir, eventSinkDB) dbFile := filepath.Join(dataDir, eventSinkDB)
db, err := sql.Open("sqlite3", dbFile) db, err := sql.Open("sqlite3", dbFile)
if err != nil { if err != nil {
@@ -111,7 +112,7 @@ func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) {
return nil, err return nil, err
} }
err = updateDeletedUsersTable(db) err = updateDeletedUsersTable(ctx, db)
if err != nil { if err != nil {
_ = db.Close() _ = db.Close()
return nil, err return nil, err
@@ -153,7 +154,7 @@ func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) {
return s, nil return s, nil
} }
func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) { func (store *Store) processResult(ctx context.Context, result *sql.Rows) ([]*activity.Event, error) {
events := make([]*activity.Event, 0) events := make([]*activity.Event, 0)
var cryptErr error var cryptErr error
for result.Next() { for result.Next() {
@@ -235,14 +236,14 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
} }
if cryptErr != nil { if cryptErr != nil {
log.Warnf("%s", cryptErr) log.WithContext(ctx).Warnf("%s", cryptErr)
} }
return events, nil return events, nil
} }
// Get returns "limit" number of events from index ordered descending or ascending by a timestamp // Get returns "limit" number of events from index ordered descending or ascending by a timestamp
func (store *Store) Get(accountID string, offset, limit int, descending bool) ([]*activity.Event, error) { func (store *Store) Get(ctx context.Context, accountID string, offset, limit int, descending bool) ([]*activity.Event, error) {
stmt := store.selectDescStatement stmt := store.selectDescStatement
if !descending { if !descending {
stmt = store.selectAscStatement stmt = store.selectAscStatement
@@ -254,11 +255,11 @@ func (store *Store) Get(accountID string, offset, limit int, descending bool) ([
} }
defer result.Close() //nolint defer result.Close() //nolint
return store.processResult(result) return store.processResult(ctx, result)
} }
// Save an event in the SQLite events table end encrypt the "email" element in meta map // Save an event in the SQLite events table end encrypt the "email" element in meta map
func (store *Store) Save(event *activity.Event) (*activity.Event, error) { func (store *Store) Save(_ context.Context, event *activity.Event) (*activity.Event, error) {
var jsonMeta string var jsonMeta string
meta, err := store.saveDeletedUserEmailAndNameInEncrypted(event) meta, err := store.saveDeletedUserEmailAndNameInEncrypted(event)
if err != nil { if err != nil {
@@ -317,15 +318,15 @@ func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event
} }
// Close the Store // Close the Store
func (store *Store) Close() error { func (store *Store) Close(_ context.Context) error {
if store.db != nil { if store.db != nil {
return store.db.Close() return store.db.Close()
} }
return nil return nil
} }
func updateDeletedUsersTable(db *sql.DB) error { func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error {
log.Debugf("check deleted_users table version") log.WithContext(ctx).Debugf("check deleted_users table version")
rows, err := db.Query(`PRAGMA table_info(deleted_users);`) rows, err := db.Query(`PRAGMA table_info(deleted_users);`)
if err != nil { if err != nil {
return err return err
@@ -360,7 +361,7 @@ func updateDeletedUsersTable(db *sql.DB) error {
return nil return nil
} }
log.Debugf("update delted_users table") log.WithContext(ctx).Debugf("update delted_users table")
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`) _, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
return err return err
} }

View File

@@ -1,6 +1,7 @@
package sqlite package sqlite
import ( import (
"context"
"fmt" "fmt"
"testing" "testing"
"time" "time"
@@ -13,17 +14,17 @@ import (
func TestNewSQLiteStore(t *testing.T) { func TestNewSQLiteStore(t *testing.T) {
dataDir := t.TempDir() dataDir := t.TempDir()
key, _ := GenerateKey() key, _ := GenerateKey()
store, err := NewSQLiteStore(dataDir, key) store, err := NewSQLiteStore(context.Background(), dataDir, key)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
} }
defer store.Close() //nolint defer store.Close(context.Background()) //nolint
accountID := "account_1" accountID := "account_1"
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
_, err = store.Save(&activity.Event{ _, err = store.Save(context.Background(), &activity.Event{
Timestamp: time.Now().UTC(), Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser, Activity: activity.PeerAddedByUser,
InitiatorID: "user_" + fmt.Sprint(i), InitiatorID: "user_" + fmt.Sprint(i),
@@ -36,7 +37,7 @@ func TestNewSQLiteStore(t *testing.T) {
} }
} }
result, err := store.Get(accountID, 0, 10, false) result, err := store.Get(context.Background(), accountID, 0, 10, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -45,7 +46,7 @@ func TestNewSQLiteStore(t *testing.T) {
assert.Len(t, result, 10) assert.Len(t, result, 10)
assert.True(t, result[0].Timestamp.Before(result[len(result)-1].Timestamp)) assert.True(t, result[0].Timestamp.Before(result[len(result)-1].Timestamp))
result, err = store.Get(accountID, 0, 5, true) result, err = store.Get(context.Background(), accountID, 0, 5, true)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return

View File

@@ -1,15 +1,18 @@
package activity package activity
import "sync" import (
"context"
"sync"
)
// Store provides an interface to store or stream events. // Store provides an interface to store or stream events.
type Store interface { type Store interface {
// Save an event in the store // Save an event in the store
Save(event *Event) (*Event, error) Save(ctx context.Context, event *Event) (*Event, error)
// Get returns "limit" number of events from the "offset" index ordered descending or ascending by a timestamp // Get returns "limit" number of events from the "offset" index ordered descending or ascending by a timestamp
Get(accountID string, offset, limit int, descending bool) ([]*Event, error) Get(ctx context.Context, accountID string, offset, limit int, descending bool) ([]*Event, error)
// Close the sink flushing events if necessary // Close the sink flushing events if necessary
Close() error Close(ctx context.Context) error
} }
// InMemoryEventStore implements the Store interface storing data in-memory // InMemoryEventStore implements the Store interface storing data in-memory
@@ -20,7 +23,7 @@ type InMemoryEventStore struct {
} }
// Save sets the Event.ID to 1 // Save sets the Event.ID to 1
func (store *InMemoryEventStore) Save(event *Event) (*Event, error) { func (store *InMemoryEventStore) Save(_ context.Context, event *Event) (*Event, error) {
store.mu.Lock() store.mu.Lock()
defer store.mu.Unlock() defer store.mu.Unlock()
if store.events == nil { if store.events == nil {
@@ -33,7 +36,7 @@ func (store *InMemoryEventStore) Save(event *Event) (*Event, error) {
} }
// Get returns a list of ALL events that belong to the given accountID without taking offset, limit and order into consideration // Get returns a list of ALL events that belong to the given accountID without taking offset, limit and order into consideration
func (store *InMemoryEventStore) Get(accountID string, offset, limit int, descending bool) ([]*Event, error) { func (store *InMemoryEventStore) Get(_ context.Context, accountID string, offset, limit int, descending bool) ([]*Event, error) {
store.mu.Lock() store.mu.Lock()
defer store.mu.Unlock() defer store.mu.Unlock()
events := make([]*Event, 0) events := make([]*Event, 0)
@@ -46,7 +49,7 @@ func (store *InMemoryEventStore) Get(accountID string, offset, limit int, descen
} }
// Close cleans up the event list // Close cleans up the event list
func (store *InMemoryEventStore) Close() error { func (store *InMemoryEventStore) Close(_ context.Context) error {
store.mu.Lock() store.mu.Lock()
defer store.mu.Unlock() defer store.mu.Unlock()
store.events = make([]*Event, 0) store.events = make([]*Event, 0)

View File

@@ -0,0 +1,8 @@
package context
const (
RequestIDKey = "requestID"
AccountIDKey = "accountID"
UserIDKey = "userID"
PeerIDKey = "peerID"
)

View File

@@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"fmt" "fmt"
"strconv" "strconv"
@@ -34,11 +35,11 @@ func (d DNSSettings) Copy() DNSSettings {
} }
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID // GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string) (*DNSSettings, error) { func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID) unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -56,11 +57,11 @@ func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string)
} }
// SaveDNSSettings validates a user role and updates the account's DNS settings // SaveDNSSettings validates a user role and updates the account's DNS settings
func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error { func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
unlock := am.Store.AcquireAccountWriteLock(accountID) unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -89,7 +90,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string
account.DNSSettings = dnsSettingsToSave.Copy() account.DNSSettings = dnsSettingsToSave.Copy()
account.Network.IncSerial() account.Network.IncSerial()
if err = am.Store.SaveAccount(account); err != nil { if err = am.Store.SaveAccount(ctx, account); err != nil {
return err return err
} }
@@ -97,17 +98,18 @@ func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string
for _, id := range addedGroups { for _, id := range addedGroups {
group := account.GetGroup(id) group := account.GetGroup(id)
meta := map[string]any{"group": group.Name, "group_id": group.ID} meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta) am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
} }
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
for _, id := range removedGroups { for _, id := range removedGroups {
group := account.GetGroup(id) group := account.GetGroup(id)
meta := map[string]any{"group": group.Name, "group_id": group.ID} meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
} }
am.updateAccountPeers(account) // todo: check if before/after groups are in use by dns, acl, routes and if it has peers
am.updateAccountPeers(ctx, account)
return nil return nil
} }
@@ -149,9 +151,9 @@ func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig {
return protoUpdate return protoUpdate
} }
func getPeersCustomZone(account *Account, dnsDomain string) nbdns.CustomZone { func getPeersCustomZone(ctx context.Context, account *Account, dnsDomain string) nbdns.CustomZone {
if dnsDomain == "" { if dnsDomain == "" {
log.Errorf("no dns domain is set, returning empty zone") log.WithContext(ctx).Errorf("no dns domain is set, returning empty zone")
return nbdns.CustomZone{} return nbdns.CustomZone{}
} }
@@ -161,7 +163,7 @@ func getPeersCustomZone(account *Account, dnsDomain string) nbdns.CustomZone {
for _, peer := range account.Peers { for _, peer := range account.Peers {
if peer.DNSLabel == "" { if peer.DNSLabel == "" {
log.Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name) log.WithContext(ctx).Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name)
continue continue
} }
@@ -210,14 +212,14 @@ func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
return false return false
} }
func addPeerLabelsToAccount(account *Account, peerLabels lookupMap) { func addPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels lookupMap) {
for _, peer := range account.Peers { for _, peer := range account.Peers {
label, err := getPeerHostLabel(peer.Name, peerLabels) label, err := getPeerHostLabel(peer.Name, peerLabels)
if err != nil { if err != nil {
log.Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err) log.WithContext(ctx).Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err)
label, err = getPeerHostLabel(peer.Meta.Hostname, peerLabels) label, err = getPeerHostLabel(peer.Meta.Hostname, peerLabels)
if err != nil { if err != nil {
log.Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err) log.WithContext(ctx).Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err)
continue continue
} }
} }

View File

@@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"net/netip" "net/netip"
"testing" "testing"
@@ -35,7 +36,7 @@ func TestGetDNSSettings(t *testing.T) {
t.Fatal("failed to init testing account") t.Fatal("failed to init testing account")
} }
dnsSettings, err := am.GetDNSSettings(account.Id, dnsAdminUserID) dnsSettings, err := am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID)
if err != nil { if err != nil {
t.Fatalf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err) t.Fatalf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err)
} }
@@ -48,12 +49,12 @@ func TestGetDNSSettings(t *testing.T) {
DisabledManagementGroups: []string{group1ID}, DisabledManagementGroups: []string{group1ID},
} }
err = am.Store.SaveAccount(account) err = am.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
t.Error("failed to save testing account with new DNS settings") t.Error("failed to save testing account with new DNS settings")
} }
dnsSettings, err = am.GetDNSSettings(account.Id, dnsAdminUserID) dnsSettings, err = am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID)
if err != nil { if err != nil {
t.Errorf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err) t.Errorf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err)
} }
@@ -62,7 +63,7 @@ func TestGetDNSSettings(t *testing.T) {
t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups) t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups)
} }
_, err = am.GetDNSSettings(account.Id, dnsRegularUserID) _, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID)
if err == nil { if err == nil {
t.Errorf("An error should be returned when getting the DNS settings with a regular user") t.Errorf("An error should be returned when getting the DNS settings with a regular user")
} }
@@ -122,7 +123,7 @@ func TestSaveDNSSettings(t *testing.T) {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
err = am.SaveDNSSettings(account.Id, testCase.userID, testCase.inputSettings) err = am.SaveDNSSettings(context.Background(), account.Id, testCase.userID, testCase.inputSettings)
if err != nil { if err != nil {
if testCase.shouldFail { if testCase.shouldFail {
return return
@@ -130,7 +131,7 @@ func TestSaveDNSSettings(t *testing.T) {
t.Error(err) t.Error(err)
} }
updatedAccount, err := am.Store.GetAccount(account.Id) updatedAccount, err := am.Store.GetAccount(context.Background(), account.Id)
if err != nil { if err != nil {
t.Errorf("should be able to retrieve updated account, got err: %s", err) t.Errorf("should be able to retrieve updated account, got err: %s", err)
} }
@@ -164,7 +165,7 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
newAccountDNSConfig, err := am.GetNetworkMap(peer1.ID) newAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS config should have one custom zone for peers") require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS config should have one custom zone for peers")
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled") require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled")
@@ -173,14 +174,14 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
dnsSettings := account.DNSSettings.Copy() dnsSettings := account.DNSSettings.Copy()
dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID) dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID)
account.DNSSettings = dnsSettings account.DNSSettings = dnsSettings
err = am.Store.SaveAccount(account) err = am.Store.SaveAccount(context.Background(), account)
require.NoError(t, err) require.NoError(t, err)
updatedAccountDNSConfig, err := am.GetNetworkMap(peer1.ID) updatedAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS config should have no custom zone when peer belongs to a disabled group") require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS config should have no custom zone when peer belongs to a disabled group")
require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS config should have local DNS service disabled when peer belongs to a disabled group") require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS config should have local DNS service disabled when peer belongs to a disabled group")
peer2AccountDNSConfig, err := am.GetNetworkMap(peer2.ID) peer2AccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer2.ID)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS config should have one custom zone for peers not in the disabled group") require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS config should have one custom zone for peers not in the disabled group")
require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS config should have DNS service enabled for peers not in the disabled group") require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS config should have DNS service enabled for peers not in the disabled group")
@@ -194,13 +195,13 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err return nil, err
} }
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}) return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{})
} }
func createDNSStore(t *testing.T) (Store, error) { func createDNSStore(t *testing.T) (Store, error) {
t.Helper() t.Helper()
dataDir := t.TempDir() dataDir := t.TempDir()
store, cleanUp, err := NewTestStoreFromJson(dataDir) store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -244,28 +245,28 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
domain := "example.com" domain := "example.com"
account := newAccountWithId(dnsAccountID, dnsAdminUserID, domain) account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain)
account.Users[dnsRegularUserID] = &User{ account.Users[dnsRegularUserID] = &User{
Id: dnsRegularUserID, Id: dnsRegularUserID,
Role: UserRoleUser, Role: UserRoleUser,
} }
err := am.Store.SaveAccount(account) err := am.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
savedPeer1, _, err := am.AddPeer("", dnsAdminUserID, peer1) savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1)
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, _, err = am.AddPeer("", dnsAdminUserID, peer2) _, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2)
if err != nil { if err != nil {
return nil, err return nil, err
} }
account, err = am.Store.GetAccount(account.Id) account, err = am.Store.GetAccount(context.Background(), account.Id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -312,10 +313,10 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
Groups: []string{allGroup.ID}, Groups: []string{allGroup.ID},
} }
err = am.Store.SaveAccount(account) err = am.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return am.Store.GetAccount(account.Id) return am.Store.GetAccount(context.Background(), account.Id)
} }

View File

@@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"sync" "sync"
"time" "time"
@@ -51,13 +52,15 @@ func NewEphemeralManager(store Store, accountManager AccountManager) *EphemeralM
// LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head // LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head
// of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new // of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new
// head. // head.
func (e *EphemeralManager) LoadInitialPeers() { func (e *EphemeralManager) LoadInitialPeers(ctx context.Context) {
e.peersLock.Lock() e.peersLock.Lock()
defer e.peersLock.Unlock() defer e.peersLock.Unlock()
e.loadEphemeralPeers() e.loadEphemeralPeers(ctx)
if e.headPeer != nil { if e.headPeer != nil {
e.timer = time.AfterFunc(ephemeralLifeTime, e.cleanup) e.timer = time.AfterFunc(ephemeralLifeTime, func() {
e.cleanup(ctx)
})
} }
} }
@@ -73,12 +76,12 @@ func (e *EphemeralManager) Stop() {
// OnPeerConnected remove the peer from the linked list of ephemeral peers. Because it has been called when the peer // OnPeerConnected remove the peer from the linked list of ephemeral peers. Because it has been called when the peer
// is active the manager will not delete it while it is active. // is active the manager will not delete it while it is active.
func (e *EphemeralManager) OnPeerConnected(peer *nbpeer.Peer) { func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Peer) {
if !peer.Ephemeral { if !peer.Ephemeral {
return return
} }
log.Tracef("remove peer from ephemeral list: %s", peer.ID) log.WithContext(ctx).Tracef("remove peer from ephemeral list: %s", peer.ID)
e.peersLock.Lock() e.peersLock.Lock()
defer e.peersLock.Unlock() defer e.peersLock.Unlock()
@@ -94,16 +97,16 @@ func (e *EphemeralManager) OnPeerConnected(peer *nbpeer.Peer) {
// OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer // OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer
// is inactive it will be deleted after the ephemeralLifeTime period. // is inactive it will be deleted after the ephemeralLifeTime period.
func (e *EphemeralManager) OnPeerDisconnected(peer *nbpeer.Peer) { func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) {
if !peer.Ephemeral { if !peer.Ephemeral {
return return
} }
log.Tracef("add peer to ephemeral list: %s", peer.ID) log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID)
a, err := e.store.GetAccountByPeerID(peer.ID) a, err := e.store.GetAccountByPeerID(context.Background(), peer.ID)
if err != nil { if err != nil {
log.Errorf("failed to add peer to ephemeral list: %s", err) log.WithContext(ctx).Errorf("failed to add peer to ephemeral list: %s", err)
return return
} }
@@ -116,12 +119,14 @@ func (e *EphemeralManager) OnPeerDisconnected(peer *nbpeer.Peer) {
e.addPeer(peer.ID, a, newDeadLine()) e.addPeer(peer.ID, a, newDeadLine())
if e.timer == nil { if e.timer == nil {
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), e.cleanup) e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() {
e.cleanup(ctx)
})
} }
} }
func (e *EphemeralManager) loadEphemeralPeers() { func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
accounts := e.store.GetAllAccounts() accounts := e.store.GetAllAccounts(context.Background())
t := newDeadLine() t := newDeadLine()
count := 0 count := 0
for _, a := range accounts { for _, a := range accounts {
@@ -132,10 +137,10 @@ func (e *EphemeralManager) loadEphemeralPeers() {
} }
} }
} }
log.Debugf("loaded ephemeral peer(s): %d", count) log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count)
} }
func (e *EphemeralManager) cleanup() { func (e *EphemeralManager) cleanup(ctx context.Context) {
log.Tracef("on ephemeral cleanup") log.Tracef("on ephemeral cleanup")
deletePeers := make(map[string]*ephemeralPeer) deletePeers := make(map[string]*ephemeralPeer)
@@ -154,7 +159,9 @@ func (e *EphemeralManager) cleanup() {
} }
if e.headPeer != nil { if e.headPeer != nil {
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), e.cleanup) e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() {
e.cleanup(ctx)
})
} else { } else {
e.timer = nil e.timer = nil
} }
@@ -162,10 +169,10 @@ func (e *EphemeralManager) cleanup() {
e.peersLock.Unlock() e.peersLock.Unlock()
for id, p := range deletePeers { for id, p := range deletePeers {
log.Debugf("delete ephemeral peer: %s", id) log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id)
err := e.accountManager.DeletePeer(p.account.Id, id, activity.SystemInitiator) err := e.accountManager.DeletePeer(ctx, p.account.Id, id, activity.SystemInitiator)
if err != nil { if err != nil {
log.Errorf("failed to delete ephemeral peer: %s", err) log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
} }
} }
} }

View File

@@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"fmt" "fmt"
"testing" "testing"
"time" "time"
@@ -13,11 +14,11 @@ type MockStore struct {
account *Account account *Account
} }
func (s *MockStore) GetAllAccounts() []*Account { func (s *MockStore) GetAllAccounts(_ context.Context) []*Account {
return []*Account{s.account} return []*Account{s.account}
} }
func (s *MockStore) GetAccountByPeerID(peerId string) (*Account, error) { func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Account, error) {
_, ok := s.account.Peers[peerId] _, ok := s.account.Peers[peerId]
if ok { if ok {
return s.account, nil return s.account, nil
@@ -31,7 +32,7 @@ type MocAccountManager struct {
store *MockStore store *MockStore
} }
func (a MocAccountManager) DeletePeer(accountID, peerID, userID string) error { func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error {
delete(a.store.account.Peers, peerID) delete(a.store.account.Peers, peerID)
return nil //nolint:nil return nil //nolint:nil
} }
@@ -52,9 +53,9 @@ func TestNewManager(t *testing.T) {
seedPeers(store, numberOfPeers, numberOfEphemeralPeers) seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
mgr := NewEphemeralManager(store, am) mgr := NewEphemeralManager(store, am)
mgr.loadEphemeralPeers() mgr.loadEphemeralPeers(context.Background())
startTime = startTime.Add(ephemeralLifeTime + 1) startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup() mgr.cleanup(context.Background())
if len(store.account.Peers) != numberOfPeers { if len(store.account.Peers) != numberOfPeers {
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers)) t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers))
@@ -77,11 +78,11 @@ func TestNewManagerPeerConnected(t *testing.T) {
seedPeers(store, numberOfPeers, numberOfEphemeralPeers) seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
mgr := NewEphemeralManager(store, am) mgr := NewEphemeralManager(store, am)
mgr.loadEphemeralPeers() mgr.loadEphemeralPeers(context.Background())
mgr.OnPeerConnected(store.account.Peers["ephemeral_peer_0"]) mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
startTime = startTime.Add(ephemeralLifeTime + 1) startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup() mgr.cleanup(context.Background())
expected := numberOfPeers + 1 expected := numberOfPeers + 1
if len(store.account.Peers) != expected { if len(store.account.Peers) != expected {
@@ -105,15 +106,15 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
seedPeers(store, numberOfPeers, numberOfEphemeralPeers) seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
mgr := NewEphemeralManager(store, am) mgr := NewEphemeralManager(store, am)
mgr.loadEphemeralPeers() mgr.loadEphemeralPeers(context.Background())
for _, v := range store.account.Peers { for _, v := range store.account.Peers {
mgr.OnPeerConnected(v) mgr.OnPeerConnected(context.Background(), v)
} }
mgr.OnPeerDisconnected(store.account.Peers["ephemeral_peer_0"]) mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
startTime = startTime.Add(ephemeralLifeTime + 1) startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup() mgr.cleanup(context.Background())
expected := numberOfPeers + numberOfEphemeralPeers - 1 expected := numberOfPeers + numberOfEphemeralPeers - 1
if len(store.account.Peers) != expected { if len(store.account.Peers) != expected {
@@ -122,7 +123,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
} }
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) { func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) {
store.account = newAccountWithId("my account", "", "") store.account = newAccountWithId(context.Background(), "my account", "", "")
for i := 0; i < numberOfPeers; i++ { for i := 0; i < numberOfPeers; i++ {
peerId := fmt.Sprintf("peer_%d", i) peerId := fmt.Sprintf("peer_%d", i)

View File

@@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"fmt" "fmt"
"time" "time"
@@ -11,11 +12,11 @@ import (
) )
// GetEvents returns a list of activity events of an account // GetEvents returns a list of activity events of an account
func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activity.Event, error) { func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID) unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -29,7 +30,7 @@ func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activit
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view events") return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view events")
} }
events, err := am.eventStore.Get(accountID, 0, 10000, true) events, err := am.eventStore.Get(ctx, accountID, 0, 10000, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -54,10 +55,10 @@ func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activit
return filtered, nil return filtered, nil
} }
func (am *DefaultAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { func (am *DefaultAccountManager) StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
go func() { go func() {
_, err := am.eventStore.Save(&activity.Event{ _, err := am.eventStore.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(), Timestamp: time.Now().UTC(),
Activity: activityID, Activity: activityID,
InitiatorID: initiatorID, InitiatorID: initiatorID,
@@ -67,7 +68,7 @@ func (am *DefaultAccountManager) StoreEvent(initiatorID, targetID, accountID str
}) })
if err != nil { if err != nil {
// todo add metric // todo add metric
log.Errorf("received an error while storing an activity event, error: %s", err) log.WithContext(ctx).Errorf("received an error while storing an activity event, error: %s", err)
} }
}() }()

View File

@@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"testing" "testing"
"time" "time"
@@ -13,7 +14,7 @@ func generateAndStoreEvents(t *testing.T, manager *DefaultAccountManager, typ ac
accountID string, count int) { accountID string, count int) {
t.Helper() t.Helper()
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
_, err := manager.eventStore.Save(&activity.Event{ _, err := manager.eventStore.Save(context.Background(), &activity.Event{
Timestamp: time.Now().UTC(), Timestamp: time.Now().UTC(),
Activity: typ, Activity: typ,
InitiatorID: initiatorID, InitiatorID: initiatorID,
@@ -35,32 +36,32 @@ func TestDefaultAccountManager_GetEvents(t *testing.T) {
accountID := "accountID" accountID := "accountID"
t.Run("get empty events list", func(t *testing.T) { t.Run("get empty events list", func(t *testing.T) {
events, err := manager.GetEvents(accountID, userID) events, err := manager.GetEvents(context.Background(), accountID, userID)
if err != nil { if err != nil {
return return
} }
assert.Len(t, events, 0) assert.Len(t, events, 0)
_ = manager.eventStore.Close() //nolint _ = manager.eventStore.Close(context.Background()) //nolint
}) })
t.Run("get events", func(t *testing.T) { t.Run("get events", func(t *testing.T) {
generateAndStoreEvents(t, manager, activity.PeerAddedByUser, userID, "peer", accountID, 10) generateAndStoreEvents(t, manager, activity.PeerAddedByUser, userID, "peer", accountID, 10)
events, err := manager.GetEvents(accountID, userID) events, err := manager.GetEvents(context.Background(), accountID, userID)
if err != nil { if err != nil {
return return
} }
assert.Len(t, events, 10) assert.Len(t, events, 10)
_ = manager.eventStore.Close() //nolint _ = manager.eventStore.Close(context.Background()) //nolint
}) })
t.Run("get events without duplicates", func(t *testing.T) { t.Run("get events without duplicates", func(t *testing.T) {
generateAndStoreEvents(t, manager, activity.UserJoined, userID, "", accountID, 10) generateAndStoreEvents(t, manager, activity.UserJoined, userID, "", accountID, 10)
events, err := manager.GetEvents(accountID, userID) events, err := manager.GetEvents(context.Background(), accountID, userID)
if err != nil { if err != nil {
return return
} }
assert.Len(t, events, 1) assert.Len(t, events, 1)
_ = manager.eventStore.Close() //nolint _ = manager.eventStore.Close(context.Background()) //nolint
}) })
} }

View File

@@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@@ -48,8 +49,8 @@ type FileStore struct {
type StoredAccount struct{} type StoredAccount struct{}
// NewFileStore restores a store from the file located in the datadir // NewFileStore restores a store from the file located in the datadir
func NewFileStore(dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) { func NewFileStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
fs, err := restore(filepath.Join(dataDir, storeFileName)) fs, err := restore(ctx, filepath.Join(dataDir, storeFileName))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -58,27 +59,27 @@ func NewFileStore(dataDir string, metrics telemetry.AppMetrics) (*FileStore, err
} }
// NewFilestoreFromSqliteStore restores a store from Sqlite and stores to Filestore json in the file located in datadir // NewFilestoreFromSqliteStore restores a store from Sqlite and stores to Filestore json in the file located in datadir
func NewFilestoreFromSqliteStore(sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) { func NewFilestoreFromSqliteStore(ctx context.Context, sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
store, err := NewFileStore(dataDir, metrics) store, err := NewFileStore(ctx, dataDir, metrics)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = store.SaveInstallationID(sqlStore.GetInstallationID()) err = store.SaveInstallationID(ctx, sqlStore.GetInstallationID())
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, account := range sqlStore.GetAllAccounts() { for _, account := range sqlStore.GetAllAccounts(ctx) {
store.Accounts[account.Id] = account store.Accounts[account.Id] = account
} }
return store, store.persist(store.storeFile) return store, store.persist(ctx, store.storeFile)
} }
// restore the state of the store from the file. // restore the state of the store from the file.
// Creates a new empty store file if doesn't exist // Creates a new empty store file if doesn't exist
func restore(file string) (*FileStore, error) { func restore(ctx context.Context, file string) (*FileStore, error) {
if _, err := os.Stat(file); os.IsNotExist(err) { if _, err := os.Stat(file); os.IsNotExist(err) {
// create a new FileStore if previously didn't exist (e.g. first run) // create a new FileStore if previously didn't exist (e.g. first run)
s := &FileStore{ s := &FileStore{
@@ -95,7 +96,7 @@ func restore(file string) (*FileStore, error) {
storeFile: file, storeFile: file,
} }
err = s.persist(file) err = s.persist(ctx, file)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -165,7 +166,7 @@ func restore(file string) (*FileStore, error) {
// for data migration. Can be removed once most base will be with labels // for data migration. Can be removed once most base will be with labels
existingLabels := account.getPeerDNSLabels() existingLabels := account.getPeerDNSLabels()
if len(existingLabels) != len(account.Peers) { if len(existingLabels) != len(account.Peers) {
addPeerLabelsToAccount(account, existingLabels) addPeerLabelsToAccount(ctx, account, existingLabels)
} }
// TODO: delete this block after migration // TODO: delete this block after migration
@@ -178,7 +179,7 @@ func restore(file string) (*FileStore, error) {
allGroup, err := account.GetGroupAll() allGroup, err := account.GetGroupAll()
if err != nil { if err != nil {
log.Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err) log.WithContext(ctx).Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err)
// if the All group didn't exist we probably don't have routes to update // if the All group didn't exist we probably don't have routes to update
continue continue
} }
@@ -236,7 +237,7 @@ func restore(file string) (*FileStore, error) {
} }
// we need this persist to apply changes we made to account.Peers (we set them to Disconnected) // we need this persist to apply changes we made to account.Peers (we set them to Disconnected)
err = store.persist(store.storeFile) err = store.persist(ctx, store.storeFile)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -246,7 +247,7 @@ func restore(file string) (*FileStore, error) {
// persist account data to a file // persist account data to a file
// It is recommended to call it with locking FileStore.mux // It is recommended to call it with locking FileStore.mux
func (s *FileStore) persist(file string) error { func (s *FileStore) persist(ctx context.Context, file string) error {
start := time.Now() start := time.Now()
err := util.WriteJson(file, s) err := util.WriteJson(file, s)
if err != nil { if err != nil {
@@ -256,23 +257,23 @@ func (s *FileStore) persist(file string) error {
if s.metrics != nil { if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took) s.metrics.StoreMetrics().CountPersistenceDuration(took)
} }
log.Debugf("took %d ms to persist the FileStore", took.Milliseconds()) log.WithContext(ctx).Debugf("took %d ms to persist the FileStore", took.Milliseconds())
return nil return nil
} }
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock // AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
func (s *FileStore) AcquireGlobalLock() (unlock func()) { func (s *FileStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
log.Debugf("acquiring global lock") log.WithContext(ctx).Debugf("acquiring global lock")
start := time.Now() start := time.Now()
s.globalAccountLock.Lock() s.globalAccountLock.Lock()
unlock = func() { unlock = func() {
s.globalAccountLock.Unlock() s.globalAccountLock.Unlock()
log.Debugf("released global lock in %v", time.Since(start)) log.WithContext(ctx).Debugf("released global lock in %v", time.Since(start))
} }
took := time.Since(start) took := time.Since(start)
log.Debugf("took %v to acquire global lock", took) log.WithContext(ctx).Debugf("took %v to acquire global lock", took)
if s.metrics != nil { if s.metrics != nil {
s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took) s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took)
} }
@@ -281,8 +282,8 @@ func (s *FileStore) AcquireGlobalLock() (unlock func()) {
} }
// AcquireAccountWriteLock acquires account lock for writing to a resource and returns a function that releases the lock // AcquireAccountWriteLock acquires account lock for writing to a resource and returns a function that releases the lock
func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) { func (s *FileStore) AcquireAccountWriteLock(ctx context.Context, accountID string) (unlock func()) {
log.Debugf("acquiring lock for account %s", accountID) log.WithContext(ctx).Debugf("acquiring lock for account %s", accountID)
start := time.Now() start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{}) value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
mtx := value.(*sync.Mutex) mtx := value.(*sync.Mutex)
@@ -290,7 +291,7 @@ func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
unlock = func() { unlock = func() {
mtx.Unlock() mtx.Unlock()
log.Debugf("released lock for account %s in %v", accountID, time.Since(start)) log.WithContext(ctx).Debugf("released lock for account %s in %v", accountID, time.Since(start))
} }
return unlock return unlock
@@ -298,11 +299,11 @@ func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
// AcquireAccountReadLock AcquireAccountWriteLock acquires account lock for reading a resource and returns a function that releases the lock // AcquireAccountReadLock AcquireAccountWriteLock acquires account lock for reading a resource and returns a function that releases the lock
// This method is still returns a write lock as file store can't handle read locks // This method is still returns a write lock as file store can't handle read locks
func (s *FileStore) AcquireAccountReadLock(accountID string) (unlock func()) { func (s *FileStore) AcquireAccountReadLock(ctx context.Context, accountID string) (unlock func()) {
return s.AcquireAccountWriteLock(accountID) return s.AcquireAccountWriteLock(ctx, accountID)
} }
func (s *FileStore) SaveAccount(account *Account) error { func (s *FileStore) SaveAccount(ctx context.Context, account *Account) error {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@@ -338,10 +339,10 @@ func (s *FileStore) SaveAccount(account *Account) error {
s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id
} }
return s.persist(s.storeFile) return s.persist(ctx, s.storeFile)
} }
func (s *FileStore) DeleteAccount(account *Account) error { func (s *FileStore) DeleteAccount(ctx context.Context, account *Account) error {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@@ -373,7 +374,7 @@ func (s *FileStore) DeleteAccount(account *Account) error {
delete(s.Accounts, account.Id) delete(s.Accounts, account.Id)
return s.persist(s.storeFile) return s.persist(ctx, s.storeFile)
} }
// DeleteHashedPAT2TokenIDIndex removes an entry from the indexing map HashedPAT2TokenID // DeleteHashedPAT2TokenIDIndex removes an entry from the indexing map HashedPAT2TokenID
@@ -397,7 +398,7 @@ func (s *FileStore) DeleteTokenID2UserIDIndex(tokenID string) error {
} }
// GetAccountByPrivateDomain returns account by private domain // GetAccountByPrivateDomain returns account by private domain
func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) { func (s *FileStore) GetAccountByPrivateDomain(_ context.Context, domain string) (*Account, error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@@ -415,7 +416,7 @@ func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
} }
// GetAccountBySetupKey returns account by setup key id // GetAccountBySetupKey returns account by setup key id
func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) { func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*Account, error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@@ -433,7 +434,7 @@ func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
} }
// GetTokenIDByHashedToken returns the id of a personal access token by its hashed secret // GetTokenIDByHashedToken returns the id of a personal access token by its hashed secret
func (s *FileStore) GetTokenIDByHashedToken(token string) (string, error) { func (s *FileStore) GetTokenIDByHashedToken(_ context.Context, token string) (string, error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@@ -446,7 +447,7 @@ func (s *FileStore) GetTokenIDByHashedToken(token string) (string, error) {
} }
// GetUserByTokenID returns a User object a tokenID belongs to // GetUserByTokenID returns a User object a tokenID belongs to
func (s *FileStore) GetUserByTokenID(tokenID string) (*User, error) { func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User, error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@@ -469,7 +470,7 @@ func (s *FileStore) GetUserByTokenID(tokenID string) (*User, error) {
} }
// GetAllAccounts returns all accounts // GetAllAccounts returns all accounts
func (s *FileStore) GetAllAccounts() (all []*Account) { func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
for _, a := range s.Accounts { for _, a := range s.Accounts {
@@ -490,7 +491,7 @@ func (s *FileStore) getAccount(accountID string) (*Account, error) {
} }
// GetAccount returns an account for ID // GetAccount returns an account for ID
func (s *FileStore) GetAccount(accountID string) (*Account, error) { func (s *FileStore) GetAccount(_ context.Context, accountID string) (*Account, error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@@ -503,7 +504,7 @@ func (s *FileStore) GetAccount(accountID string) (*Account, error) {
} }
// GetAccountByUser returns a user account // GetAccountByUser returns a user account
func (s *FileStore) GetAccountByUser(userID string) (*Account, error) { func (s *FileStore) GetAccountByUser(_ context.Context, userID string) (*Account, error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@@ -521,7 +522,7 @@ func (s *FileStore) GetAccountByUser(userID string) (*Account, error) {
} }
// GetAccountByPeerID returns an account for a given peer ID // GetAccountByPeerID returns an account for a given peer ID
func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) { func (s *FileStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@@ -539,7 +540,7 @@ func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) {
// check Account.Peers for a match // check Account.Peers for a match
if _, ok := account.Peers[peerID]; !ok { if _, ok := account.Peers[peerID]; !ok {
delete(s.PeerID2AccountID, peerID) delete(s.PeerID2AccountID, peerID)
log.Warnf("removed stale peerID %s to accountID %s index", peerID, accountID) log.WithContext(ctx).Warnf("removed stale peerID %s to accountID %s index", peerID, accountID)
return nil, status.NewPeerNotFoundError(peerID) return nil, status.NewPeerNotFoundError(peerID)
} }
@@ -547,7 +548,7 @@ func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) {
} }
// GetAccountByPeerPubKey returns an account for a given peer WireGuard public key // GetAccountByPeerPubKey returns an account for a given peer WireGuard public key
func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { func (s *FileStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@@ -572,14 +573,14 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
} }
if stale { if stale {
delete(s.PeerKeyID2AccountID, peerKey) delete(s.PeerKeyID2AccountID, peerKey)
log.Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID) log.WithContext(ctx).Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID)
return nil, status.NewPeerNotFoundError(peerKey) return nil, status.NewPeerNotFoundError(peerKey)
} }
return account.Copy(), nil return account.Copy(), nil
} }
func (s *FileStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) { func (s *FileStore) GetAccountIDByPeerPubKey(_ context.Context, peerKey string) (string, error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@@ -603,7 +604,7 @@ func (s *FileStore) GetAccountIDByUserID(userID string) (string, error) {
return accountID, nil return accountID, nil
} }
func (s *FileStore) GetAccountIDBySetupKey(setupKey string) (string, error) { func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (string, error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@@ -615,7 +616,7 @@ func (s *FileStore) GetAccountIDBySetupKey(setupKey string) (string, error) {
return accountID, nil return accountID, nil
} }
func (s *FileStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) { func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbpeer.Peer, error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@@ -638,7 +639,7 @@ func (s *FileStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) {
return nil, status.NewPeerNotFoundError(peerKey) return nil, status.NewPeerNotFoundError(peerKey)
} }
func (s *FileStore) GetAccountSettings(accountID string) (*Settings, error) { func (s *FileStore) GetAccountSettings(_ context.Context, accountID string) (*Settings, error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@@ -656,13 +657,13 @@ func (s *FileStore) GetInstallationID() string {
} }
// SaveInstallationID saves the installation ID // SaveInstallationID saves the installation ID
func (s *FileStore) SaveInstallationID(ID string) error { func (s *FileStore) SaveInstallationID(ctx context.Context, ID string) error {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
s.InstallationID = ID s.InstallationID = ID
return s.persist(s.storeFile) return s.persist(ctx, s.storeFile)
} }
// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things. // SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things.
@@ -732,13 +733,13 @@ func (s *FileStore) GetPostureCheckByChecksDefinition(accountID string, checks *
} }
// Close the FileStore persisting data to disk // Close the FileStore persisting data to disk
func (s *FileStore) Close() error { func (s *FileStore) Close(ctx context.Context) error {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
log.Infof("closing FileStore") log.WithContext(ctx).Infof("closing FileStore")
return s.persist(s.storeFile) return s.persist(ctx, s.storeFile)
} }
// GetStoreEngine returns FileStoreEngine // GetStoreEngine returns FileStoreEngine

View File

@@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"crypto/sha256" "crypto/sha256"
"net" "net"
"path/filepath" "path/filepath"
@@ -27,12 +28,12 @@ func TestStalePeerIndices(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
store, err := NewFileStore(storeDir, nil) store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil { if err != nil {
return return
} }
account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b") account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err) require.NoError(t, err)
peerID := "some_peer" peerID := "some_peer"
@@ -42,24 +43,24 @@ func TestStalePeerIndices(t *testing.T) {
Key: peerKey, Key: peerKey,
} }
err = store.SaveAccount(account) err = store.SaveAccount(context.Background(), account)
require.NoError(t, err) require.NoError(t, err)
account.DeletePeer(peerID) account.DeletePeer(peerID)
err = store.SaveAccount(account) err = store.SaveAccount(context.Background(), account)
require.NoError(t, err) require.NoError(t, err)
_, err = store.GetAccountByPeerID(peerID) _, err = store.GetAccountByPeerID(context.Background(), peerID)
require.Error(t, err, "expecting to get an error when found stale index") require.Error(t, err, "expecting to get an error when found stale index")
_, err = store.GetAccountByPeerPubKey(peerKey) _, err = store.GetAccountByPeerPubKey(context.Background(), peerKey)
require.Error(t, err, "expecting to get an error when found stale index") require.Error(t, err, "expecting to get an error when found stale index")
} }
func TestNewStore(t *testing.T) { func TestNewStore(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close() defer store.Close(context.Background())
if store.Accounts == nil || len(store.Accounts) != 0 { if store.Accounts == nil || len(store.Accounts) != 0 {
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
@@ -88,9 +89,9 @@ func TestNewStore(t *testing.T) {
func TestSaveAccount(t *testing.T) { func TestSaveAccount(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close() defer store.Close(context.Background())
account := newAccountWithId("account_id", "testuser", "") account := newAccountWithId(context.Background(), "account_id", "testuser", "")
setupKey := GenerateDefaultSetupKey() setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &nbpeer.Peer{ account.Peers["testpeer"] = &nbpeer.Peer{
@@ -103,7 +104,7 @@ func TestSaveAccount(t *testing.T) {
} }
// SaveAccount should trigger persist // SaveAccount should trigger persist
err := store.SaveAccount(account) err := store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return return
} }
@@ -133,11 +134,11 @@ func TestDeleteAccount(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
store, err := NewFileStore(storeDir, nil) store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer store.Close() defer store.Close(context.Background())
var account *Account var account *Account
for _, a := range store.Accounts { for _, a := range store.Accounts {
@@ -147,7 +148,7 @@ func TestDeleteAccount(t *testing.T) {
require.NotNil(t, account, "failed to restore a FileStore file and get at least one account") require.NotNil(t, account, "failed to restore a FileStore file and get at least one account")
err = store.DeleteAccount(account) err = store.DeleteAccount(context.Background(), account)
require.NoError(t, err, "failed to delete account, error: %v", err) require.NoError(t, err, "failed to delete account, error: %v", err)
_, ok := store.Accounts[account.Id] _, ok := store.Accounts[account.Id]
@@ -183,9 +184,9 @@ func TestDeleteAccount(t *testing.T) {
func TestStore(t *testing.T) { func TestStore(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close() defer store.Close(context.Background())
account := newAccountWithId("account_id", "testuser", "") account := newAccountWithId(context.Background(), "account_id", "testuser", "")
account.Peers["testpeer"] = &nbpeer.Peer{ account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey", Key: "peerkey",
SetupKey: "peerkeysetupkey", SetupKey: "peerkeysetupkey",
@@ -228,12 +229,12 @@ func TestStore(t *testing.T) {
}) })
// SaveAccount should trigger persist // SaveAccount should trigger persist
err := store.SaveAccount(account) err := store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return return
} }
restored, err := NewFileStore(store.storeFile, nil) restored, err := NewFileStore(context.Background(), store.storeFile, nil)
if err != nil { if err != nil {
return return
} }
@@ -281,7 +282,7 @@ func TestRestore(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
store, err := NewFileStore(storeDir, nil) store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil { if err != nil {
return return
} }
@@ -319,7 +320,7 @@ func TestRestoreGroups_Migration(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
store, err := NewFileStore(storeDir, nil) store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil { if err != nil {
return return
} }
@@ -332,11 +333,11 @@ func TestRestoreGroups_Migration(t *testing.T) {
Name: "All", Name: "All",
}, },
} }
err = store.SaveAccount(account) err = store.SaveAccount(context.Background(), account)
require.NoError(t, err, "failed to save account") require.NoError(t, err, "failed to save account")
// restore account with default group with empty Issue field // restore account with default group with empty Issue field
if store, err = NewFileStore(storeDir, nil); err != nil { if store, err = NewFileStore(context.Background(), storeDir, nil); err != nil {
return return
} }
account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
@@ -353,18 +354,18 @@ func TestGetAccountByPrivateDomain(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
store, err := NewFileStore(storeDir, nil) store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil { if err != nil {
return return
} }
existingDomain := "test.com" existingDomain := "test.com"
account, err := store.GetAccountByPrivateDomain(existingDomain) account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain)
require.NoError(t, err, "should found account") require.NoError(t, err, "should found account")
require.Equal(t, existingDomain, account.Domain, "domains should match") require.Equal(t, existingDomain, account.Domain, "domains should match")
_, err = store.GetAccountByPrivateDomain("missing-domain.com") _, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com")
require.Error(t, err, "should return error on domain lookup") require.Error(t, err, "should return error on domain lookup")
} }
@@ -382,7 +383,7 @@ func TestFileStore_GetAccount(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
store, err := NewFileStore(storeDir, nil) store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -393,7 +394,7 @@ func TestFileStore_GetAccount(t *testing.T) {
return return
} }
account, err := store.GetAccount(expected.Id) account, err := store.GetAccount(context.Background(), expected.Id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -424,13 +425,13 @@ func TestFileStore_GetTokenIDByHashedToken(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
store, err := NewFileStore(storeDir, nil) store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
hashedToken := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].HashedToken hashedToken := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].HashedToken
tokenID, err := store.GetTokenIDByHashedToken(hashedToken) tokenID, err := store.GetTokenIDByHashedToken(context.Background(), hashedToken)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -441,7 +442,7 @@ func TestFileStore_GetTokenIDByHashedToken(t *testing.T) {
func TestFileStore_DeleteHashedPAT2TokenIDIndex(t *testing.T) { func TestFileStore_DeleteHashedPAT2TokenIDIndex(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close() defer store.Close(context.Background())
store.HashedPAT2TokenID["someHashedToken"] = "someTokenId" store.HashedPAT2TokenID["someHashedToken"] = "someTokenId"
err := store.DeleteHashedPAT2TokenIDIndex("someHashedToken") err := store.DeleteHashedPAT2TokenIDIndex("someHashedToken")
@@ -478,13 +479,13 @@ func TestFileStore_GetTokenIDByHashedToken_Failure(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
store, err := NewFileStore(storeDir, nil) store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wrongToken := sha256.Sum256([]byte("someNotValidTokenThatFails1234")) wrongToken := sha256.Sum256([]byte("someNotValidTokenThatFails1234"))
_, err = store.GetTokenIDByHashedToken(string(wrongToken[:])) _, err = store.GetTokenIDByHashedToken(context.Background(), string(wrongToken[:]))
assert.Error(t, err, "GetTokenIDByHashedToken should throw error if token invalid") assert.Error(t, err, "GetTokenIDByHashedToken should throw error if token invalid")
} }
@@ -503,13 +504,13 @@ func TestFileStore_GetUserByTokenID(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
store, err := NewFileStore(storeDir, nil) store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
tokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID tokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID
user, err := store.GetUserByTokenID(tokenID) user, err := store.GetUserByTokenID(context.Background(), tokenID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -531,13 +532,13 @@ func TestFileStore_GetUserByTokenID_Failure(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
store, err := NewFileStore(storeDir, nil) store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wrongTokenID := "someNonExistingTokenID" wrongTokenID := "someNonExistingTokenID"
_, err = store.GetUserByTokenID(wrongTokenID) _, err = store.GetUserByTokenID(context.Background(), wrongTokenID)
assert.Error(t, err, "GetUserByTokenID should throw error if tokenID invalid") assert.Error(t, err, "GetUserByTokenID should throw error if tokenID invalid")
} }
@@ -550,7 +551,7 @@ func TestFileStore_SavePeerStatus(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
store, err := NewFileStore(storeDir, nil) store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil { if err != nil {
return return
} }
@@ -576,7 +577,7 @@ func TestFileStore_SavePeerStatus(t *testing.T) {
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()},
} }
err = store.SaveAccount(account) err = store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -602,11 +603,11 @@ func TestFileStore_SavePeerLocation(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
store, err := NewFileStore(storeDir, nil) store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil { if err != nil {
return return
} }
account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b") account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err) require.NoError(t, err)
peer := &nbpeer.Peer{ peer := &nbpeer.Peer{
@@ -625,7 +626,7 @@ func TestFileStore_SavePeerLocation(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
account.Peers[peer.ID] = peer account.Peers[peer.ID] = peer
err = store.SaveAccount(account) err = store.SaveAccount(context.Background(), account)
require.NoError(t, err) require.NoError(t, err)
peer.Location.ConnectionIP = net.ParseIP("35.1.1.1") peer.Location.ConnectionIP = net.ParseIP("35.1.1.1")
@@ -636,7 +637,7 @@ func TestFileStore_SavePeerLocation(t *testing.T) {
err = store.SavePeerLocation(account.Id, account.Peers[peer.ID]) err = store.SavePeerLocation(account.Id, account.Peers[peer.ID])
assert.NoError(t, err) assert.NoError(t, err)
account, err = store.GetAccount(account.Id) account, err = store.GetAccount(context.Background(), account.Id)
require.NoError(t, err) require.NoError(t, err)
actual := account.Peers[peer.ID].Location actual := account.Peers[peer.ID].Location
@@ -645,7 +646,7 @@ func TestFileStore_SavePeerLocation(t *testing.T) {
func newStore(t *testing.T) *FileStore { func newStore(t *testing.T) *FileStore {
t.Helper() t.Helper()
store, err := NewFileStore(t.TempDir(), nil) store, err := NewFileStore(context.Background(), t.TempDir(), nil)
if err != nil { if err != nil {
t.Errorf("failed creating a new store") t.Errorf("failed creating a new store")
} }

View File

@@ -2,6 +2,7 @@ package geolocation
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"net" "net"
"os" "os"
@@ -52,7 +53,7 @@ type Country struct {
CountryName string CountryName string
} }
func NewGeolocation(dataDir string) (*Geolocation, error) { func NewGeolocation(ctx context.Context, dataDir string) (*Geolocation, error) {
if err := loadGeolocationDatabases(dataDir); err != nil { if err := loadGeolocationDatabases(dataDir); err != nil {
return nil, fmt.Errorf("failed to load MaxMind databases: %v", err) return nil, fmt.Errorf("failed to load MaxMind databases: %v", err)
} }
@@ -68,7 +69,7 @@ func NewGeolocation(dataDir string) (*Geolocation, error) {
return nil, err return nil, err
} }
locationDB, err := NewSqliteStore(dataDir) locationDB, err := NewSqliteStore(ctx, dataDir)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -83,7 +84,7 @@ func NewGeolocation(dataDir string) (*Geolocation, error) {
stopCh: make(chan struct{}), stopCh: make(chan struct{}),
} }
go geo.reloader() go geo.reloader(ctx)
return geo, nil return geo, nil
} }
@@ -165,19 +166,19 @@ func (gl *Geolocation) Stop() error {
return nil return nil
} }
func (gl *Geolocation) reloader() { func (gl *Geolocation) reloader(ctx context.Context) {
for { for {
select { select {
case <-gl.stopCh: case <-gl.stopCh:
return return
case <-time.After(gl.reloadCheckInterval): case <-time.After(gl.reloadCheckInterval):
if err := gl.locationDB.reload(); err != nil { if err := gl.locationDB.reload(ctx); err != nil {
log.Errorf("geonames db reload failed: %s", err) log.WithContext(ctx).Errorf("geonames db reload failed: %s", err)
} }
newSha256sum1, err := calculateFileSHA256(gl.mmdbPath) newSha256sum1, err := calculateFileSHA256(gl.mmdbPath)
if err != nil { if err != nil {
log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err) log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
continue continue
} }
if !bytes.Equal(gl.sha256sum, newSha256sum1) { if !bytes.Equal(gl.sha256sum, newSha256sum1) {
@@ -186,30 +187,30 @@ func (gl *Geolocation) reloader() {
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
newSha256sum2, err := calculateFileSHA256(gl.mmdbPath) newSha256sum2, err := calculateFileSHA256(gl.mmdbPath)
if err != nil { if err != nil {
log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err) log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
continue continue
} }
if !bytes.Equal(newSha256sum1, newSha256sum2) { if !bytes.Equal(newSha256sum1, newSha256sum2) {
log.Errorf("sha256 sum changed during reloading of '%s'", gl.mmdbPath) log.WithContext(ctx).Errorf("sha256 sum changed during reloading of '%s'", gl.mmdbPath)
continue continue
} }
err = gl.reload(newSha256sum2) err = gl.reload(ctx, newSha256sum2)
if err != nil { if err != nil {
log.Errorf("mmdb reload failed: %s", err) log.WithContext(ctx).Errorf("mmdb reload failed: %s", err)
} }
} else { } else {
log.Tracef("No changes in '%s', no need to reload. Next check is in %.0f seconds.", log.WithContext(ctx).Tracef("No changes in '%s', no need to reload. Next check is in %.0f seconds.",
gl.mmdbPath, gl.reloadCheckInterval.Seconds()) gl.mmdbPath, gl.reloadCheckInterval.Seconds())
} }
} }
} }
} }
func (gl *Geolocation) reload(newSha256sum []byte) error { func (gl *Geolocation) reload(ctx context.Context, newSha256sum []byte) error {
gl.mux.Lock() gl.mux.Lock()
defer gl.mux.Unlock() defer gl.mux.Unlock()
log.Infof("Reloading '%s'", gl.mmdbPath) log.WithContext(ctx).Infof("Reloading '%s'", gl.mmdbPath)
err := gl.db.Close() err := gl.db.Close()
if err != nil { if err != nil {
@@ -224,7 +225,7 @@ func (gl *Geolocation) reload(newSha256sum []byte) error {
gl.db = db gl.db = db
gl.sha256sum = newSha256sum gl.sha256sum = newSha256sum
log.Infof("Successfully reloaded '%s'", gl.mmdbPath) log.WithContext(ctx).Infof("Successfully reloaded '%s'", gl.mmdbPath)
return nil return nil
} }

View File

@@ -2,6 +2,7 @@ package geolocation
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"path/filepath" "path/filepath"
"runtime" "runtime"
@@ -50,10 +51,10 @@ type SqliteStore struct {
sha256sum []byte sha256sum []byte
} }
func NewSqliteStore(dataDir string) (*SqliteStore, error) { func NewSqliteStore(ctx context.Context, dataDir string) (*SqliteStore, error) {
file := filepath.Join(dataDir, GeoSqliteDBFile) file := filepath.Join(dataDir, GeoSqliteDBFile)
db, err := connectDB(file) db, err := connectDB(ctx, file)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -115,13 +116,13 @@ func (s *SqliteStore) GetCitiesByCountry(countryISOCode string) ([]City, error)
} }
// reload attempts to reload the SqliteStore's database if the database file has changed. // reload attempts to reload the SqliteStore's database if the database file has changed.
func (s *SqliteStore) reload() error { func (s *SqliteStore) reload(ctx context.Context) error {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
newSha256sum1, err := calculateFileSHA256(s.filePath) newSha256sum1, err := calculateFileSHA256(s.filePath)
if err != nil { if err != nil {
log.Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err) log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err)
} }
if !bytes.Equal(s.sha256sum, newSha256sum1) { if !bytes.Equal(s.sha256sum, newSha256sum1) {
@@ -136,11 +137,11 @@ func (s *SqliteStore) reload() error {
return fmt.Errorf("sha256 sum changed during reloading of '%s'", s.filePath) return fmt.Errorf("sha256 sum changed during reloading of '%s'", s.filePath)
} }
log.Infof("Reloading '%s'", s.filePath) log.WithContext(ctx).Infof("Reloading '%s'", s.filePath)
_ = s.close() _ = s.close()
s.closed = true s.closed = true
newDb, err := connectDB(s.filePath) newDb, err := connectDB(ctx, s.filePath)
if err != nil { if err != nil {
return err return err
} }
@@ -148,9 +149,9 @@ func (s *SqliteStore) reload() error {
s.closed = false s.closed = false
s.db = newDb s.db = newDb
log.Infof("Successfully reloaded '%s'", s.filePath) log.WithContext(ctx).Infof("Successfully reloaded '%s'", s.filePath)
} else { } else {
log.Tracef("No changes in '%s', no need to reload", s.filePath) log.WithContext(ctx).Tracef("No changes in '%s', no need to reload", s.filePath)
} }
return nil return nil
@@ -168,10 +169,10 @@ func (s *SqliteStore) close() error {
} }
// connectDB connects to an SQLite database and prepares it by setting up an in-memory database. // connectDB connects to an SQLite database and prepares it by setting up an in-memory database.
func connectDB(filePath string) (*gorm.DB, error) { func connectDB(ctx context.Context, filePath string) (*gorm.DB, error) {
start := time.Now() start := time.Now()
defer func() { defer func() {
log.Debugf("took %v to setup geoname db", time.Since(start)) log.WithContext(ctx).Debugf("took %v to setup geoname db", time.Since(start))
}() }()
_, err := fileExists(filePath) _, err := fileExists(filePath)

View File

@@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"fmt" "fmt"
"github.com/rs/xid" "github.com/rs/xid"
@@ -21,11 +22,11 @@ func (e *GroupLinkError) Error() string {
} }
// GetGroup object of the peers // GetGroup object of the peers
func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*nbgroup.Group, error) { func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID) unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -48,11 +49,11 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*n
} }
// GetAllGroups returns all groups in an account // GetAllGroups returns all groups in an account
func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*nbgroup.Group, error) { func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID) unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -75,11 +76,11 @@ func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) (
} }
// GetGroupByName filters all groups in an account by name and returns the one with the most peers // GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) { func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID) unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -108,11 +109,11 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*n
} }
// SaveGroup object of the peers // SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *nbgroup.Group) error { func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error {
unlock := am.Store.AcquireAccountWriteLock(accountID) unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -150,11 +151,12 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *n
account.Groups[newGroup.ID] = newGroup account.Groups[newGroup.ID] = newGroup
account.Network.IncSerial() account.Network.IncSerial()
if err = am.Store.SaveAccount(account); err != nil { if err = am.Store.SaveAccount(ctx, account); err != nil {
return err return err
} }
am.updateAccountPeers(account) // todo: check if groups is in use by dns, acl, routes and before/after peers
am.updateAccountPeers(ctx, account)
// the following snippet tracks the activity and stores the group events in the event store. // the following snippet tracks the activity and stores the group events in the event store.
// It has to happen after all the operations have been successfully performed. // It has to happen after all the operations have been successfully performed.
@@ -165,16 +167,16 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *n
removedPeers = difference(oldGroup.Peers, newGroup.Peers) removedPeers = difference(oldGroup.Peers, newGroup.Peers)
} else { } else {
addedPeers = append(addedPeers, newGroup.Peers...) addedPeers = append(addedPeers, newGroup.Peers...)
am.StoreEvent(userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta()) am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
} }
for _, p := range addedPeers { for _, p := range addedPeers {
peer := account.Peers[p] peer := account.Peers[p]
if peer == nil { if peer == nil {
log.Errorf("peer %s not found under account %s while saving group", p, accountID) log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
continue continue
} }
am.StoreEvent(userID, peer.ID, accountID, activity.GroupAddedToPeer, am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer,
map[string]any{ map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(), "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
"peer_fqdn": peer.FQDN(am.GetDNSDomain()), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
@@ -184,10 +186,10 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *n
for _, p := range removedPeers { for _, p := range removedPeers {
peer := account.Peers[p] peer := account.Peers[p]
if peer == nil { if peer == nil {
log.Errorf("peer %s not found under account %s while saving group", p, accountID) log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
continue continue
} }
am.StoreEvent(userID, peer.ID, accountID, activity.GroupRemovedFromPeer, am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer,
map[string]any{ map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(), "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
"peer_fqdn": peer.FQDN(am.GetDNSDomain()), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
@@ -213,11 +215,11 @@ func difference(a, b []string) []string {
} }
// DeleteGroup object of the peers // DeleteGroup object of the peers
func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error { func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountId) unlock := am.Store.AcquireAccountWriteLock(ctx, accountId)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountId) account, err := am.Store.GetAccount(ctx, accountId)
if err != nil { if err != nil {
return err return err
} }
@@ -315,23 +317,24 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string)
delete(account.Groups, groupID) delete(account.Groups, groupID)
account.Network.IncSerial() account.Network.IncSerial()
if err = am.Store.SaveAccount(account); err != nil { if err = am.Store.SaveAccount(ctx, account); err != nil {
return err return err
} }
am.StoreEvent(userId, groupID, accountId, activity.GroupDeleted, g.EventMeta()) am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
am.updateAccountPeers(account) // todo: check if groups is in use by dns, acl, routes and if it has peers
am.updateAccountPeers(ctx, account)
return nil return nil
} }
// ListGroups objects of the peers // ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, error) { func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID) unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -345,11 +348,11 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group,
} }
// GroupAddPeer appends peer to the group // GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string) error { func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountID) unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -371,21 +374,22 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string)
} }
account.Network.IncSerial() account.Network.IncSerial()
if err = am.Store.SaveAccount(account); err != nil { if err = am.Store.SaveAccount(ctx, account); err != nil {
return err return err
} }
am.updateAccountPeers(account) // todo: check if groups is in use by dns, acl, routes
am.updateAccountPeers(ctx, account)
return nil return nil
} }
// GroupDeletePeer removes peer from the group // GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error { func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountID) unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -399,13 +403,14 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID stri
for i, itemID := range group.Peers { for i, itemID := range group.Peers {
if itemID == peerID { if itemID == peerID {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...) group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
if err := am.Store.SaveAccount(account); err != nil { if err := am.Store.SaveAccount(ctx, account); err != nil {
return err return err
} }
} }
} }
am.updateAccountPeers(account) // todo: check if groups is in use by dns, acl, routes
am.updateAccountPeers(ctx, account)
return nil return nil
} }

View File

@@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"errors" "errors"
"testing" "testing"
@@ -26,7 +27,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
} }
for _, group := range account.Groups { for _, group := range account.Groups {
group.Issued = nbgroup.GroupIssuedIntegration group.Issued = nbgroup.GroupIssuedIntegration
err = am.SaveGroup(account.Id, groupAdminUserID, group) err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group)
if err != nil { if err != nil {
t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedIntegration) t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedIntegration)
} }
@@ -34,7 +35,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
for _, group := range account.Groups { for _, group := range account.Groups {
group.Issued = nbgroup.GroupIssuedJWT group.Issued = nbgroup.GroupIssuedJWT
err = am.SaveGroup(account.Id, groupAdminUserID, group) err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group)
if err != nil { if err != nil {
t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedJWT) t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedJWT)
} }
@@ -42,7 +43,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
for _, group := range account.Groups { for _, group := range account.Groups {
group.Issued = nbgroup.GroupIssuedAPI group.Issued = nbgroup.GroupIssuedAPI
group.ID = "" group.ID = ""
err = am.SaveGroup(account.Id, groupAdminUserID, group) err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group)
if err == nil { if err == nil {
t.Errorf("should not create api group with the same name, %s", group.Name) t.Errorf("should not create api group with the same name, %s", group.Name)
} }
@@ -104,7 +105,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
err = am.DeleteGroup(account.Id, groupAdminUserID, testCase.groupID) err = am.DeleteGroup(context.Background(), account.Id, groupAdminUserID, testCase.groupID)
if err == nil { if err == nil {
t.Errorf("delete %s group successfully", testCase.groupID) t.Errorf("delete %s group successfully", testCase.groupID)
return return
@@ -225,7 +226,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
Id: "example user", Id: "example user",
AutoGroups: []string{groupForUsers.ID}, AutoGroups: []string{groupForUsers.ID},
} }
account := newAccountWithId(accountID, groupAdminUserID, domain) account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
account.Routes[routeResource.ID] = routeResource account.Routes[routeResource.ID] = routeResource
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
@@ -233,18 +234,18 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
account.SetupKeys[setupKey.Id] = setupKey account.SetupKeys[setupKey.Id] = setupKey
account.Users[user.Id] = user account.Users[user.Id] = user
err := am.Store.SaveAccount(account) err := am.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
_ = am.SaveGroup(accountID, groupAdminUserID, groupForRoute) _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForRoute2) _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForNameServerGroups) _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForPolicies) _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForSetupKeys) _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForUsers) _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForIntegration) _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
return am.Store.GetAccount(account.Id) return am.Store.GetAccount(context.Background(), account.Id)
} }

View File

@@ -11,12 +11,14 @@ import (
pb "github.com/golang/protobuf/proto" // nolint pb "github.com/golang/protobuf/proto" // nolint
"github.com/golang/protobuf/ptypes/timestamp" "github.com/golang/protobuf/ptypes/timestamp"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/netbirdio/netbird/management/server/posture"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -40,7 +42,7 @@ type GRPCServer struct {
} }
// NewServer creates a new Management server // NewServer creates a new Management server
func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) { func NewServer(ctx context.Context, config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) {
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -50,6 +52,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) { if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
jwtValidator, err = jwtclaims.NewJWTValidator( jwtValidator, err = jwtclaims.NewJWTValidator(
ctx,
config.HttpConfig.AuthIssuer, config.HttpConfig.AuthIssuer,
config.GetAuthAudiences(), config.GetAuthAudiences(),
config.HttpConfig.AuthKeysLocation, config.HttpConfig.AuthKeysLocation,
@@ -59,7 +62,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err) return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)
} }
} else { } else {
log.Debug("unable to use http config to create new jwt middleware") log.WithContext(ctx).Debug("unable to use http config to create new jwt middleware")
} }
if appMetrics != nil { if appMetrics != nil {
@@ -126,47 +129,61 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
if s.appMetrics != nil { if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequest() s.appMetrics.GRPCMetrics().CountSyncRequest()
} }
realIP := getRealIP(srv.Context())
log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String()) ctx := srv.Context()
realIP := getRealIP(ctx)
syncReq := &proto.SyncRequest{} syncReq := &proto.SyncRequest{}
peerKey, err := s.parseRequest(req, syncReq) peerKey, err := s.parseRequest(ctx, req, syncReq)
if err != nil { if err != nil {
return err return err
} }
//nolint
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
// this case should not happen and already indicates an issue but we don't want the system to fail due to being unable to log in detail
accountID = "UNKNOWN"
}
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String())
if syncReq.GetMeta() == nil { if syncReq.GetMeta() == nil {
log.Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP) log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
} }
peer, netMap, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), extractPeerMeta(syncReq.GetMeta()), realIP) peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
if err != nil { if err != nil {
return mapError(err) return mapError(ctx, err)
} }
err = s.sendInitialSync(peerKey, peer, netMap, srv) err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv)
if err != nil { if err != nil {
log.Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
return err return err
} }
updates := s.peersUpdateManager.CreateChannel(peer.ID) updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID)
s.ephemeralManager.OnPeerConnected(peer) s.ephemeralManager.OnPeerConnected(ctx, peer)
if s.config.TURNConfig.TimeBasedCredentials { if s.config.TURNConfig.TimeBasedCredentials {
s.turnCredentialsManager.SetupRefresh(peer.ID) s.turnCredentialsManager.SetupRefresh(ctx, peer.ID)
} }
if s.appMetrics != nil { if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart)) s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
} }
return s.handleUpdates(peerKey, peer, updates, srv) return s.handleUpdates(ctx, peerKey, peer, updates, srv)
} }
// handleUpdates sends updates to the connected peer until the updates channel is closed. // handleUpdates sends updates to the connected peer until the updates channel is closed.
func (s *GRPCServer) handleUpdates(peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error { func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
for { for {
select { select {
// condition when there are some updates // condition when there are some updates
@@ -176,21 +193,21 @@ func (s *GRPCServer) handleUpdates(peerKey wgtypes.Key, peer *nbpeer.Peer, updat
} }
if !open { if !open {
log.Debugf("updates channel for peer %s was closed", peerKey.String()) log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
s.cancelPeerRoutines(peer) s.cancelPeerRoutines(ctx, peer)
return nil return nil
} }
log.Debugf("received an update for peer %s", peerKey.String()) log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
if err := s.sendUpdate(peerKey, peer, update, srv); err != nil { if err := s.sendUpdate(ctx, peerKey, peer, update, srv); err != nil {
return err return err
} }
// condition when client <-> server connection has been terminated // condition when client <-> server connection has been terminated
case <-srv.Context().Done(): case <-srv.Context().Done():
// happens when connection drops, e.g. client disconnects // happens when connection drops, e.g. client disconnects
log.Debugf("stream of peer %s has been closed", peerKey.String()) log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
s.cancelPeerRoutines(peer) s.cancelPeerRoutines(ctx, peer)
return srv.Context().Err() return srv.Context().Err()
} }
} }
@@ -198,10 +215,10 @@ func (s *GRPCServer) handleUpdates(peerKey wgtypes.Key, peer *nbpeer.Peer, updat
// sendUpdate encrypts the update message using the peer key and the server's wireguard key, // sendUpdate encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server. // then sends the encrypted message to the connected peer via the sync server.
func (s *GRPCServer) sendUpdate(peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error { func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
if err != nil { if err != nil {
s.cancelPeerRoutines(peer) s.cancelPeerRoutines(ctx, peer)
return status.Errorf(codes.Internal, "failed processing update message") return status.Errorf(codes.Internal, "failed processing update message")
} }
err = srv.SendMsg(&proto.EncryptedMessage{ err = srv.SendMsg(&proto.EncryptedMessage{
@@ -209,37 +226,37 @@ func (s *GRPCServer) sendUpdate(peerKey wgtypes.Key, peer *nbpeer.Peer, update *
Body: encryptedResp, Body: encryptedResp,
}) })
if err != nil { if err != nil {
s.cancelPeerRoutines(peer) s.cancelPeerRoutines(ctx, peer)
return status.Errorf(codes.Internal, "failed sending update message") return status.Errorf(codes.Internal, "failed sending update message")
} }
log.Debugf("sent an update to peer %s", peerKey.String()) log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
return nil return nil
} }
func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) { func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) {
s.peersUpdateManager.CloseChannel(peer.ID) s.peersUpdateManager.CloseChannel(ctx, peer.ID)
s.turnCredentialsManager.CancelRefresh(peer.ID) s.turnCredentialsManager.CancelRefresh(peer.ID)
_ = s.accountManager.CancelPeerRoutines(peer) _ = s.accountManager.CancelPeerRoutines(ctx, peer)
s.ephemeralManager.OnPeerDisconnected(peer) s.ephemeralManager.OnPeerDisconnected(ctx, peer)
} }
func (s *GRPCServer) validateToken(jwtToken string) (string, error) { func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
if s.jwtValidator == nil { if s.jwtValidator == nil {
return "", status.Error(codes.Internal, "no jwt validator set") return "", status.Error(codes.Internal, "no jwt validator set")
} }
token, err := s.jwtValidator.ValidateAndParse(jwtToken) token, err := s.jwtValidator.ValidateAndParse(ctx, jwtToken)
if err != nil { if err != nil {
return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err) return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err)
} }
claims := s.jwtClaimsExtractor.FromToken(token) claims := s.jwtClaimsExtractor.FromToken(token)
// we need to call this method because if user is new, we will automatically add it to existing or create a new account // we need to call this method because if user is new, we will automatically add it to existing or create a new account
_, _, err = s.accountManager.GetAccountFromToken(claims) _, _, err = s.accountManager.GetAccountFromToken(ctx, claims)
if err != nil { if err != nil {
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err) return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
} }
if err := s.accountManager.CheckUserAccessByJWTGroups(claims); err != nil { if err := s.accountManager.CheckUserAccessByJWTGroups(ctx, claims); err != nil {
return "", status.Errorf(codes.PermissionDenied, err.Error()) return "", status.Errorf(codes.PermissionDenied, err.Error())
} }
@@ -247,7 +264,7 @@ func (s *GRPCServer) validateToken(jwtToken string) (string, error) {
} }
// maps internal internalStatus.Error to gRPC status.Error // maps internal internalStatus.Error to gRPC status.Error
func mapError(err error) error { func mapError(ctx context.Context, err error) error {
if e, ok := internalStatus.FromError(err); ok { if e, ok := internalStatus.FromError(err); ok {
switch e.Type() { switch e.Type() {
case internalStatus.PermissionDenied: case internalStatus.PermissionDenied:
@@ -263,11 +280,11 @@ func mapError(err error) error {
default: default:
} }
} }
log.Errorf("got an unhandled error: %s", err) log.WithContext(ctx).Errorf("got an unhandled error: %s", err)
return status.Errorf(codes.Internal, "failed handling request") return status.Errorf(codes.Internal, "failed handling request")
} }
func extractPeerMeta(meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta { func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
if meta == nil { if meta == nil {
return nbpeer.PeerSystemMeta{} return nbpeer.PeerSystemMeta{}
} }
@@ -281,7 +298,7 @@ func extractPeerMeta(meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
for _, addr := range meta.GetNetworkAddresses() { for _, addr := range meta.GetNetworkAddresses() {
netAddr, err := netip.ParsePrefix(addr.GetNetIP()) netAddr, err := netip.ParsePrefix(addr.GetNetIP())
if err != nil { if err != nil {
log.Warnf("failed to parse netip address, %s: %v", addr.GetNetIP(), err) log.WithContext(ctx).Warnf("failed to parse netip address, %s: %v", addr.GetNetIP(), err)
continue continue
} }
networkAddresses = append(networkAddresses, nbpeer.NetworkAddress{ networkAddresses = append(networkAddresses, nbpeer.NetworkAddress{
@@ -321,10 +338,10 @@ func extractPeerMeta(meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
} }
} }
func (s *GRPCServer) parseRequest(req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) { func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil { if err != nil {
log.Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey) log.WithContext(ctx).Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey)
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey) return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey)
} }
@@ -351,22 +368,32 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
s.appMetrics.GRPCMetrics().CountLoginRequest() s.appMetrics.GRPCMetrics().CountLoginRequest()
} }
realIP := getRealIP(ctx) realIP := getRealIP(ctx)
log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String()) log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String())
loginReq := &proto.LoginRequest{} loginReq := &proto.LoginRequest{}
peerKey, err := s.parseRequest(req, loginReq) peerKey, err := s.parseRequest(ctx, req, loginReq)
if err != nil { if err != nil {
return nil, err return nil, err
} }
//nolint
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
// this case should not happen and already indicates an issue but we don't want the system to fail due to being unable to log in detail
accountID = "UNKNOWN"
}
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
if loginReq.GetMeta() == nil { if loginReq.GetMeta() == nil {
msg := status.Errorf(codes.FailedPrecondition, msg := status.Errorf(codes.FailedPrecondition,
"peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(), realIP) "peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(), realIP)
log.Warn(msg) log.WithContext(ctx).Warn(msg)
return nil, msg return nil, msg
} }
userID, err := s.processJwtToken(loginReq, peerKey) userID, err := s.processJwtToken(ctx, loginReq, peerKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -376,33 +403,33 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
sshKey = loginReq.GetPeerKeys().GetSshPubKey() sshKey = loginReq.GetPeerKeys().GetSshPubKey()
} }
peer, netMap, err := s.accountManager.LoginPeer(PeerLogin{ peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, PeerLogin{
WireGuardPubKey: peerKey.String(), WireGuardPubKey: peerKey.String(),
SSHKey: string(sshKey), SSHKey: string(sshKey),
Meta: extractPeerMeta(loginReq.GetMeta()), Meta: extractPeerMeta(ctx, loginReq.GetMeta()),
UserID: userID, UserID: userID,
SetupKey: loginReq.GetSetupKey(), SetupKey: loginReq.GetSetupKey(),
ConnectionIP: realIP, ConnectionIP: realIP,
}) })
if err != nil { if err != nil {
log.Warnf("failed logging in peer %s: %s", peerKey, err) log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err)
return nil, mapError(err) return nil, mapError(ctx, err)
} }
// if the login request contains setup key then it is a registration request // if the login request contains setup key then it is a registration request
if loginReq.GetSetupKey() != "" { if loginReq.GetSetupKey() != "" {
s.ephemeralManager.OnPeerDisconnected(peer) s.ephemeralManager.OnPeerDisconnected(ctx, peer)
} }
// if peer has reached this point then it has logged in // if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{ loginResp := &proto.LoginResponse{
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil), WiretrusteeConfig: toWiretrusteeConfig(s.config, nil),
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()), PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()),
Checks: toProtocolChecks(s.accountManager, peerKey.String()), Checks: toProtocolChecks(ctx, postureChecks),
} }
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
if err != nil { if err != nil {
log.Warnf("failed encrypting peer %s message", peer.ID) log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID)
return nil, status.Errorf(codes.Internal, "failed logging in peer") return nil, status.Errorf(codes.Internal, "failed logging in peer")
} }
@@ -417,16 +444,16 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
// //
// The user ID can be empty if the token is not provided, which is acceptable if the peer is already // The user ID can be empty if the token is not provided, which is acceptable if the peer is already
// registered or if it uses a setup key to register. // registered or if it uses a setup key to register.
func (s *GRPCServer) processJwtToken(loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) { func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) {
userID := "" userID := ""
if loginReq.GetJwtToken() != "" { if loginReq.GetJwtToken() != "" {
var err error var err error
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
userID, err = s.validateToken(loginReq.GetJwtToken()) userID, err = s.validateToken(ctx, loginReq.GetJwtToken())
if err == nil { if err == nil {
break break
} }
log.Warnf("failed validating JWT token sent from peer %s with error %v. "+ log.WithContext(ctx).Warnf("failed validating JWT token sent from peer %s with error %v. "+
"Trying again as it may be due to the IdP cache issue", peerKey.String(), err) "Trying again as it may be due to the IdP cache issue", peerKey.String(), err)
time.Sleep(200 * time.Millisecond) time.Sleep(200 * time.Millisecond)
} }
@@ -520,7 +547,7 @@ func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePee
return remotePeers return remotePeers
} }
func toSyncResponse(accountManager AccountManager, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse { func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials) wtConfig := toWiretrusteeConfig(config, turnCredentials)
pConfig := toPeerConfig(peer, networkMap.Network, dnsName) pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
@@ -551,7 +578,7 @@ func toSyncResponse(accountManager AccountManager, config *Config, peer *nbpeer.
FirewallRules: firewallRules, FirewallRules: firewallRules,
FirewallRulesIsEmpty: len(firewallRules) == 0, FirewallRulesIsEmpty: len(firewallRules) == 0,
}, },
Checks: toProtocolChecks(accountManager, peer.Key), Checks: toProtocolChecks(ctx, checks),
} }
} }
@@ -561,7 +588,7 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em
} }
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, srv proto.ManagementService_SyncServer) error { func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error {
// make secret time based TURN credentials optional // make secret time based TURN credentials optional
var turnCredentials *TURNCredentials var turnCredentials *TURNCredentials
if s.config.TURNConfig.TimeBasedCredentials { if s.config.TURNConfig.TimeBasedCredentials {
@@ -570,7 +597,7 @@ func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, net
} else { } else {
turnCredentials = nil turnCredentials = nil
} }
plainResp := toSyncResponse(s.accountManager, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain()) plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil { if err != nil {
@@ -583,7 +610,7 @@ func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, net
}) })
if err != nil { if err != nil {
log.Errorf("failed sending SyncResponse %v", err) log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err)
return status.Errorf(codes.Internal, "error handling request") return status.Errorf(codes.Internal, "error handling request")
} }
@@ -597,14 +624,14 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil { if err != nil {
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey) errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey)
log.Warn(errMSG) log.WithContext(ctx).Warn(errMSG)
return nil, status.Error(codes.InvalidArgument, errMSG) return nil, status.Error(codes.InvalidArgument, errMSG)
} }
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.DeviceAuthorizationFlowRequest{}) err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.DeviceAuthorizationFlowRequest{})
if err != nil { if err != nil {
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey) errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
log.Warn(errMSG) log.WithContext(ctx).Warn(errMSG)
return nil, status.Error(codes.InvalidArgument, errMSG) return nil, status.Error(codes.InvalidArgument, errMSG)
} }
@@ -645,18 +672,18 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
// GetPKCEAuthorizationFlow returns a pkce authorization flow information // GetPKCEAuthorizationFlow returns a pkce authorization flow information
// This is used for initiating an Oauth 2 pkce authorization grant flow // This is used for initiating an Oauth 2 pkce authorization grant flow
// which will be used by our clients to Login // which will be used by our clients to Login
func (s *GRPCServer) GetPKCEAuthorizationFlow(_ context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil { if err != nil {
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetPKCEAuthorizationFlow request.", req.WgPubKey) errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetPKCEAuthorizationFlow request.", req.WgPubKey)
log.Warn(errMSG) log.WithContext(ctx).Warn(errMSG)
return nil, status.Error(codes.InvalidArgument, errMSG) return nil, status.Error(codes.InvalidArgument, errMSG)
} }
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.PKCEAuthorizationFlowRequest{}) err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.PKCEAuthorizationFlowRequest{})
if err != nil { if err != nil {
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey) errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
log.Warn(errMSG) log.WithContext(ctx).Warn(errMSG)
return nil, status.Error(codes.InvalidArgument, errMSG) return nil, status.Error(codes.InvalidArgument, errMSG)
} }
@@ -692,10 +719,10 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(_ context.Context, req *proto.Encr
// peer's under the same account of any updates. // peer's under the same account of any updates.
func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
realIP := getRealIP(ctx) realIP := getRealIP(ctx)
log.Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String()) log.WithContext(ctx).Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String())
syncMetaReq := &proto.SyncMetaRequest{} syncMetaReq := &proto.SyncMetaRequest{}
peerKey, err := s.parseRequest(req, syncMetaReq) peerKey, err := s.parseRequest(ctx, req, syncMetaReq)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -703,27 +730,21 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage)
if syncMetaReq.GetMeta() == nil { if syncMetaReq.GetMeta() == nil {
msg := status.Errorf(codes.FailedPrecondition, msg := status.Errorf(codes.FailedPrecondition,
"peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP) "peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
log.Warn(msg) log.WithContext(ctx).Warn(msg)
return nil, msg return nil, msg
} }
err = s.accountManager.SyncPeerMeta(peerKey.String(), extractPeerMeta(syncMetaReq.GetMeta())) err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()))
if err != nil { if err != nil {
return nil, mapError(err) return nil, mapError(ctx, err)
} }
return &proto.Empty{}, nil return &proto.Empty{}, nil
} }
// toProtocolChecks returns posture checks for the peer that needs to be evaluated on the client side. // toProtocolChecks converts posture checks to protocol checks.
func toProtocolChecks(accountManager AccountManager, peerKey string) []*proto.Checks { func toProtocolChecks(ctx context.Context, postureChecks []*posture.Checks) []*proto.Checks {
postureChecks, err := accountManager.GetPeerAppliedPostureChecks(peerKey) protoChecks := make([]*proto.Checks, 0, len(postureChecks))
if err != nil {
log.Errorf("failed getting peer's: %s posture checks: %v", peerKey, err)
return nil
}
protoChecks := make([]*proto.Checks, 0)
for _, postureCheck := range postureChecks { for _, postureCheck := range postureChecks {
protoChecks = append(protoChecks, toProtocolCheck(postureCheck)) protoChecks = append(protoChecks, toProtocolCheck(postureCheck))
} }
@@ -732,7 +753,7 @@ func toProtocolChecks(accountManager AccountManager, peerKey string) []*proto.Ch
} }
// toProtocolCheck converts a posture.Checks to a proto.Checks. // toProtocolCheck converts a posture.Checks to a proto.Checks.
func toProtocolCheck(postureCheck posture.Checks) *proto.Checks { func toProtocolCheck(postureCheck *posture.Checks) *proto.Checks {
protoCheck := &proto.Checks{} protoCheck := &proto.Checks{}
if check := postureCheck.Checks.ProcessCheck; check != nil { if check := postureCheck.Checks.ProcessCheck; check != nil {

View File

@@ -35,34 +35,34 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *
// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. // GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) { func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
if !(user.HasAdminPower() || user.IsServiceUser) { if !(user.HasAdminPower() || user.IsServiceUser) {
util.WriteError(status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w) util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w)
return return
} }
resp := toAccountResponse(account) resp := toAccountResponse(account)
util.WriteJSONObject(w, []*api.Account{resp}) util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
} }
// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) // UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) { func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
_, user, err := h.accountManager.GetAccountFromToken(claims) _, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
accountID := vars["accountId"] accountID := vars["accountId"]
if len(accountID) == 0 { if len(accountID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid accountID ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid accountID ID"), w)
return return
} }
@@ -96,15 +96,15 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
} }
updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, settings) updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, user.Id, settings)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
resp := toAccountResponse(updatedAccount) resp := toAccountResponse(updatedAccount)
util.WriteJSONObject(w, &resp) util.WriteJSONObject(r.Context(), w, &resp)
} }
// DeleteAccount is a HTTP DELETE handler to delete an account // DeleteAccount is a HTTP DELETE handler to delete an account
@@ -118,17 +118,17 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request)
vars := mux.Vars(r) vars := mux.Vars(r)
targetAccountID := vars["accountId"] targetAccountID := vars["accountId"]
if len(targetAccountID) == 0 { if len(targetAccountID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid account ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid account ID"), w)
return return
} }
err := h.accountManager.DeleteAccount(targetAccountID, claims.UserId) err := h.accountManager.DeleteAccount(r.Context(), targetAccountID, claims.UserId)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(w, emptyObject{}) util.WriteJSONObject(r.Context(), w, emptyObject{})
} }
func toAccountResponse(account *server.Account) *api.Account { func toAccountResponse(account *server.Account) *api.Account {

View File

@@ -2,6 +2,7 @@ package http
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
@@ -22,10 +23,10 @@ import (
func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler { func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler {
return &AccountsHandler{ return &AccountsHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return account, admin, nil return account, admin, nil
}, },
UpdateAccountSettingsFunc: func(accountID, userID string, newSettings *server.Settings) (*server.Account, error) { UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
halfYearLimit := 180 * 24 * time.Hour halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit { if newSettings.PeerLoginExpiration > halfYearLimit {
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")

View File

@@ -32,16 +32,16 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg
// GetDNSSettings returns the DNS settings for the account // GetDNSSettings returns the DNS settings for the account
func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) { func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
log.Error(err) log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
dnsSettings, err := h.accountManager.GetDNSSettings(account.Id, user.Id) dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -49,15 +49,15 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque
DisabledManagementGroups: dnsSettings.DisabledManagementGroups, DisabledManagementGroups: dnsSettings.DisabledManagementGroups,
} }
util.WriteJSONObject(w, apiDNSSettings) util.WriteJSONObject(r.Context(), w, apiDNSSettings)
} }
// UpdateDNSSettings handles update to DNS settings of an account // UpdateDNSSettings handles update to DNS settings of an account
func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) { func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -72,9 +72,9 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re
DisabledManagementGroups: req.DisabledManagementGroups, DisabledManagementGroups: req.DisabledManagementGroups,
} }
err = h.accountManager.SaveDNSSettings(account.Id, user.Id, updateDNSSettings) err = h.accountManager.SaveDNSSettings(r.Context(), account.Id, user.Id, updateDNSSettings)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -82,5 +82,5 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re
DisabledManagementGroups: updateDNSSettings.DisabledManagementGroups, DisabledManagementGroups: updateDNSSettings.DisabledManagementGroups,
} }
util.WriteJSONObject(w, &resp) util.WriteJSONObject(r.Context(), w, &resp)
} }

View File

@@ -2,6 +2,7 @@ package http
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
@@ -42,16 +43,16 @@ var testingDNSSettingsAccount = &server.Account{
func initDNSSettingsTestData() *DNSSettingsHandler { func initDNSSettingsTestData() *DNSSettingsHandler {
return &DNSSettingsHandler{ return &DNSSettingsHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetDNSSettingsFunc: func(accountID string, userID string) (*server.DNSSettings, error) { GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) {
return &testingDNSSettingsAccount.DNSSettings, nil return &testingDNSSettingsAccount.DNSSettings, nil
}, },
SaveDNSSettingsFunc: func(accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error { SaveDNSSettingsFunc: func(ctx context.Context, accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error {
if dnsSettingsToSave != nil { if dnsSettingsToSave != nil {
return nil return nil
} }
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
}, },
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil
}, },
}, },

View File

@@ -1,6 +1,7 @@
package http package http
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
@@ -33,16 +34,16 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev
// GetAllEvents list of the given account // GetAllEvents list of the given account
func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
log.Error(err) log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
accountEvents, err := h.accountManager.GetEvents(account.Id, user.Id) accountEvents, err := h.accountManager.GetEvents(r.Context(), account.Id, user.Id)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
events := make([]*api.Event, len(accountEvents)) events := make([]*api.Event, len(accountEvents))
@@ -50,20 +51,20 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
events[i] = toEventResponse(e) events[i] = toEventResponse(e)
} }
err = h.fillEventsWithUserInfo(events, account.Id, user.Id) err = h.fillEventsWithUserInfo(r.Context(), events, account.Id, user.Id)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(w, events) util.WriteJSONObject(r.Context(), w, events)
} }
func (h *EventsHandler) fillEventsWithUserInfo(events []*api.Event, accountId, userId string) error { func (h *EventsHandler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error {
// build email, name maps based on users // build email, name maps based on users
userInfos, err := h.accountManager.GetUsersFromAccount(accountId, userId) userInfos, err := h.accountManager.GetUsersFromAccount(ctx, accountId, userId)
if err != nil { if err != nil {
log.Errorf("failed to get users from account: %s", err) log.WithContext(ctx).Errorf("failed to get users from account: %s", err)
return err return err
} }
@@ -80,7 +81,7 @@ func (h *EventsHandler) fillEventsWithUserInfo(events []*api.Event, accountId, u
if event.InitiatorEmail == "" { if event.InitiatorEmail == "" {
event.InitiatorEmail, ok = emails[event.InitiatorId] event.InitiatorEmail, ok = emails[event.InitiatorId]
if !ok { if !ok {
log.Warnf("failed to resolve email for initiator: %s", event.InitiatorId) log.WithContext(ctx).Warnf("failed to resolve email for initiator: %s", event.InitiatorId)
} }
} }

View File

@@ -1,6 +1,7 @@
package http package http
import ( import (
"context"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
@@ -22,13 +23,13 @@ import (
func initEventsTestData(account string, user *server.User, events ...*activity.Event) *EventsHandler { func initEventsTestData(account string, user *server.User, events ...*activity.Event) *EventsHandler {
return &EventsHandler{ return &EventsHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetEventsFunc: func(accountID, userID string) ([]*activity.Event, error) { GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) {
if accountID == account { if accountID == account {
return events, nil return events, nil
} }
return []*activity.Event{}, nil return []*activity.Event{}, nil
}, },
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: claims.AccountId,
Domain: "hotmail.com", Domain: "hotmail.com",
@@ -37,7 +38,7 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E
}, },
}, user, nil }, user, nil
}, },
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) { GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
return make([]*server.UserInfo, 0), nil return make([]*server.UserInfo, 0), nil
}, },
}, },

View File

@@ -1,6 +1,7 @@
package http package http
import ( import (
"context"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
@@ -35,13 +36,13 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
err = util.CopyFileContents(geonamesDBPath, path.Join(tempDir, geolocation.GeoSqliteDBFile)) err = util.CopyFileContents(geonamesDBPath, path.Join(tempDir, geolocation.GeoSqliteDBFile))
assert.NoError(t, err) assert.NoError(t, err)
geo, err := geolocation.NewGeolocation(tempDir) geo, err := geolocation.NewGeolocation(context.Background(), tempDir)
assert.NoError(t, err) assert.NoError(t, err)
t.Cleanup(func() { _ = geo.Stop() }) t.Cleanup(func() { _ = geo.Stop() })
return &GeolocationsHandler{ return &GeolocationsHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user") user := server.NewAdminUser("test_user")
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: claims.AccountId,

View File

@@ -40,19 +40,19 @@ func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geoloca
// GetAllCountries retrieves a list of all countries // GetAllCountries retrieves a list of all countries
func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Request) { func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil { if err := l.authenticateUser(r); err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
if l.geolocationManager == nil { if l.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready // TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w) util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
return return
} }
allCountries, err := l.geolocationManager.GetAllCountries() allCountries, err := l.geolocationManager.GetAllCountries()
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -60,32 +60,32 @@ func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Req
for _, country := range allCountries { for _, country := range allCountries {
countries = append(countries, toCountryResponse(country)) countries = append(countries, toCountryResponse(country))
} }
util.WriteJSONObject(w, countries) util.WriteJSONObject(r.Context(), w, countries)
} }
// GetCitiesByCountry retrieves a list of cities based on the given country code // GetCitiesByCountry retrieves a list of cities based on the given country code
func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.Request) { func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil { if err := l.authenticateUser(r); err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
countryCode := vars["country"] countryCode := vars["country"]
if !countryCodeRegex.MatchString(countryCode) { if !countryCodeRegex.MatchString(countryCode) {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid country code"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid country code"), w)
return return
} }
if l.geolocationManager == nil { if l.geolocationManager == nil {
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+ util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
"Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w) "Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w)
return return
} }
allCities, err := l.geolocationManager.GetCitiesByCountry(countryCode) allCities, err := l.geolocationManager.GetCitiesByCountry(countryCode)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -93,12 +93,12 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.
for _, city := range allCities { for _, city := range allCities {
cities = append(cities, toCityResponse(city)) cities = append(cities, toCityResponse(city))
} }
util.WriteJSONObject(w, cities) util.WriteJSONObject(r.Context(), w, cities)
} }
func (l *GeolocationsHandler) authenticateUser(r *http.Request) error { func (l *GeolocationsHandler) authenticateUser(r *http.Request) error {
claims := l.claimsExtractor.FromRequestContext(r) claims := l.claimsExtractor.FromRequestContext(r)
_, user, err := l.accountManager.GetAccountFromToken(claims) _, user, err := l.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -35,16 +35,16 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr
// GetAllGroups list for the account // GetAllGroups list for the account
func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
log.Error(err) log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
groups, err := h.accountManager.GetAllGroups(account.Id, user.Id) groups, err := h.accountManager.GetAllGroups(r.Context(), account.Id, user.Id)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -53,42 +53,42 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
groupsResponse = append(groupsResponse, toGroupResponse(account, group)) groupsResponse = append(groupsResponse, toGroupResponse(account, group))
} }
util.WriteJSONObject(w, groupsResponse) util.WriteJSONObject(r.Context(), w, groupsResponse)
} }
// UpdateGroup handles update to a group identified by a given ID // UpdateGroup handles update to a group identified by a given ID
func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
groupID, ok := vars["groupId"] groupID, ok := vars["groupId"]
if !ok { if !ok {
util.WriteError(status.Errorf(status.InvalidArgument, "group ID field is missing"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group ID field is missing"), w)
return return
} }
if len(groupID) == 0 { if len(groupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "group ID can't be empty"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group ID can't be empty"), w)
return return
} }
eg, ok := account.Groups[groupID] eg, ok := account.Groups[groupID]
if !ok { if !ok {
util.WriteError(status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w) util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
return return
} }
allGroup, err := account.GetGroupAll() allGroup, err := account.GetGroupAll()
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
if allGroup.ID == groupID { if allGroup.ID == groupID {
util.WriteError(status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
return return
} }
@@ -100,7 +100,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
} }
if req.Name == "" { if req.Name == "" {
util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
return return
} }
@@ -118,21 +118,21 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
IntegrationReference: eg.IntegrationReference, IntegrationReference: eg.IntegrationReference,
} }
if err := h.accountManager.SaveGroup(account.Id, user.Id, &group); err != nil { if err := h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group); err != nil {
log.Errorf("failed updating group %s under account %s %v", groupID, account.Id, err) log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(w, toGroupResponse(account, &group)) util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
} }
// CreateGroup handles group creation request // CreateGroup handles group creation request
func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -144,7 +144,7 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
} }
if req.Name == "" { if req.Name == "" {
util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
return return
} }
@@ -160,62 +160,62 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
Issued: nbgroup.GroupIssuedAPI, Issued: nbgroup.GroupIssuedAPI,
} }
err = h.accountManager.SaveGroup(account.Id, user.Id, &group) err = h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(w, toGroupResponse(account, &group)) util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
} }
// DeleteGroup handles group deletion request // DeleteGroup handles group deletion request
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
aID := account.Id aID := account.Id
groupID := mux.Vars(r)["groupId"] groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 { if len(groupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return return
} }
allGroup, err := account.GetGroupAll() allGroup, err := account.GetGroupAll()
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
if allGroup.ID == groupID { if allGroup.ID == groupID {
util.WriteError(status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
return return
} }
err = h.accountManager.DeleteGroup(aID, user.Id, groupID) err = h.accountManager.DeleteGroup(r.Context(), aID, user.Id, groupID)
if err != nil { if err != nil {
_, ok := err.(*server.GroupLinkError) _, ok := err.(*server.GroupLinkError)
if ok { if ok {
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w) util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
return return
} }
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(w, emptyObject{}) util.WriteJSONObject(r.Context(), w, emptyObject{})
} }
// GetGroup returns a group // GetGroup returns a group
func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -223,19 +223,19 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
case http.MethodGet: case http.MethodGet:
groupID := mux.Vars(r)["groupId"] groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 { if len(groupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return return
} }
group, err := h.accountManager.GetGroup(account.Id, groupID, user.Id) group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, user.Id)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(w, toGroupResponse(account, group)) util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group))
default: default:
util.WriteError(status.Errorf(status.NotFound, "HTTP method not found"), w) util.WriteError(r.Context(), status.Errorf(status.NotFound, "HTTP method not found"), w)
return return
} }
} }

View File

@@ -2,6 +2,7 @@ package http
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@@ -32,13 +33,13 @@ var TestPeers = map[string]*nbpeer.Peer{
func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
return &GroupsHandler{ return &GroupsHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
SaveGroupFunc: func(accountID, userID string, group *nbgroup.Group) error { SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error {
if !strings.HasPrefix(group.ID, "id-") { if !strings.HasPrefix(group.ID, "id-") {
group.ID = "id-was-set" group.ID = "id-was-set"
} }
return nil return nil
}, },
GetGroupFunc: func(_, groupID, _ string) (*nbgroup.Group, error) { GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) {
if groupID != "idofthegroup" { if groupID != "idofthegroup" {
return nil, status.Errorf(status.NotFound, "not found") return nil, status.Errorf(status.NotFound, "not found")
} }
@@ -55,7 +56,7 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
Issued: nbgroup.GroupIssuedAPI, Issued: nbgroup.GroupIssuedAPI,
}, nil }, nil
}, },
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: claims.AccountId,
Domain: "hotmail.com", Domain: "hotmail.com",
@@ -70,7 +71,7 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
}, },
}, user, nil }, user, nil
}, },
DeleteGroupFunc: func(accountID, userId, groupID string) error { DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
if groupID == "linked-grp" { if groupID == "linked-grp" {
return &server.GroupLinkError{ return &server.GroupLinkError{
Resource: "something", Resource: "something",

View File

@@ -9,6 +9,7 @@ import (
"github.com/rs/cors" "github.com/rs/cors"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
s "github.com/netbirdio/netbird/management/server" s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/http/middleware"
@@ -57,6 +58,11 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
corsMiddleware := cors.AllowAll() corsMiddleware := cors.AllowAll()
claimsExtractor = jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
)
acMiddleware := middleware.NewAccessControl( acMiddleware := middleware.NewAccessControl(
authCfg.Audience, authCfg.Audience,
authCfg.UserIDClaim, authCfg.UserIDClaim,

View File

@@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"context"
"net/http" "net/http"
"regexp" "regexp"
@@ -15,7 +16,7 @@ import (
) )
// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims // GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
type GetUser func(claims jwtclaims.AuthorizationClaims) (*server.User, error) type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only // AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
type AccessControl struct { type AccessControl struct {
@@ -46,15 +47,15 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
claims := a.claimsExtract.FromRequestContext(r) claims := a.claimsExtract.FromRequestContext(r)
user, err := a.getUser(claims) user, err := a.getUser(r.Context(), claims)
if err != nil { if err != nil {
log.Errorf("failed to get user from claims: %s", err) log.WithContext(r.Context()).Errorf("failed to get user from claims: %s", err)
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w) util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid JWT"), w)
return return
} }
if user.IsBlocked() { if user.IsBlocked() {
util.WriteError(status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w) util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w)
return return
} }
@@ -63,12 +64,12 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut: case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
if tokenPathRegexp.MatchString(r.URL.Path) { if tokenPathRegexp.MatchString(r.URL.Path) {
log.Debugf("valid Path") log.WithContext(r.Context()).Debugf("valid Path")
h.ServeHTTP(w, r) h.ServeHTTP(w, r)
return return
} }
util.WriteError(status.Errorf(status.PermissionDenied, "only users with admin power can perform this operation"), w) util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only users with admin power can perform this operation"), w)
return return
} }
} }

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -19,16 +20,16 @@ import (
) )
// GetAccountFromPATFunc function // GetAccountFromPATFunc function
type GetAccountFromPATFunc func(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
// ValidateAndParseTokenFunc function // ValidateAndParseTokenFunc function
type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error) type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
// MarkPATUsedFunc function // MarkPATUsedFunc function
type MarkPATUsedFunc func(token string) error type MarkPATUsedFunc func(ctx context.Context, token string) error
// CheckUserAccessByJWTGroupsFunc function // CheckUserAccessByJWTGroupsFunc function
type CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct { type AuthMiddleware struct {
@@ -85,23 +86,27 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
case "bearer": case "bearer":
err := m.checkJWTFromRequest(w, r, auth) err := m.checkJWTFromRequest(w, r, auth)
if err != nil { if err != nil {
log.Errorf("Error when validating JWT claims: %s", err.Error()) log.WithContext(r.Context()).Errorf("Error when validating JWT claims: %s", err.Error())
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w) util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
return return
} }
h.ServeHTTP(w, r)
case "token": case "token":
err := m.checkPATFromRequest(w, r, auth) err := m.checkPATFromRequest(w, r, auth)
if err != nil { if err != nil {
log.Debugf("Error when validating PAT claims: %s", err.Error()) log.WithContext(r.Context()).Debugf("Error when validating PAT claims: %s", err.Error())
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w) util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
return return
} }
h.ServeHTTP(w, r)
default: default:
util.WriteError(status.Errorf(status.Unauthorized, "no valid authentication provided"), w) util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
return return
} }
claims := m.claimsExtractor.FromRequestContext(r)
//nolint
ctx := context.WithValue(r.Context(), nbContext.UserIDKey, claims.UserId)
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, claims.AccountId)
h.ServeHTTP(w, r.WithContext(ctx))
}) })
} }
@@ -114,7 +119,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
return fmt.Errorf("Error extracting token: %w", err) return fmt.Errorf("Error extracting token: %w", err)
} }
validatedToken, err := m.validateAndParseToken(token) validatedToken, err := m.validateAndParseToken(r.Context(), token)
if err != nil { if err != nil {
return err return err
} }
@@ -123,7 +128,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
return nil return nil
} }
if err := m.verifyUserAccess(validatedToken); err != nil { if err := m.verifyUserAccess(r.Context(), validatedToken); err != nil {
return err return err
} }
@@ -138,9 +143,9 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
// verifyUserAccess checks if a user, based on a validated JWT token, // verifyUserAccess checks if a user, based on a validated JWT token,
// is allowed access, particularly in cases where the admin enabled JWT // is allowed access, particularly in cases where the admin enabled JWT
// group propagation and designated certain groups with access permissions. // group propagation and designated certain groups with access permissions.
func (m *AuthMiddleware) verifyUserAccess(validatedToken *jwt.Token) error { func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *jwt.Token) error {
authClaims := m.claimsExtractor.FromToken(validatedToken) authClaims := m.claimsExtractor.FromToken(validatedToken)
return m.checkUserAccessByJWTGroups(authClaims) return m.checkUserAccessByJWTGroups(ctx, authClaims)
} }
// CheckPATFromRequest checks if the PAT is valid // CheckPATFromRequest checks if the PAT is valid
@@ -152,7 +157,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
return fmt.Errorf("Error extracting token: %w", err) return fmt.Errorf("Error extracting token: %w", err)
} }
account, user, pat, err := m.getAccountFromPAT(token) account, user, pat, err := m.getAccountFromPAT(r.Context(), token)
if err != nil { if err != nil {
return fmt.Errorf("invalid Token: %w", err) return fmt.Errorf("invalid Token: %w", err)
} }
@@ -160,7 +165,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
return fmt.Errorf("token expired") return fmt.Errorf("token expired")
} }
err = m.markPATUsed(pat.ID) err = m.markPATUsed(r.Context(), pat.ID)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -15,15 +16,16 @@ import (
) )
const ( const (
audience = "audience" audience = "audience"
userIDClaim = "userIDClaim" userIDClaim = "userIDClaim"
accountID = "accountID" accountID = "accountID"
domain = "domain" domain = "domain"
userID = "userID" domainCategory = "domainCategory"
tokenID = "tokenID" userID = "userID"
PAT = "nbp_PAT" tokenID = "tokenID"
JWT = "JWT" PAT = "nbp_PAT"
wrongToken = "wrongToken" JWT = "JWT"
wrongToken = "wrongToken"
) )
var testAccount = &server.Account{ var testAccount = &server.Account{
@@ -47,14 +49,14 @@ var testAccount = &server.Account{
}, },
} }
func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
if token == PAT { if token == PAT {
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
} }
return nil, nil, nil, fmt.Errorf("PAT invalid") return nil, nil, nil, fmt.Errorf("PAT invalid")
} }
func mockValidateAndParseToken(token string) (*jwt.Token, error) { func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
if token == JWT { if token == JWT {
return &jwt.Token{ return &jwt.Token{
Claims: jwt.MapClaims{ Claims: jwt.MapClaims{
@@ -67,14 +69,14 @@ func mockValidateAndParseToken(token string) (*jwt.Token, error) {
return nil, fmt.Errorf("JWT invalid") return nil, fmt.Errorf("JWT invalid")
} }
func mockMarkPATUsed(token string) error { func mockMarkPATUsed(_ context.Context, token string) error {
if token == tokenID { if token == tokenID {
return nil return nil
} }
return fmt.Errorf("Should never get reached") return fmt.Errorf("Should never get reached")
} }
func mockCheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error { func mockCheckUserAccessByJWTGroups(_ context.Context, claims jwtclaims.AuthorizationClaims) error {
if testAccount.Id != claims.AccountId { if testAccount.Id != claims.AccountId {
return fmt.Errorf("account with id %s does not exist", claims.AccountId) return fmt.Errorf("account with id %s does not exist", claims.AccountId)
} }

View File

@@ -56,7 +56,7 @@ func ShouldBypass(requestPath string, h http.Handler, w http.ResponseWriter, r *
for bypassPath := range bypassPaths { for bypassPath := range bypassPaths {
matched, err := path.Match(bypassPath, requestPath) matched, err := path.Match(bypassPath, requestPath)
if err != nil { if err != nil {
log.Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err) log.WithContext(r.Context()).Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err)
continue continue
} }
if matched { if matched {

View File

@@ -36,16 +36,16 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg
// GetAllNameservers returns the list of nameserver groups for the account // GetAllNameservers returns the list of nameserver groups for the account
func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) { func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
log.Error(err) log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
nsGroups, err := h.accountManager.ListNameServerGroups(account.Id, user.Id) nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), account.Id, user.Id)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -54,15 +54,15 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re
apiNameservers = append(apiNameservers, toNameserverGroupResponse(r)) apiNameservers = append(apiNameservers, toNameserverGroupResponse(r))
} }
util.WriteJSONObject(w, apiNameservers) util.WriteJSONObject(r.Context(), w, apiNameservers)
} }
// CreateNameserverGroup handles nameserver group creation request // CreateNameserverGroup handles nameserver group creation request
func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) { func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -75,33 +75,33 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
nsList, err := toServerNSList(req.Nameservers) nsList, err := toServerNSList(req.Nameservers)
if err != nil { if err != nil {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
return return
} }
nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled) nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
resp := toNameserverGroupResponse(nsGroup) resp := toNameserverGroupResponse(nsGroup)
util.WriteJSONObject(w, &resp) util.WriteJSONObject(r.Context(), w, &resp)
} }
// UpdateNameserverGroup handles update to a nameserver group identified by a given ID // UpdateNameserverGroup handles update to a nameserver group identified by a given ID
func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) { func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
nsGroupID := mux.Vars(r)["nsgroupId"] nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 { if len(nsGroupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return return
} }
@@ -114,7 +114,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
nsList, err := toServerNSList(req.Nameservers) nsList, err := toServerNSList(req.Nameservers)
if err != nil { if err != nil {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
return return
} }
@@ -130,66 +130,66 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
SearchDomainsEnabled: req.SearchDomainsEnabled, SearchDomainsEnabled: req.SearchDomainsEnabled,
} }
err = h.accountManager.SaveNameServerGroup(account.Id, user.Id, updatedNSGroup) err = h.accountManager.SaveNameServerGroup(r.Context(), account.Id, user.Id, updatedNSGroup)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
resp := toNameserverGroupResponse(updatedNSGroup) resp := toNameserverGroupResponse(updatedNSGroup)
util.WriteJSONObject(w, &resp) util.WriteJSONObject(r.Context(), w, &resp)
} }
// DeleteNameserverGroup handles nameserver group deletion request // DeleteNameserverGroup handles nameserver group deletion request
func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) { func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
nsGroupID := mux.Vars(r)["nsgroupId"] nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 { if len(nsGroupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return return
} }
err = h.accountManager.DeleteNameServerGroup(account.Id, nsGroupID, user.Id) err = h.accountManager.DeleteNameServerGroup(r.Context(), account.Id, nsGroupID, user.Id)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(w, emptyObject{}) util.WriteJSONObject(r.Context(), w, emptyObject{})
} }
// GetNameserverGroup handles a nameserver group Get request identified by ID // GetNameserverGroup handles a nameserver group Get request identified by ID
func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) { func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
log.Error(err) log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
nsGroupID := mux.Vars(r)["nsgroupId"] nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 { if len(nsGroupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return return
} }
nsGroup, err := h.accountManager.GetNameServerGroup(account.Id, user.Id, nsGroupID) nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), account.Id, user.Id, nsGroupID)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
resp := toNameserverGroupResponse(nsGroup) resp := toNameserverGroupResponse(nsGroup)
util.WriteJSONObject(w, &resp) util.WriteJSONObject(r.Context(), w, &resp)
} }
func toServerNSList(apiNSList []api.Nameserver) ([]nbdns.NameServer, error) { func toServerNSList(apiNSList []api.Nameserver) ([]nbdns.NameServer, error) {

View File

@@ -2,6 +2,7 @@ package http
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
@@ -61,13 +62,13 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{
func initNameserversTestData() *NameserversHandler { func initNameserversTestData() *NameserversHandler {
return &NameserversHandler{ return &NameserversHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetNameServerGroupFunc: func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { GetNameServerGroupFunc: func(_ context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
if nsGroupID == existingNSGroupID { if nsGroupID == existingNSGroupID {
return baseExistingNSGroup.Copy(), nil return baseExistingNSGroup.Copy(), nil
} }
return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID) return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
}, },
CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, _ string, searchDomains bool) (*nbdns.NameServerGroup, error) { CreateNameServerGroupFunc: func(_ context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, _ string, searchDomains bool) (*nbdns.NameServerGroup, error) {
return &nbdns.NameServerGroup{ return &nbdns.NameServerGroup{
ID: existingNSGroupID, ID: existingNSGroupID,
Name: name, Name: name,
@@ -80,16 +81,16 @@ func initNameserversTestData() *NameserversHandler {
SearchDomainsEnabled: searchDomains, SearchDomainsEnabled: searchDomains,
}, nil }, nil
}, },
DeleteNameServerGroupFunc: func(accountID, nsGroupID, _ string) error { DeleteNameServerGroupFunc: func(_ context.Context, accountID, nsGroupID, _ string) error {
return nil return nil
}, },
SaveNameServerGroupFunc: func(accountID, _ string, nsGroupToSave *nbdns.NameServerGroup) error { SaveNameServerGroupFunc: func(_ context.Context, accountID, _ string, nsGroupToSave *nbdns.NameServerGroup) error {
if nsGroupToSave.ID == existingNSGroupID { if nsGroupToSave.ID == existingNSGroupID {
return nil return nil
} }
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID) return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
}, },
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingNSAccount, testingAccount.Users["test_user"], nil return testingNSAccount, testingAccount.Users["test_user"], nil
}, },
}, },

View File

@@ -34,22 +34,22 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH
// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user // GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
userID := vars["userId"] userID := vars["userId"]
if len(userID) == 0 { if len(userID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return return
} }
pats, err := h.accountManager.GetAllPATs(account.Id, user.Id, userID) pats, err := h.accountManager.GetAllPATs(r.Context(), account.Id, user.Id, userID)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -58,53 +58,53 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
patResponse = append(patResponse, toPATResponse(pat)) patResponse = append(patResponse, toPATResponse(pat))
} }
util.WriteJSONObject(w, patResponse) util.WriteJSONObject(r.Context(), w, patResponse)
} }
// GetToken is HTTP GET handler that returns a personal access token for the given user // GetToken is HTTP GET handler that returns a personal access token for the given user
func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
targetUserID := vars["userId"] targetUserID := vars["userId"]
if len(targetUserID) == 0 { if len(targetUserID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return return
} }
tokenID := vars["tokenId"] tokenID := vars["tokenId"]
if len(tokenID) == 0 { if len(tokenID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid token ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid token ID"), w)
return return
} }
pat, err := h.accountManager.GetPAT(account.Id, user.Id, targetUserID, tokenID) pat, err := h.accountManager.GetPAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(w, toPATResponse(pat)) util.WriteJSONObject(r.Context(), w, toPATResponse(pat))
} }
// CreateToken is HTTP POST handler that creates a personal access token for the given user // CreateToken is HTTP POST handler that creates a personal access token for the given user
func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
targetUserID := vars["userId"] targetUserID := vars["userId"]
if len(targetUserID) == 0 { if len(targetUserID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return return
} }
@@ -115,44 +115,44 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
return return
} }
pat, err := h.accountManager.CreatePAT(account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn) pat, err := h.accountManager.CreatePAT(r.Context(), account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(w, toPATGeneratedResponse(pat)) util.WriteJSONObject(r.Context(), w, toPATGeneratedResponse(pat))
} }
// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user // DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user
func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
targetUserID := vars["userId"] targetUserID := vars["userId"]
if len(targetUserID) == 0 { if len(targetUserID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return return
} }
tokenID := vars["tokenId"] tokenID := vars["tokenId"]
if len(tokenID) == 0 { if len(tokenID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid token ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid token ID"), w)
return return
} }
err = h.accountManager.DeletePAT(account.Id, user.Id, targetUserID, tokenID) err = h.accountManager.DeletePAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(w, emptyObject{}) util.WriteJSONObject(r.Context(), w, emptyObject{})
} }
func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken { func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken {

View File

@@ -2,6 +2,7 @@ package http
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
@@ -63,7 +64,7 @@ var testAccount = &server.Account{
func initPATTestData() *PATHandler { func initPATTestData() *PATHandler {
return &PATHandler{ return &PATHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
CreatePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
if accountID != existingAccountID { if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
} }
@@ -76,10 +77,10 @@ func initPATTestData() *PATHandler {
}, nil }, nil
}, },
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testAccount, testAccount.Users[existingUserID], nil return testAccount, testAccount.Users[existingUserID], nil
}, },
DeletePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) error { DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
if accountID != existingAccountID { if accountID != existingAccountID {
return status.Errorf(status.NotFound, "account with ID %s not found", accountID) return status.Errorf(status.NotFound, "account with ID %s not found", accountID)
} }
@@ -91,7 +92,7 @@ func initPATTestData() *PATHandler {
} }
return nil return nil
}, },
GetPATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { GetPATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
if accountID != existingAccountID { if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
} }
@@ -103,7 +104,7 @@ func initPATTestData() *PATHandler {
} }
return testAccount.Users[existingUserID].PATs[existingTokenID], nil return testAccount.Users[existingUserID].PATs[existingTokenID], nil
}, },
GetAllPATsFunc: func(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { GetAllPATsFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
if accountID != existingAccountID { if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
} }

View File

@@ -1,6 +1,7 @@
package http package http
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@@ -47,16 +48,16 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error)
return peerToReturn, nil return peerToReturn, nil
} }
func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w http.ResponseWriter) { func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) {
peer, err := h.accountManager.GetPeer(account.Id, peerID, userID) peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(ctx, err, w)
return return
} }
peerToReturn, err := h.checkPeerStatus(peer) peerToReturn, err := h.checkPeerStatus(peer)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(ctx, err, w)
return return
} }
dnsDomain := h.accountManager.GetDNSDomain() dnsDomain := h.accountManager.GetDNSDomain()
@@ -65,19 +66,19 @@ func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w
validPeers, err := h.accountManager.GetValidatedPeers(account) validPeers, err := h.accountManager.GetValidatedPeers(account)
if err != nil { if err != nil {
log.Errorf("failed to list appreoved peers: %v", err) log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
util.WriteError(fmt.Errorf("internal error"), w) util.WriteError(ctx, fmt.Errorf("internal error"), w)
return return
} }
netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validPeers) netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain) accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID] _, valid := validPeers[peer.ID]
util.WriteJSONObject(w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid)) util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid))
} }
func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) { func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
req := &api.PeerRequest{} req := &api.PeerRequest{}
err := json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
@@ -99,9 +100,9 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe
} }
} }
peer, err := h.accountManager.UpdatePeer(account.Id, user.Id, update) peer, err := h.accountManager.UpdatePeer(ctx, account.Id, user.Id, update)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(ctx, err, w)
return return
} }
dnsDomain := h.accountManager.GetDNSDomain() dnsDomain := h.accountManager.GetDNSDomain()
@@ -110,75 +111,75 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe
validPeers, err := h.accountManager.GetValidatedPeers(account) validPeers, err := h.accountManager.GetValidatedPeers(account)
if err != nil { if err != nil {
log.Errorf("failed to list appreoved peers: %v", err) log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
util.WriteError(fmt.Errorf("internal error"), w) util.WriteError(ctx, fmt.Errorf("internal error"), w)
return return
} }
netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validPeers) netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain) accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID] _, valid := validPeers[peer.ID]
util.WriteJSONObject(w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid)) util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid))
} }
func (h *PeersHandler) deletePeer(accountID, userID string, peerID string, w http.ResponseWriter) { func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
err := h.accountManager.DeletePeer(accountID, peerID, userID) err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID)
if err != nil { if err != nil {
log.Errorf("failed to delete peer: %v", err) log.WithContext(ctx).Errorf("failed to delete peer: %v", err)
util.WriteError(err, w) util.WriteError(ctx, err, w)
return return
} }
util.WriteJSONObject(w, emptyObject{}) util.WriteJSONObject(ctx, w, emptyObject{})
} }
// HandlePeer handles all peer requests for GET, PUT and DELETE operations // HandlePeer handles all peer requests for GET, PUT and DELETE operations
func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
peerID := vars["peerId"] peerID := vars["peerId"]
if len(peerID) == 0 { if len(peerID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid peer ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
return return
} }
switch r.Method { switch r.Method {
case http.MethodDelete: case http.MethodDelete:
h.deletePeer(account.Id, user.Id, peerID, w) h.deletePeer(r.Context(), account.Id, user.Id, peerID, w)
return return
case http.MethodPut: case http.MethodPut:
h.updatePeer(account, user, peerID, w, r) h.updatePeer(r.Context(), account, user, peerID, w, r)
return return
case http.MethodGet: case http.MethodGet:
h.getPeer(account, peerID, user.Id, w) h.getPeer(r.Context(), account, peerID, user.Id, w)
return return
default: default:
util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w) util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
} }
} }
// GetAllPeers returns a list of all peers associated with a provided account // GetAllPeers returns a list of all peers associated with a provided account
func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w) util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
return return
} }
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
peers, err := h.accountManager.GetPeers(account.Id, user.Id) peers, err := h.accountManager.GetPeers(r.Context(), account.Id, user.Id)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -188,34 +189,34 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
for _, peer := range peers { for _, peer := range peers {
peerToReturn, err := h.checkPeerStatus(peer) peerToReturn, err := h.checkPeerStatus(peer)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
accessiblePeerNumbers, _ := h.accessiblePeersNumber(account, peer.ID) accessiblePeerNumbers, _ := h.accessiblePeersNumber(r.Context(), account, peer.ID)
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers)) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers))
} }
validPeersMap, err := h.accountManager.GetValidatedPeers(account) validPeersMap, err := h.accountManager.GetValidatedPeers(account)
if err != nil { if err != nil {
log.Errorf("failed to list appreoved peers: %v", err) log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err)
util.WriteError(fmt.Errorf("internal error"), w) util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
return return
} }
h.setApprovalRequiredFlag(respBody, validPeersMap) h.setApprovalRequiredFlag(respBody, validPeersMap)
util.WriteJSONObject(w, respBody) util.WriteJSONObject(r.Context(), w, respBody)
} }
func (h *PeersHandler) accessiblePeersNumber(account *server.Account, peerID string) (int, error) { func (h *PeersHandler) accessiblePeersNumber(ctx context.Context, account *server.Account, peerID string) (int, error) {
validatedPeersMap, err := h.accountManager.GetValidatedPeers(account) validatedPeersMap, err := h.accountManager.GetValidatedPeers(account)
if err != nil { if err != nil {
return 0, err return 0, err
} }
netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validatedPeersMap) netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validatedPeersMap)
return len(netMap.Peers) + len(netMap.OfflinePeers), nil return len(netMap.Peers) + len(netMap.OfflinePeers), nil
} }

View File

@@ -2,6 +2,7 @@ package http
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"io" "io"
"net" "net"
@@ -29,7 +30,7 @@ const noUpdateChannelTestPeerID = "no-update-channel"
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
return &PeersHandler{ return &PeersHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
UpdatePeerFunc: func(accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
var p *nbpeer.Peer var p *nbpeer.Peer
for _, peer := range peers { for _, peer := range peers {
if update.ID == peer.ID { if update.ID == peer.ID {
@@ -42,7 +43,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
p.Name = update.Name p.Name = update.Name
return p, nil return p, nil
}, },
GetPeerFunc: func(accountID, peerID, userID string) (*nbpeer.Peer, error) { GetPeerFunc: func(_ context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
var p *nbpeer.Peer var p *nbpeer.Peer
for _, peer := range peers { for _, peer := range peers {
if peerID == peer.ID { if peerID == peer.ID {
@@ -52,13 +53,13 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
} }
return p, nil return p, nil
}, },
GetPeersFunc: func(accountID, userID string) ([]*nbpeer.Peer, error) { GetPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
return peers, nil return peers, nil
}, },
GetDNSDomainFunc: func() string { GetDNSDomainFunc: func() string {
return "netbird.selfhosted" return "netbird.selfhosted"
}, },
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user") user := server.NewAdminUser("test_user")
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: claims.AccountId,

View File

@@ -35,15 +35,15 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *
// GetAllPolicies list for the account // GetAllPolicies list for the account
func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
accountPolicies, err := h.accountManager.ListPolicies(account.Id, user.Id) accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, user.Id)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -51,28 +51,28 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
for _, policy := range accountPolicies { for _, policy := range accountPolicies {
resp := toPolicyResponse(account, policy) resp := toPolicyResponse(account, policy)
if len(resp.Rules) == 0 { if len(resp.Rules) == 0 {
util.WriteError(status.Errorf(status.Internal, "no rules in the policy"), w) util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return return
} }
policies = append(policies, resp) policies = append(policies, resp)
} }
util.WriteJSONObject(w, policies) util.WriteJSONObject(r.Context(), w, policies)
} }
// UpdatePolicy handles update to a policy identified by a given ID // UpdatePolicy handles update to a policy identified by a given ID
func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
policyID := vars["policyId"] policyID := vars["policyId"]
if len(policyID) == 0 { if len(policyID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
return return
} }
@@ -84,7 +84,7 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
} }
} }
if policyIdx < 0 { if policyIdx < 0 {
util.WriteError(status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w) util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
return return
} }
@@ -94,9 +94,9 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
// CreatePolicy handles policy creation request // CreatePolicy handles policy creation request
func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) { func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -118,12 +118,12 @@ func (h *Policies) savePolicy(
} }
if req.Name == "" { if req.Name == "" {
util.WriteError(status.Errorf(status.InvalidArgument, "policy name shouldn't be empty"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy name shouldn't be empty"), w)
return return
} }
if len(req.Rules) == 0 { if len(req.Rules) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "policy rules shouldn't be empty"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy rules shouldn't be empty"), w)
return return
} }
@@ -137,31 +137,31 @@ func (h *Policies) savePolicy(
Enabled: req.Enabled, Enabled: req.Enabled,
Description: req.Description, Description: req.Description,
} }
for _, r := range req.Rules { for _, rule := range req.Rules {
pr := server.PolicyRule{ pr := server.PolicyRule{
ID: policyID, //TODO: when policy can contain multiple rules, need refactor ID: policyID, // TODO: when policy can contain multiple rules, need refactor
Name: r.Name, Name: rule.Name,
Destinations: groupMinimumsToStrings(account, r.Destinations), Destinations: groupMinimumsToStrings(account, rule.Destinations),
Sources: groupMinimumsToStrings(account, r.Sources), Sources: groupMinimumsToStrings(account, rule.Sources),
Bidirectional: r.Bidirectional, Bidirectional: rule.Bidirectional,
} }
pr.Enabled = r.Enabled pr.Enabled = rule.Enabled
if r.Description != nil { if rule.Description != nil {
pr.Description = *r.Description pr.Description = *rule.Description
} }
switch r.Action { switch rule.Action {
case api.PolicyRuleUpdateActionAccept: case api.PolicyRuleUpdateActionAccept:
pr.Action = server.PolicyTrafficActionAccept pr.Action = server.PolicyTrafficActionAccept
case api.PolicyRuleUpdateActionDrop: case api.PolicyRuleUpdateActionDrop:
pr.Action = server.PolicyTrafficActionDrop pr.Action = server.PolicyTrafficActionDrop
default: default:
util.WriteError(status.Errorf(status.InvalidArgument, "unknown action type"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown action type"), w)
return return
} }
switch r.Protocol { switch rule.Protocol {
case api.PolicyRuleUpdateProtocolAll: case api.PolicyRuleUpdateProtocolAll:
pr.Protocol = server.PolicyRuleProtocolALL pr.Protocol = server.PolicyRuleProtocolALL
case api.PolicyRuleUpdateProtocolTcp: case api.PolicyRuleUpdateProtocolTcp:
@@ -171,14 +171,14 @@ func (h *Policies) savePolicy(
case api.PolicyRuleUpdateProtocolIcmp: case api.PolicyRuleUpdateProtocolIcmp:
pr.Protocol = server.PolicyRuleProtocolICMP pr.Protocol = server.PolicyRuleProtocolICMP
default: default:
util.WriteError(status.Errorf(status.InvalidArgument, "unknown protocol type: %v", r.Protocol), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w)
return return
} }
if r.Ports != nil && len(*r.Ports) != 0 { if rule.Ports != nil && len(*rule.Ports) != 0 {
for _, v := range *r.Ports { for _, v := range *rule.Ports {
if port, err := strconv.Atoi(v); err != nil || port < 1 || port > 65535 { if port, err := strconv.Atoi(v); err != nil || port < 1 || port > 65535 {
util.WriteError(status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w)
return return
} }
pr.Ports = append(pr.Ports, v) pr.Ports = append(pr.Ports, v)
@@ -189,16 +189,16 @@ func (h *Policies) savePolicy(
switch pr.Protocol { switch pr.Protocol {
case server.PolicyRuleProtocolALL, server.PolicyRuleProtocolICMP: case server.PolicyRuleProtocolALL, server.PolicyRuleProtocolICMP:
if len(pr.Ports) != 0 { if len(pr.Ports) != 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w)
return return
} }
if !pr.Bidirectional { if !pr.Bidirectional {
util.WriteError(status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
return return
} }
case server.PolicyRuleProtocolTCP, server.PolicyRuleProtocolUDP: case server.PolicyRuleProtocolTCP, server.PolicyRuleProtocolUDP:
if !pr.Bidirectional && len(pr.Ports) == 0 { if !pr.Bidirectional && len(pr.Ports) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
return return
} }
} }
@@ -210,26 +210,26 @@ func (h *Policies) savePolicy(
policy.SourcePostureChecks = sourcePostureChecksToStrings(account, *req.SourcePostureChecks) policy.SourcePostureChecks = sourcePostureChecksToStrings(account, *req.SourcePostureChecks)
} }
if err := h.accountManager.SavePolicy(account.Id, user.Id, &policy); err != nil { if err := h.accountManager.SavePolicy(r.Context(), account.Id, user.Id, &policy); err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
resp := toPolicyResponse(account, &policy) resp := toPolicyResponse(account, &policy)
if len(resp.Rules) == 0 { if len(resp.Rules) == 0 {
util.WriteError(status.Errorf(status.Internal, "no rules in the policy"), w) util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return return
} }
util.WriteJSONObject(w, resp) util.WriteJSONObject(r.Context(), w, resp)
} }
// DeletePolicy handles policy deletion request // DeletePolicy handles policy deletion request
func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
aID := account.Id aID := account.Id
@@ -237,24 +237,24 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r) vars := mux.Vars(r)
policyID := vars["policyId"] policyID := vars["policyId"]
if len(policyID) == 0 { if len(policyID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
return return
} }
if err = h.accountManager.DeletePolicy(aID, policyID, user.Id); err != nil { if err = h.accountManager.DeletePolicy(r.Context(), aID, policyID, user.Id); err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(w, emptyObject{}) util.WriteJSONObject(r.Context(), w, emptyObject{})
} }
// GetPolicy handles a group Get request identified by ID // GetPolicy handles a group Get request identified by ID
func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) { func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -263,25 +263,25 @@ func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r) vars := mux.Vars(r)
policyID := vars["policyId"] policyID := vars["policyId"]
if len(policyID) == 0 { if len(policyID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
return return
} }
policy, err := h.accountManager.GetPolicy(account.Id, policyID, user.Id) policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, user.Id)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
resp := toPolicyResponse(account, policy) resp := toPolicyResponse(account, policy)
if len(resp.Rules) == 0 { if len(resp.Rules) == 0 {
util.WriteError(status.Errorf(status.Internal, "no rules in the policy"), w) util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return return
} }
util.WriteJSONObject(w, resp) util.WriteJSONObject(r.Context(), w, resp)
default: default:
util.WriteError(status.Errorf(status.NotFound, "method not found"), w) util.WriteError(r.Context(), status.Errorf(status.NotFound, "method not found"), w)
} }
} }

View File

@@ -2,6 +2,7 @@ package http
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
@@ -30,21 +31,21 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
} }
return &Policies{ return &Policies{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetPolicyFunc: func(_, policyID, _ string) (*server.Policy, error) { GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*server.Policy, error) {
policy, ok := testPolicies[policyID] policy, ok := testPolicies[policyID]
if !ok { if !ok {
return nil, status.Errorf(status.NotFound, "policy not found") return nil, status.Errorf(status.NotFound, "policy not found")
} }
return policy, nil return policy, nil
}, },
SavePolicyFunc: func(_, _ string, policy *server.Policy) error { SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) error {
if !strings.HasPrefix(policy.ID, "id-") { if !strings.HasPrefix(policy.ID, "id-") {
policy.ID = "id-was-set" policy.ID = "id-was-set"
policy.Rules[0].ID = "id-was-set" policy.Rules[0].ID = "id-was-set"
} }
return nil return nil
}, },
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user") user := server.NewAdminUser("test_user")
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: claims.AccountId,

View File

@@ -37,15 +37,15 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa
// GetAllPostureChecks list for the account // GetAllPostureChecks list for the account
func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) { func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r) claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(claims) account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
accountPostureChecks, err := p.accountManager.ListPostureChecks(account.Id, user.Id) accountPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), account.Id, user.Id)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -54,22 +54,22 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt
postureChecks = append(postureChecks, postureCheck.ToAPIResponse()) postureChecks = append(postureChecks, postureCheck.ToAPIResponse())
} }
util.WriteJSONObject(w, postureChecks) util.WriteJSONObject(r.Context(), w, postureChecks)
} }
// UpdatePostureCheck handles update to a posture check identified by a given ID // UpdatePostureCheck handles update to a posture check identified by a given ID
func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) { func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r) claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(claims) account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"] postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 { if len(postureChecksID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
return return
} }
@@ -81,7 +81,7 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http
} }
} }
if postureChecksIdx < 0 { if postureChecksIdx < 0 {
util.WriteError(status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w) util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w)
return return
} }
@@ -91,9 +91,9 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http
// CreatePostureCheck handles posture check creation request // CreatePostureCheck handles posture check creation request
func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) { func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r) claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(claims) account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -103,50 +103,50 @@ func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http
// GetPostureCheck handles a posture check Get request identified by ID // GetPostureCheck handles a posture check Get request identified by ID
func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) { func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r) claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(claims) account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"] postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 { if len(postureChecksID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
return return
} }
postureChecks, err := p.accountManager.GetPostureChecks(account.Id, postureChecksID, user.Id) postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), account.Id, postureChecksID, user.Id)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(w, postureChecks.ToAPIResponse()) util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse())
} }
// DeletePostureCheck handles posture check deletion request // DeletePostureCheck handles posture check deletion request
func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) { func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r) claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(claims) account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"] postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 { if len(postureChecksID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
return return
} }
if err = p.accountManager.DeletePostureChecks(account.Id, postureChecksID, user.Id); err != nil { if err = p.accountManager.DeletePostureChecks(r.Context(), account.Id, postureChecksID, user.Id); err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(w, emptyObject{}) util.WriteJSONObject(r.Context(), w, emptyObject{})
} }
// savePostureChecks handles posture checks create and update // savePostureChecks handles posture checks create and update
@@ -169,7 +169,7 @@ func (p *PostureChecksHandler) savePostureChecks(
if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil { if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil {
if p.geolocationManager == nil { if p.geolocationManager == nil {
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+ util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
"Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w) "Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w)
return return
} }
@@ -177,14 +177,14 @@ func (p *PostureChecksHandler) savePostureChecks(
postureChecks, err := posture.NewChecksFromAPIPostureCheckUpdate(req, postureChecksID) postureChecks, err := posture.NewChecksFromAPIPostureCheckUpdate(req, postureChecksID)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
if err := p.accountManager.SavePostureChecks(account.Id, user.Id, postureChecks); err != nil { if err := p.accountManager.SavePostureChecks(r.Context(), account.Id, user.Id, postureChecks); err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(w, postureChecks.ToAPIResponse()) util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse())
} }

View File

@@ -2,6 +2,7 @@ package http
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
@@ -33,14 +34,14 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
return &PostureChecksHandler{ return &PostureChecksHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetPostureChecksFunc: func(accountID, postureChecksID, userID string) (*posture.Checks, error) { GetPostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
p, ok := testPostureChecks[postureChecksID] p, ok := testPostureChecks[postureChecksID]
if !ok { if !ok {
return nil, status.Errorf(status.NotFound, "posture checks not found") return nil, status.Errorf(status.NotFound, "posture checks not found")
} }
return p, nil return p, nil
}, },
SavePostureChecksFunc: func(accountID, userID string, postureChecks *posture.Checks) error { SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error {
postureChecks.ID = "postureCheck" postureChecks.ID = "postureCheck"
testPostureChecks[postureChecks.ID] = postureChecks testPostureChecks[postureChecks.ID] = postureChecks
@@ -50,7 +51,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
return nil return nil
}, },
DeletePostureChecksFunc: func(accountID, postureChecksID, userID string) error { DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error {
_, ok := testPostureChecks[postureChecksID] _, ok := testPostureChecks[postureChecksID]
if !ok { if !ok {
return status.Errorf(status.NotFound, "posture checks not found") return status.Errorf(status.NotFound, "posture checks not found")
@@ -59,14 +60,14 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
return nil return nil
}, },
ListPostureChecksFunc: func(accountID, userID string) ([]*posture.Checks, error) { ListPostureChecksFunc: func(_ context.Context, accountID, userID string) ([]*posture.Checks, error) {
accountPostureChecks := make([]*posture.Checks, len(testPostureChecks)) accountPostureChecks := make([]*posture.Checks, len(testPostureChecks))
for _, p := range testPostureChecks { for _, p := range testPostureChecks {
accountPostureChecks = append(accountPostureChecks, p) accountPostureChecks = append(accountPostureChecks, p)
} }
return accountPostureChecks, nil return accountPostureChecks, nil
}, },
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user") user := server.NewAdminUser("test_user")
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: claims.AccountId,

View File

@@ -43,36 +43,36 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro
// GetAllRoutes returns the list of routes for the account // GetAllRoutes returns the list of routes for the account
func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
routes, err := h.accountManager.ListRoutes(account.Id, user.Id) routes, err := h.accountManager.ListRoutes(r.Context(), account.Id, user.Id)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
apiRoutes := make([]*api.Route, 0) apiRoutes := make([]*api.Route, 0)
for _, r := range routes { for _, route := range routes {
route, err := toRouteResponse(r) route, err := toRouteResponse(route)
if err != nil { if err != nil {
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w) util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w)
return return
} }
apiRoutes = append(apiRoutes, route) apiRoutes = append(apiRoutes, route)
} }
util.WriteJSONObject(w, apiRoutes) util.WriteJSONObject(r.Context(), w, apiRoutes)
} }
// CreateRoute handles route creation request // CreateRoute handles route creation request
func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -84,7 +84,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
} }
if err := h.validateRoute(req); err != nil { if err := h.validateRoute(req); err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -94,7 +94,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
if req.Domains != nil { if req.Domains != nil {
d, err := validateDomains(*req.Domains) d, err := validateDomains(*req.Domains)
if err != nil { if err != nil {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
return return
} }
domains = d domains = d
@@ -102,7 +102,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
} else if req.Network != nil { } else if req.Network != nil {
networkType, newPrefix, err = route.ParseNetwork(*req.Network) networkType, newPrefix, err = route.ParseNetwork(*req.Network)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
} }
@@ -120,24 +120,24 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
// Do not allow non-Linux peers // Do not allow non-Linux peers
if peer := account.GetPeer(peerId); peer != nil { if peer := account.GetPeer(peerId); peer != nil {
if peer.Meta.GoOS != "linux" { if peer.Meta.GoOS != "linux" {
util.WriteError(status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w)
return return
} }
} }
newRoute, err := h.accountManager.CreateRoute(account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute) newRoute, err := h.accountManager.CreateRoute(r.Context(), account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
routes, err := toRouteResponse(newRoute) routes, err := toRouteResponse(newRoute)
if err != nil { if err != nil {
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w) util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w)
return return
} }
util.WriteJSONObject(w, routes) util.WriteJSONObject(r.Context(), w, routes)
} }
func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) error { func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) error {
@@ -168,22 +168,22 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
// UpdateRoute handles update to a route identified by a given ID // UpdateRoute handles update to a route identified by a given ID
func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
routeID := vars["routeId"] routeID := vars["routeId"]
if len(routeID) == 0 { if len(routeID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return return
} }
_, err = h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id) _, err = h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -195,7 +195,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
} }
if err := h.validateRoute(req); err != nil { if err := h.validateRoute(req); err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -207,7 +207,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
// do not allow non Linux peers // do not allow non Linux peers
if peer := account.GetPeer(peerID); peer != nil { if peer := account.GetPeer(peerID); peer != nil {
if peer.Meta.GoOS != "linux" { if peer.Meta.GoOS != "linux" {
util.WriteError(status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w)
return return
} }
} }
@@ -226,7 +226,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
if req.Domains != nil { if req.Domains != nil {
d, err := validateDomains(*req.Domains) d, err := validateDomains(*req.Domains)
if err != nil { if err != nil {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
return return
} }
newRoute.Domains = d newRoute.Domains = d
@@ -234,7 +234,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
} else if req.Network != nil { } else if req.Network != nil {
newRoute.NetworkType, newRoute.Network, err = route.ParseNetwork(*req.Network) newRoute.NetworkType, newRoute.Network, err = route.ParseNetwork(*req.Network)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
} }
@@ -247,73 +247,73 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
newRoute.PeerGroups = *req.PeerGroups newRoute.PeerGroups = *req.PeerGroups
} }
err = h.accountManager.SaveRoute(account.Id, user.Id, newRoute) err = h.accountManager.SaveRoute(r.Context(), account.Id, user.Id, newRoute)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
routes, err := toRouteResponse(newRoute) routes, err := toRouteResponse(newRoute)
if err != nil { if err != nil {
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w) util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w)
return return
} }
util.WriteJSONObject(w, routes) util.WriteJSONObject(r.Context(), w, routes)
} }
// DeleteRoute handles route deletion request // DeleteRoute handles route deletion request
func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
routeID := mux.Vars(r)["routeId"] routeID := mux.Vars(r)["routeId"]
if len(routeID) == 0 { if len(routeID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return return
} }
err = h.accountManager.DeleteRoute(account.Id, route.ID(routeID), user.Id) err = h.accountManager.DeleteRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(w, emptyObject{}) util.WriteJSONObject(r.Context(), w, emptyObject{})
} }
// GetRoute handles a route Get request identified by ID // GetRoute handles a route Get request identified by ID
func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
routeID := mux.Vars(r)["routeId"] routeID := mux.Vars(r)["routeId"]
if len(routeID) == 0 { if len(routeID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return return
} }
foundRoute, err := h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id) foundRoute, err := h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
if err != nil { if err != nil {
util.WriteError(status.Errorf(status.NotFound, "route not found"), w) util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w)
return return
} }
routes, err := toRouteResponse(foundRoute) routes, err := toRouteResponse(foundRoute)
if err != nil { if err != nil {
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w) util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w)
return return
} }
util.WriteJSONObject(w, routes) util.WriteJSONObject(r.Context(), w, routes)
} }
func toRouteResponse(serverRoute *route.Route) (*api.Route, error) { func toRouteResponse(serverRoute *route.Route) (*api.Route, error) {

View File

@@ -2,6 +2,7 @@ package http
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@@ -89,7 +90,7 @@ var testingAccount = &server.Account{
func initRoutesTestData() *RoutesHandler { func initRoutesTestData() *RoutesHandler {
return &RoutesHandler{ return &RoutesHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetRouteFunc: func(_ string, routeID route.ID, _ string) (*route.Route, error) { GetRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) (*route.Route, error) {
if routeID == existingRouteID { if routeID == existingRouteID {
return baseExistingRoute, nil return baseExistingRoute, nil
} }
@@ -104,7 +105,7 @@ func initRoutesTestData() *RoutesHandler {
} }
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
}, },
CreateRouteFunc: func(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) { CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) {
if peerID == notFoundPeerID { if peerID == notFoundPeerID {
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
} }
@@ -126,19 +127,19 @@ func initRoutesTestData() *RoutesHandler {
KeepRoute: keepRoute, KeepRoute: keepRoute,
}, nil }, nil
}, },
SaveRouteFunc: func(_, _ string, r *route.Route) error { SaveRouteFunc: func(_ context.Context, _, _ string, r *route.Route) error {
if r.Peer == notFoundPeerID { if r.Peer == notFoundPeerID {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer) return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer)
} }
return nil return nil
}, },
DeleteRouteFunc: func(_ string, routeID route.ID, _ string) error { DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error {
if routeID != existingRouteID { if routeID != existingRouteID {
return status.Errorf(status.NotFound, "Peer with ID %s not found", routeID) return status.Errorf(status.NotFound, "Peer with ID %s not found", routeID)
} }
return nil return nil
}, },
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingAccount, testingAccount.Users["test_user"], nil return testingAccount, testingAccount.Users["test_user"], nil
}, },
}, },

View File

@@ -1,6 +1,7 @@
package http package http
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"time" "time"
@@ -34,9 +35,9 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg)
// CreateSetupKey is a POST requests that creates a new SetupKey // CreateSetupKey is a POST requests that creates a new SetupKey
func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) { func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -48,13 +49,13 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
} }
if req.Name == "" { if req.Name == "" {
util.WriteError(status.Errorf(status.InvalidArgument, "setup key name shouldn't be empty"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name shouldn't be empty"), w)
return return
} }
if !(server.SetupKeyType(req.Type) == server.SetupKeyReusable || if !(server.SetupKeyType(req.Type) == server.SetupKeyReusable ||
server.SetupKeyType(req.Type) == server.SetupKeyOneOff) { server.SetupKeyType(req.Type) == server.SetupKeyOneOff) {
util.WriteError(status.Errorf(status.InvalidArgument, "unknown setup key type %s", req.Type), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown setup key type %s", req.Type), w)
return return
} }
@@ -63,7 +64,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
day := time.Hour * 24 day := time.Hour * 24
year := day * 365 year := day * 365
if expiresIn < day || expiresIn > year { if expiresIn < day || expiresIn > year {
util.WriteError(status.Errorf(status.InvalidArgument, "expiresIn should be between 1 day and 365 days"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "expiresIn should be between 1 day and 365 days"), w)
return return
} }
@@ -75,54 +76,54 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
if req.Ephemeral != nil { if req.Ephemeral != nil {
ephemeral = *req.Ephemeral ephemeral = *req.Ephemeral
} }
setupKey, err := h.accountManager.CreateSetupKey(account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn, setupKey, err := h.accountManager.CreateSetupKey(r.Context(), account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
req.AutoGroups, req.UsageLimit, user.Id, ephemeral) req.AutoGroups, req.UsageLimit, user.Id, ephemeral)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
writeSuccess(w, setupKey) writeSuccess(r.Context(), w, setupKey)
} }
// GetSetupKey is a GET request to get a SetupKey by ID // GetSetupKey is a GET request to get a SetupKey by ID
func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
keyID := vars["keyId"] keyID := vars["keyId"]
if len(keyID) == 0 { if len(keyID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w)
return return
} }
key, err := h.accountManager.GetSetupKey(account.Id, user.Id, keyID) key, err := h.accountManager.GetSetupKey(r.Context(), account.Id, user.Id, keyID)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
writeSuccess(w, key) writeSuccess(r.Context(), w, key)
} }
// UpdateSetupKey is a PUT request to update server.SetupKey // UpdateSetupKey is a PUT request to update server.SetupKey
func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) { func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
keyID := vars["keyId"] keyID := vars["keyId"]
if len(keyID) == 0 { if len(keyID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w)
return return
} }
@@ -134,12 +135,12 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
} }
if req.Name == "" { if req.Name == "" {
util.WriteError(status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
return return
} }
if req.AutoGroups == nil { if req.AutoGroups == nil {
util.WriteError(status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
return return
} }
@@ -149,26 +150,26 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
newKey.Name = req.Name newKey.Name = req.Name
newKey.Id = keyID newKey.Id = keyID
newKey, err = h.accountManager.SaveSetupKey(account.Id, newKey, user.Id) newKey, err = h.accountManager.SaveSetupKey(r.Context(), account.Id, newKey, user.Id)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
writeSuccess(w, newKey) writeSuccess(r.Context(), w, newKey)
} }
// GetAllSetupKeys is a GET request that returns a list of SetupKey // GetAllSetupKeys is a GET request that returns a list of SetupKey
func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) { func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
setupKeys, err := h.accountManager.ListSetupKeys(account.Id, user.Id) setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), account.Id, user.Id)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -177,15 +178,15 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques
apiSetupKeys = append(apiSetupKeys, toResponseBody(key)) apiSetupKeys = append(apiSetupKeys, toResponseBody(key))
} }
util.WriteJSONObject(w, apiSetupKeys) util.WriteJSONObject(r.Context(), w, apiSetupKeys)
} }
func writeSuccess(w http.ResponseWriter, key *server.SetupKey) { func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200) w.WriteHeader(200)
err := json.NewEncoder(w).Encode(toResponseBody(key)) err := json.NewEncoder(w).Encode(toResponseBody(key))
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(ctx, err, w)
return return
} }
} }

View File

@@ -2,6 +2,7 @@ package http
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@@ -33,7 +34,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
) *SetupKeysHandler { ) *SetupKeysHandler {
return &SetupKeysHandler{ return &SetupKeysHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{ return &server.Account{
Id: testAccountID, Id: testAccountID,
Domain: "hotmail.com", Domain: "hotmail.com",
@@ -49,7 +50,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
}, },
}, user, nil }, user, nil
}, },
CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string, CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
_ int, _ string, ephemeral bool, _ int, _ string, ephemeral bool,
) (*server.SetupKey, error) { ) (*server.SetupKey, error) {
if keyName == newKey.Name || typ != newKey.Type { if keyName == newKey.Name || typ != newKey.Type {
@@ -59,7 +60,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
} }
return nil, fmt.Errorf("failed creating setup key") return nil, fmt.Errorf("failed creating setup key")
}, },
GetSetupKeyFunc: func(accountID, userID, keyID string) (*server.SetupKey, error) { GetSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) (*server.SetupKey, error) {
switch keyID { switch keyID {
case defaultKey.Id: case defaultKey.Id:
return defaultKey, nil return defaultKey, nil
@@ -70,14 +71,14 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
} }
}, },
SaveSetupKeyFunc: func(accountID string, key *server.SetupKey, _ string) (*server.SetupKey, error) { SaveSetupKeyFunc: func(_ context.Context, accountID string, key *server.SetupKey, _ string) (*server.SetupKey, error) {
if key.Id == updatedSetupKey.Id { if key.Id == updatedSetupKey.Id {
return updatedSetupKey, nil return updatedSetupKey, nil
} }
return nil, status.Errorf(status.NotFound, "key %s not found", key.Id) return nil, status.Errorf(status.NotFound, "key %s not found", key.Id)
}, },
ListSetupKeysFunc: func(accountID, userID string) ([]*server.SetupKey, error) { ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*server.SetupKey, error) {
return []*server.SetupKey{defaultKey}, nil return []*server.SetupKey{defaultKey}, nil
}, },
}, },

View File

@@ -41,22 +41,22 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
} }
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
userID := vars["userId"] userID := vars["userId"]
if len(userID) == 0 { if len(userID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return return
} }
existingUser, ok := account.Users[userID] existingUser, ok := account.Users[userID]
if !ok { if !ok {
util.WriteError(status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w) util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w)
return return
} }
@@ -74,11 +74,11 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
userRole := server.StrRoleToUserRole(req.Role) userRole := server.StrRoleToUserRole(req.Role)
if userRole == server.UserRoleUnknown { if userRole == server.UserRoleUnknown {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user role"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user role"), w)
return return
} }
newUser, err := h.accountManager.SaveUser(account.Id, user.Id, &server.User{ newUser, err := h.accountManager.SaveUser(r.Context(), account.Id, user.Id, &server.User{
Id: userID, Id: userID,
Role: userRole, Role: userRole,
AutoGroups: req.AutoGroups, AutoGroups: req.AutoGroups,
@@ -88,10 +88,10 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
}) })
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId)) util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
} }
// DeleteUser is a DELETE request to delete a user // DeleteUser is a DELETE request to delete a user
@@ -102,26 +102,26 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
} }
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
targetUserID := vars["userId"] targetUserID := vars["userId"]
if len(targetUserID) == 0 { if len(targetUserID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return return
} }
err = h.accountManager.DeleteUser(account.Id, user.Id, targetUserID) err = h.accountManager.DeleteUser(r.Context(), account.Id, user.Id, targetUserID)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(w, emptyObject{}) util.WriteJSONObject(r.Context(), w, emptyObject{})
} }
// CreateUser creates a User in the system with a status "invited" (effectively this is a user invite). // CreateUser creates a User in the system with a status "invited" (effectively this is a user invite).
@@ -132,9 +132,9 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
} }
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -146,7 +146,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
} }
if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown { if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown {
util.WriteError(status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w)
return return
} }
@@ -160,7 +160,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
name = *req.Name name = *req.Name
} }
newUser, err := h.accountManager.CreateUser(account.Id, user.Id, &server.UserInfo{ newUser, err := h.accountManager.CreateUser(r.Context(), account.Id, user.Id, &server.UserInfo{
Email: email, Email: email,
Name: name, Name: name,
Role: req.Role, Role: req.Role,
@@ -169,10 +169,10 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
Issued: server.UserIssuedAPI, Issued: server.UserIssuedAPI,
}) })
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId)) util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
} }
// GetAllUsers returns a list of users of the account this user belongs to. // GetAllUsers returns a list of users of the account this user belongs to.
@@ -184,42 +184,42 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
} }
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
data, err := h.accountManager.GetUsersFromAccount(account.Id, user.Id) data, err := h.accountManager.GetUsersFromAccount(r.Context(), account.Id, user.Id)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
serviceUser := r.URL.Query().Get("service_user") serviceUser := r.URL.Query().Get("service_user")
users := make([]*api.User, 0) users := make([]*api.User, 0)
for _, r := range data { for _, d := range data {
if r.NonDeletable { if d.NonDeletable {
continue continue
} }
if serviceUser == "" { if serviceUser == "" {
users = append(users, toUserResponse(r, claims.UserId)) users = append(users, toUserResponse(d, claims.UserId))
continue continue
} }
includeServiceUser, err := strconv.ParseBool(serviceUser) includeServiceUser, err := strconv.ParseBool(serviceUser)
log.Debugf("Should include service user: %v", includeServiceUser) log.WithContext(r.Context()).Debugf("Should include service user: %v", includeServiceUser)
if err != nil { if err != nil {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w)
return return
} }
if includeServiceUser == r.IsServiceUser { if includeServiceUser == d.IsServiceUser {
users = append(users, toUserResponse(r, claims.UserId)) users = append(users, toUserResponse(d, claims.UserId))
} }
} }
util.WriteJSONObject(w, users) util.WriteJSONObject(r.Context(), w, users)
} }
// InviteUser resend invitations to users who haven't activated their accounts, // InviteUser resend invitations to users who haven't activated their accounts,
@@ -231,26 +231,26 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
} }
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
targetUserID := vars["userId"] targetUserID := vars["userId"]
if len(targetUserID) == 0 { if len(targetUserID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return return
} }
err = h.accountManager.InviteUser(account.Id, user.Id, targetUserID) err = h.accountManager.InviteUser(r.Context(), account.Id, user.Id, targetUserID)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(w, emptyObject{}) util.WriteJSONObject(r.Context(), w, emptyObject{})
} }
func toUserResponse(user *server.UserInfo, currenUserID string) *api.User { func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {

View File

@@ -2,6 +2,7 @@ package http
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@@ -63,10 +64,10 @@ var usersTestAccount = &server.Account{
func initUsersTestData() *UsersHandler { func initUsersTestData() *UsersHandler {
return &UsersHandler{ return &UsersHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return usersTestAccount, usersTestAccount.Users[claims.UserId], nil return usersTestAccount, usersTestAccount.Users[claims.UserId], nil
}, },
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) { GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
users := make([]*server.UserInfo, 0) users := make([]*server.UserInfo, 0)
for _, v := range usersTestAccount.Users { for _, v := range usersTestAccount.Users {
users = append(users, &server.UserInfo{ users = append(users, &server.UserInfo{
@@ -81,13 +82,13 @@ func initUsersTestData() *UsersHandler {
} }
return users, nil return users, nil
}, },
CreateUserFunc: func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) { CreateUserFunc: func(_ context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) {
if userID != existingUserID { if userID != existingUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID) return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
} }
return key, nil return key, nil
}, },
DeleteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error { DeleteUserFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) error {
if targetUserID == notFoundUserID { if targetUserID == notFoundUserID {
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID) return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
} }
@@ -96,7 +97,7 @@ func initUsersTestData() *UsersHandler {
} }
return nil return nil
}, },
SaveUserFunc: func(accountID, userID string, update *server.User) (*server.UserInfo, error) { SaveUserFunc: func(_ context.Context, accountID, userID string, update *server.User) (*server.UserInfo, error) {
if update.Id == notFoundUserID { if update.Id == notFoundUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", update.Id) return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", update.Id)
} }
@@ -111,7 +112,7 @@ func initUsersTestData() *UsersHandler {
} }
return info, nil return info, nil
}, },
InviteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error { InviteUserFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) error {
if initiatorUserID != existingUserID { if initiatorUserID != existingUserID {
return status.Errorf(status.NotFound, "user with ID %s does not exists", initiatorUserID) return status.Errorf(status.NotFound, "user with ID %s does not exists", initiatorUserID)
} }

View File

@@ -1,6 +1,7 @@
package util package util
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -19,12 +20,12 @@ type ErrorResponse struct {
} }
// WriteJSONObject simply writes object to the HTTP response in JSON format // WriteJSONObject simply writes object to the HTTP response in JSON format
func WriteJSONObject(w http.ResponseWriter, obj interface{}) { func WriteJSONObject(ctx context.Context, w http.ResponseWriter, obj interface{}) {
w.Header().Set("Content-Type", "application/json; charset=UTF-8") w.Header().Set("Content-Type", "application/json; charset=UTF-8")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
err := json.NewEncoder(w).Encode(obj) err := json.NewEncoder(w).Encode(obj)
if err != nil { if err != nil {
WriteError(err, w) WriteError(ctx, err, w)
return return
} }
} }
@@ -76,8 +77,8 @@ func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) {
// WriteError converts an error to an JSON error response. // WriteError converts an error to an JSON error response.
// If it is known internal error of type server.Error then it sets the messages from the error, a generic message otherwise // If it is known internal error of type server.Error then it sets the messages from the error, a generic message otherwise
func WriteError(err error, w http.ResponseWriter) { func WriteError(ctx context.Context, err error, w http.ResponseWriter) {
log.Errorf("got a handler error: %s", err.Error()) log.WithContext(ctx).Errorf("got a handler error: %s", err.Error())
errStatus, ok := status.FromError(err) errStatus, ok := status.FromError(err)
httpStatus := http.StatusInternalServerError httpStatus := http.StatusInternalServerError
msg := "internal server error" msg := "internal server error"
@@ -106,7 +107,7 @@ func WriteError(err error, w http.ResponseWriter) {
msg = strings.ToLower(err.Error()) msg = strings.ToLower(err.Error())
} else { } else {
unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", err.Error()) unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", err.Error())
log.Error(unhandledMSG) log.WithContext(ctx).Error(unhandledMSG)
} }
WriteErrorResponse(msg, httpStatus, w) WriteErrorResponse(msg, httpStatus, w)

View File

@@ -183,7 +183,7 @@ func (c *Auth0Credentials) jwtStillValid() bool {
} }
// requestJWTToken performs request to get jwt token // requestJWTToken performs request to get jwt token
func (c *Auth0Credentials) requestJWTToken() (*http.Response, error) { func (c *Auth0Credentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
var res *http.Response var res *http.Response
reqURL := c.clientConfig.AuthIssuer + "/oauth/token" reqURL := c.clientConfig.AuthIssuer + "/oauth/token"
@@ -200,7 +200,7 @@ func (c *Auth0Credentials) requestJWTToken() (*http.Response, error) {
req.Header.Add("content-type", "application/json") req.Header.Add("content-type", "application/json")
log.Debug("requesting new jwt token for idp manager") log.WithContext(ctx).Debug("requesting new jwt token for idp manager")
res, err = c.httpClient.Do(req) res, err = c.httpClient.Do(req)
if err != nil { if err != nil {
@@ -247,7 +247,7 @@ func (c *Auth0Credentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTTo
} }
// Authenticate retrieves access token to use the Auth0 Management API // Authenticate retrieves access token to use the Auth0 Management API
func (c *Auth0Credentials) Authenticate() (JWTToken, error) { func (c *Auth0Credentials) Authenticate(ctx context.Context) (JWTToken, error) {
c.mux.Lock() c.mux.Lock()
defer c.mux.Unlock() defer c.mux.Unlock()
@@ -260,14 +260,14 @@ func (c *Auth0Credentials) Authenticate() (JWTToken, error) {
return c.jwtToken, nil return c.jwtToken, nil
} }
res, err := c.requestJWTToken() res, err := c.requestJWTToken(ctx)
if err != nil { if err != nil {
return c.jwtToken, err return c.jwtToken, err
} }
defer func() { defer func() {
err = res.Body.Close() err = res.Body.Close()
if err != nil { if err != nil {
log.Errorf("error while closing get jwt token response body: %v", err) log.WithContext(ctx).Errorf("error while closing get jwt token response body: %v", err)
} }
}() }()
@@ -301,8 +301,8 @@ func requestByUserIDURL(authIssuer, userID string) string {
} }
// GetAccount returns all the users for a given profile. Calls Auth0 API. // GetAccount returns all the users for a given profile. Calls Auth0 API.
func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) { func (am *Auth0Manager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
jwtToken, err := am.credentials.Authenticate() jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -353,7 +353,7 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
return nil, err return nil, err
} }
log.Debugf("returned user batch for accountID %s on page %d, batch length %d", accountID, page, len(batch)) log.WithContext(ctx).Debugf("returned user batch for accountID %s on page %d, batch length %d", accountID, page, len(batch))
err = res.Body.Close() err = res.Body.Close()
if err != nil { if err != nil {
@@ -365,7 +365,7 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
} }
if len(batch) == 0 || len(batch) < resultsPerPage { if len(batch) == 0 || len(batch) < resultsPerPage {
log.Debugf("finished loading users for accountID %s", accountID) log.WithContext(ctx).Debugf("finished loading users for accountID %s", accountID)
return list, nil return list, nil
} }
} }
@@ -374,8 +374,8 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
} }
// GetUserDataByID requests user data from auth0 via ID // GetUserDataByID requests user data from auth0 via ID
func (am *Auth0Manager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { func (am *Auth0Manager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
jwtToken, err := am.credentials.Authenticate() jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -414,7 +414,7 @@ func (am *Auth0Manager) GetUserDataByID(userID string, appMetadata AppMetadata)
defer func() { defer func() {
err = res.Body.Close() err = res.Body.Close()
if err != nil { if err != nil {
log.Errorf("error while closing update user app metadata response body: %v", err) log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
} }
}() }()
@@ -426,9 +426,9 @@ func (am *Auth0Manager) GetUserDataByID(userID string, appMetadata AppMetadata)
} }
// UpdateUserAppMetadata updates user app metadata based on userId and metadata map // UpdateUserAppMetadata updates user app metadata based on userId and metadata map
func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { func (am *Auth0Manager) UpdateUserAppMetadata(ctx context.Context, userID string, appMetadata AppMetadata) error {
jwtToken, err := am.credentials.Authenticate() jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil { if err != nil {
return err return err
} }
@@ -449,7 +449,7 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json") req.Header.Add("content-type", "application/json")
log.Debugf("updating IdP metadata for user %s", userID) log.WithContext(ctx).Debugf("updating IdP metadata for user %s", userID)
res, err := am.httpClient.Do(req) res, err := am.httpClient.Do(req)
if err != nil { if err != nil {
@@ -466,7 +466,7 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
defer func() { defer func() {
err = res.Body.Close() err = res.Body.Close()
if err != nil { if err != nil {
log.Errorf("error while closing update user app metadata response body: %v", err) log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
} }
}() }()
@@ -530,9 +530,9 @@ func buildUserExportRequest() (string, error) {
} }
func (am *Auth0Manager) createRequest( func (am *Auth0Manager) createRequest(
method string, endpoint string, body io.Reader, ctx context.Context, method string, endpoint string, body io.Reader,
) (*http.Request, error) { ) (*http.Request, error) {
jwtToken, err := am.credentials.Authenticate() jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -548,8 +548,8 @@ func (am *Auth0Manager) createRequest(
return req, nil return req, nil
} }
func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) { func (am *Auth0Manager) createPostRequest(ctx context.Context, endpoint string, payloadStr string) (*http.Request, error) {
req, err := am.createRequest("POST", endpoint, strings.NewReader(payloadStr)) req, err := am.createRequest(ctx, "POST", endpoint, strings.NewReader(payloadStr))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -560,20 +560,20 @@ func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*
// GetAllAccounts gets all registered accounts with corresponding user data. // GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID. // It returns a list of users indexed by accountID.
func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) { func (am *Auth0Manager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
payloadString, err := buildUserExportRequest() payloadString, err := buildUserExportRequest()
if err != nil { if err != nil {
return nil, err return nil, err
} }
exportJobReq, err := am.createPostRequest("/api/v2/jobs/users-exports", payloadString) exportJobReq, err := am.createPostRequest(ctx, "/api/v2/jobs/users-exports", payloadString)
if err != nil { if err != nil {
return nil, err return nil, err
} }
jobResp, err := am.httpClient.Do(exportJobReq) jobResp, err := am.httpClient.Do(exportJobReq)
if err != nil { if err != nil {
log.Debugf("Couldn't get job response %v", err) log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
if am.appMetrics != nil { if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError() am.appMetrics.IDPMetrics().CountRequestError()
} }
@@ -583,7 +583,7 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
defer func() { defer func() {
err = jobResp.Body.Close() err = jobResp.Body.Close()
if err != nil { if err != nil {
log.Errorf("error while closing update user app metadata response body: %v", err) log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
} }
}() }()
if jobResp.StatusCode != 200 { if jobResp.StatusCode != 200 {
@@ -597,13 +597,13 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
body, err := io.ReadAll(jobResp.Body) body, err := io.ReadAll(jobResp.Body)
if err != nil { if err != nil {
log.Debugf("Couldn't read export job response; %v", err) log.WithContext(ctx).Debugf("Couldn't read export job response; %v", err)
return nil, err return nil, err
} }
err = am.helper.Unmarshal(body, &exportJobResp) err = am.helper.Unmarshal(body, &exportJobResp)
if err != nil { if err != nil {
log.Debugf("Couldn't unmarshal export job response; %v", err) log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
return nil, err return nil, err
} }
@@ -614,16 +614,16 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
return nil, fmt.Errorf("couldn't get an batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp) return nil, fmt.Errorf("couldn't get an batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
} }
log.Debugf("batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp) log.WithContext(ctx).Debugf("batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
done, downloadLink, err := am.checkExportJobStatus(exportJobResp.ID) done, downloadLink, err := am.checkExportJobStatus(ctx, exportJobResp.ID)
if err != nil { if err != nil {
log.Debugf("Failed at getting status checks from exportJob; %v", err) log.WithContext(ctx).Debugf("Failed at getting status checks from exportJob; %v", err)
return nil, err return nil, err
} }
if done { if done {
return am.downloadProfileExport(downloadLink) return am.downloadProfileExport(ctx, downloadLink)
} }
return nil, fmt.Errorf("failed extracting user profiles from auth0") return nil, fmt.Errorf("failed extracting user profiles from auth0")
@@ -632,13 +632,13 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
// GetUserByEmail searches users with a given email. If no users have been found, this function returns an empty list. // GetUserByEmail searches users with a given email. If no users have been found, this function returns an empty list.
// This function can return multiple users. This is due to the Auth0 internals - there could be multiple users with // This function can return multiple users. This is due to the Auth0 internals - there could be multiple users with
// the same email but different connections that are considered as separate accounts (e.g., Google and username/password). // the same email but different connections that are considered as separate accounts (e.g., Google and username/password).
func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) { func (am *Auth0Manager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
jwtToken, err := am.credentials.Authenticate() jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
reqURL := am.authIssuer + "/api/v2/users-by-email?email=" + url.QueryEscape(email) reqURL := am.authIssuer + "/api/v2/users-by-email?email=" + url.QueryEscape(email)
body, err := doGetReq(am.httpClient, reqURL, jwtToken.AccessToken) body, err := doGetReq(ctx, am.httpClient, reqURL, jwtToken.AccessToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -651,7 +651,7 @@ func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) {
err = am.helper.Unmarshal(body, &userResp) err = am.helper.Unmarshal(body, &userResp)
if err != nil { if err != nil {
log.Debugf("Couldn't unmarshal export job response; %v", err) log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
return nil, err return nil, err
} }
@@ -659,13 +659,13 @@ func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) {
} }
// CreateUser creates a new user in Auth0 Idp and sends an invite // CreateUser creates a new user in Auth0 Idp and sends an invite
func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { func (am *Auth0Manager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
payloadString, err := buildCreateUserRequestPayload(email, name, accountID, invitedByEmail) payloadString, err := buildCreateUserRequestPayload(email, name, accountID, invitedByEmail)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req, err := am.createPostRequest("/api/v2/users", payloadString) req, err := am.createPostRequest(ctx, "/api/v2/users", payloadString)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -676,7 +676,7 @@ func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string
resp, err := am.httpClient.Do(req) resp, err := am.httpClient.Do(req)
if err != nil { if err != nil {
log.Debugf("Couldn't get job response %v", err) log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
if am.appMetrics != nil { if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError() am.appMetrics.IDPMetrics().CountRequestError()
} }
@@ -686,7 +686,7 @@ func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string
defer func() { defer func() {
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
log.Errorf("error while closing create user response body: %v", err) log.WithContext(ctx).Errorf("error while closing create user response body: %v", err)
} }
}() }()
if !(resp.StatusCode == 200 || resp.StatusCode == 201) { if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
@@ -700,13 +700,13 @@ func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
log.Debugf("Couldn't read export job response; %v", err) log.WithContext(ctx).Debugf("Couldn't read export job response; %v", err)
return nil, err return nil, err
} }
err = am.helper.Unmarshal(body, &createResp) err = am.helper.Unmarshal(body, &createResp)
if err != nil { if err != nil {
log.Debugf("Couldn't unmarshal export job response; %v", err) log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
return nil, err return nil, err
} }
@@ -714,14 +714,14 @@ func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string
return nil, fmt.Errorf("couldn't create user: response %v", resp) return nil, fmt.Errorf("couldn't create user: response %v", resp)
} }
log.Debugf("created user %s in account %s", createResp.ID, accountID) log.WithContext(ctx).Debugf("created user %s in account %s", createResp.ID, accountID)
return &createResp, nil return &createResp, nil
} }
// InviteUserByID resend invitations to users who haven't activated, // InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period. // their accounts prior to the expiration period.
func (am *Auth0Manager) InviteUserByID(userID string) error { func (am *Auth0Manager) InviteUserByID(ctx context.Context, userID string) error {
userVerificationReq := userVerificationJobRequest{ userVerificationReq := userVerificationJobRequest{
UserID: userID, UserID: userID,
} }
@@ -731,14 +731,14 @@ func (am *Auth0Manager) InviteUserByID(userID string) error {
return err return err
} }
req, err := am.createPostRequest("/api/v2/jobs/verification-email", string(payload)) req, err := am.createPostRequest(ctx, "/api/v2/jobs/verification-email", string(payload))
if err != nil { if err != nil {
return err return err
} }
resp, err := am.httpClient.Do(req) resp, err := am.httpClient.Do(req)
if err != nil { if err != nil {
log.Debugf("Couldn't get job response %v", err) log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
if am.appMetrics != nil { if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError() am.appMetrics.IDPMetrics().CountRequestError()
} }
@@ -748,7 +748,7 @@ func (am *Auth0Manager) InviteUserByID(userID string) error {
defer func() { defer func() {
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
log.Errorf("error while closing invite user response body: %v", err) log.WithContext(ctx).Errorf("error while closing invite user response body: %v", err)
} }
}() }()
if !(resp.StatusCode == 200 || resp.StatusCode == 201) { if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
@@ -762,15 +762,15 @@ func (am *Auth0Manager) InviteUserByID(userID string) error {
} }
// DeleteUser from Auth0 // DeleteUser from Auth0
func (am *Auth0Manager) DeleteUser(userID string) error { func (am *Auth0Manager) DeleteUser(ctx context.Context, userID string) error {
req, err := am.createRequest(http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil) req, err := am.createRequest(ctx, http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil)
if err != nil { if err != nil {
return err return err
} }
resp, err := am.httpClient.Do(req) resp, err := am.httpClient.Do(req)
if err != nil { if err != nil {
log.Debugf("execute delete request: %v", err) log.WithContext(ctx).Debugf("execute delete request: %v", err)
if am.appMetrics != nil { if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError() am.appMetrics.IDPMetrics().CountRequestError()
} }
@@ -780,7 +780,7 @@ func (am *Auth0Manager) DeleteUser(userID string) error {
defer func() { defer func() {
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
log.Errorf("close delete request body: %v", err) log.WithContext(ctx).Errorf("close delete request body: %v", err)
} }
}() }()
if resp.StatusCode != 204 { if resp.StatusCode != 204 {
@@ -795,20 +795,20 @@ func (am *Auth0Manager) DeleteUser(userID string) error {
// GetAllConnections returns detailed list of all connections filtered by given params. // GetAllConnections returns detailed list of all connections filtered by given params.
// Note this method is not part of the IDP Manager interface as this is Auth0 specific. // Note this method is not part of the IDP Manager interface as this is Auth0 specific.
func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, error) { func (am *Auth0Manager) GetAllConnections(ctx context.Context, strategy []string) ([]Connection, error) {
var connections []Connection var connections []Connection
q := make(url.Values) q := make(url.Values)
q.Set("strategy", strings.Join(strategy, ",")) q.Set("strategy", strings.Join(strategy, ","))
req, err := am.createRequest(http.MethodGet, "/api/v2/connections?"+q.Encode(), nil) req, err := am.createRequest(ctx, http.MethodGet, "/api/v2/connections?"+q.Encode(), nil)
if err != nil { if err != nil {
return connections, err return connections, err
} }
resp, err := am.httpClient.Do(req) resp, err := am.httpClient.Do(req)
if err != nil { if err != nil {
log.Debugf("execute get connections request: %v", err) log.WithContext(ctx).Debugf("execute get connections request: %v", err)
if am.appMetrics != nil { if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError() am.appMetrics.IDPMetrics().CountRequestError()
} }
@@ -818,7 +818,7 @@ func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, erro
defer func() { defer func() {
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
log.Errorf("close get connections request body: %v", err) log.WithContext(ctx).Errorf("close get connections request body: %v", err)
} }
}() }()
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
@@ -830,13 +830,13 @@ func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, erro
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
log.Debugf("Couldn't read get connections response; %v", err) log.WithContext(ctx).Debugf("Couldn't read get connections response; %v", err)
return connections, err return connections, err
} }
err = am.helper.Unmarshal(body, &connections) err = am.helper.Unmarshal(body, &connections)
if err != nil { if err != nil {
log.Debugf("Couldn't unmarshal get connection response; %v", err) log.WithContext(ctx).Debugf("Couldn't unmarshal get connection response; %v", err)
return connections, err return connections, err
} }
@@ -845,23 +845,23 @@ func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, erro
// checkExportJobStatus checks the status of the job created at CreateExportUsersJob. // checkExportJobStatus checks the status of the job created at CreateExportUsersJob.
// If the status is "completed", then return the downloadLink // If the status is "completed", then return the downloadLink
func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) { func (am *Auth0Manager) checkExportJobStatus(ctx context.Context, jobID string) (bool, string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second) ctx, cancel := context.WithTimeout(ctx, 90*time.Second)
defer cancel() defer cancel()
retry := time.NewTicker(10 * time.Second) retry := time.NewTicker(10 * time.Second)
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
log.Debugf("Export job status stopped...\n") log.WithContext(ctx).Debugf("Export job status stopped...\n")
return false, "", ctx.Err() return false, "", ctx.Err()
case <-retry.C: case <-retry.C:
jwtToken, err := am.credentials.Authenticate() jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil { if err != nil {
return false, "", err return false, "", err
} }
statusURL := am.authIssuer + "/api/v2/jobs/" + jobID statusURL := am.authIssuer + "/api/v2/jobs/" + jobID
body, err := doGetReq(am.httpClient, statusURL, jwtToken.AccessToken) body, err := doGetReq(ctx, am.httpClient, statusURL, jwtToken.AccessToken)
if err != nil { if err != nil {
return false, "", err return false, "", err
} }
@@ -872,7 +872,7 @@ func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error)
return false, "", err return false, "", err
} }
log.Debugf("current export job status is %v", status.Status) log.WithContext(ctx).Debugf("current export job status is %v", status.Status)
if status.Status != "completed" { if status.Status != "completed" {
continue continue
@@ -884,8 +884,8 @@ func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error)
} }
// downloadProfileExport downloads user profiles from auth0 batch job // downloadProfileExport downloads user profiles from auth0 batch job
func (am *Auth0Manager) downloadProfileExport(location string) (map[string][]*UserData, error) { func (am *Auth0Manager) downloadProfileExport(ctx context.Context, location string) (map[string][]*UserData, error) {
body, err := doGetReq(am.httpClient, location, "") body, err := doGetReq(ctx, am.httpClient, location, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -927,7 +927,7 @@ func (am *Auth0Manager) downloadProfileExport(location string) (map[string][]*Us
} }
// Boilerplate implementation for Get Requests. // Boilerplate implementation for Get Requests.
func doGetReq(client ManagerHTTPClient, url, accessToken string) ([]byte, error) { func doGetReq(ctx context.Context, client ManagerHTTPClient, url, accessToken string) ([]byte, error) {
req, err := http.NewRequest("GET", url, nil) req, err := http.NewRequest("GET", url, nil)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -945,7 +945,7 @@ func doGetReq(client ManagerHTTPClient, url, accessToken string) ([]byte, error)
defer func() { defer func() {
err = res.Body.Close() err = res.Body.Close()
if err != nil { if err != nil {
log.Errorf("error while closing body for url %s: %v", url, err) log.WithContext(ctx).Errorf("error while closing body for url %s: %v", url, err)
} }
}() }()
body, err := io.ReadAll(res.Body) body, err := io.ReadAll(res.Body)

View File

@@ -1,6 +1,7 @@
package idp package idp
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@@ -60,7 +61,7 @@ type mockAuth0Credentials struct {
err error err error
} }
func (mc *mockAuth0Credentials) Authenticate() (JWTToken, error) { func (mc *mockAuth0Credentials) Authenticate(_ context.Context) (JWTToken, error) {
return mc.jwtToken, mc.err return mc.jwtToken, mc.err
} }
@@ -126,7 +127,7 @@ func TestAuth0_RequestJWTToken(t *testing.T) {
helper: testCase.helper, helper: testCase.helper,
} }
res, err := creds.requestJWTToken() res, err := creds.requestJWTToken(context.Background())
if err != nil { if err != nil {
if testCase.expectedFuncExitErrDiff != nil { if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
@@ -295,7 +296,7 @@ func TestAuth0_Authenticate(t *testing.T) {
creds.jwtToken.expiresInTime = testCase.inputExpireToken creds.jwtToken.expiresInTime = testCase.inputExpireToken
_, err := creds.Authenticate() _, err := creds.Authenticate(context.Background())
if err != nil { if err != nil {
if testCase.expectedFuncExitErrDiff != nil { if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
@@ -417,7 +418,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
helper: testCase.helper, helper: testCase.helper,
} }
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata) err := manager.UpdateUserAppMetadata(context.Background(), "1", testCase.appMetadata)
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage) testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
assert.Equal(t, testCase.expectedReqBody, jwtReqClient.reqBody, "request body should match") assert.Equal(t, testCase.expectedReqBody, jwtReqClient.reqBody, "request body should match")

View File

@@ -116,7 +116,7 @@ func (ac *AuthentikCredentials) jwtStillValid() bool {
} }
// requestJWTToken performs request to get jwt token. // requestJWTToken performs request to get jwt token.
func (ac *AuthentikCredentials) requestJWTToken() (*http.Response, error) { func (ac *AuthentikCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
data := url.Values{} data := url.Values{}
data.Set("client_id", ac.clientConfig.ClientID) data.Set("client_id", ac.clientConfig.ClientID)
data.Set("username", ac.clientConfig.Username) data.Set("username", ac.clientConfig.Username)
@@ -131,7 +131,7 @@ func (ac *AuthentikCredentials) requestJWTToken() (*http.Response, error) {
} }
req.Header.Add("content-type", "application/x-www-form-urlencoded") req.Header.Add("content-type", "application/x-www-form-urlencoded")
log.Debug("requesting new jwt token for authentik idp manager") log.WithContext(ctx).Debug("requesting new jwt token for authentik idp manager")
resp, err := ac.httpClient.Do(req) resp, err := ac.httpClient.Do(req)
if err != nil { if err != nil {
@@ -183,7 +183,7 @@ func (ac *AuthentikCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (
} }
// Authenticate retrieves access token to use the authentik management API. // Authenticate retrieves access token to use the authentik management API.
func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) { func (ac *AuthentikCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
ac.mux.Lock() ac.mux.Lock()
defer ac.mux.Unlock() defer ac.mux.Unlock()
@@ -197,7 +197,7 @@ func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) {
return ac.jwtToken, nil return ac.jwtToken, nil
} }
resp, err := ac.requestJWTToken() resp, err := ac.requestJWTToken(ctx)
if err != nil { if err != nil {
return ac.jwtToken, err return ac.jwtToken, err
} }
@@ -214,13 +214,13 @@ func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) {
} }
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. // UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (am *AuthentikManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { func (am *AuthentikManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil return nil
} }
// GetUserDataByID requests user data from authentik via ID. // GetUserDataByID requests user data from authentik via ID.
func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { func (am *AuthentikManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
ctx, err := am.authenticationContext() ctx, err := am.authenticationContext(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -254,8 +254,8 @@ func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetada
} }
// GetAccount returns all the users for a given profile. // GetAccount returns all the users for a given profile.
func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) { func (am *AuthentikManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
users, err := am.getAllUsers() users, err := am.getAllUsers(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -274,8 +274,8 @@ func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) {
// GetAllAccounts gets all registered accounts with corresponding user data. // GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID. // It returns a list of users indexed by accountID.
func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) { func (am *AuthentikManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
users, err := am.getAllUsers() users, err := am.getAllUsers(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -291,12 +291,12 @@ func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) {
} }
// getAllUsers returns all users in a Authentik account. // getAllUsers returns all users in a Authentik account.
func (am *AuthentikManager) getAllUsers() ([]*UserData, error) { func (am *AuthentikManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
users := make([]*UserData, 0) users := make([]*UserData, 0)
page := int32(1) page := int32(1)
for { for {
ctx, err := am.authenticationContext() ctx, err := am.authenticationContext(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -329,14 +329,14 @@ func (am *AuthentikManager) getAllUsers() ([]*UserData, error) {
} }
// CreateUser creates a new user in authentik Idp and sends an invitation. // CreateUser creates a new user in authentik Idp and sends an invitation.
func (am *AuthentikManager) CreateUser(_, _, _, _ string) (*UserData, error) { func (am *AuthentikManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented") return nil, fmt.Errorf("method CreateUser not implemented")
} }
// GetUserByEmail searches users with a given email. // GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list. // If no users have been found, this function returns an empty list.
func (am *AuthentikManager) GetUserByEmail(email string) ([]*UserData, error) { func (am *AuthentikManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
ctx, err := am.authenticationContext() ctx, err := am.authenticationContext(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -368,13 +368,13 @@ func (am *AuthentikManager) GetUserByEmail(email string) ([]*UserData, error) {
// InviteUserByID resend invitations to users who haven't activated, // InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period. // their accounts prior to the expiration period.
func (am *AuthentikManager) InviteUserByID(_ string) error { func (am *AuthentikManager) InviteUserByID(_ context.Context, _ string) error {
return fmt.Errorf("method InviteUserByID not implemented") return fmt.Errorf("method InviteUserByID not implemented")
} }
// DeleteUser from Authentik // DeleteUser from Authentik
func (am *AuthentikManager) DeleteUser(userID string) error { func (am *AuthentikManager) DeleteUser(ctx context.Context, userID string) error {
ctx, err := am.authenticationContext() ctx, err := am.authenticationContext(ctx)
if err != nil { if err != nil {
return err return err
} }
@@ -404,8 +404,8 @@ func (am *AuthentikManager) DeleteUser(userID string) error {
return nil return nil
} }
func (am *AuthentikManager) authenticationContext() (context.Context, error) { func (am *AuthentikManager) authenticationContext(ctx context.Context) (context.Context, error) {
jwtToken, err := am.credentials.Authenticate() jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -1,6 +1,7 @@
package idp package idp
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"strings" "strings"
@@ -138,7 +139,7 @@ func TestAuthentikRequestJWTToken(t *testing.T) {
helper: testCase.helper, helper: testCase.helper,
} }
resp, err := creds.requestJWTToken() resp, err := creds.requestJWTToken(context.Background())
if err != nil { if err != nil {
if testCase.expectedFuncExitErrDiff != nil { if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
@@ -304,7 +305,7 @@ func TestAuthentikAuthenticate(t *testing.T) {
} }
creds.jwtToken.expiresInTime = testCase.inputExpireToken creds.jwtToken.expiresInTime = testCase.inputExpireToken
_, err := creds.Authenticate() _, err := creds.Authenticate(context.Background())
if err != nil { if err != nil {
if testCase.expectedFuncExitErrDiff != nil { if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")

View File

@@ -1,6 +1,7 @@
package idp package idp
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@@ -110,7 +111,7 @@ func (ac *AzureCredentials) jwtStillValid() bool {
} }
// requestJWTToken performs request to get jwt token. // requestJWTToken performs request to get jwt token.
func (ac *AzureCredentials) requestJWTToken() (*http.Response, error) { func (ac *AzureCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
data := url.Values{} data := url.Values{}
data.Set("client_id", ac.clientConfig.ClientID) data.Set("client_id", ac.clientConfig.ClientID)
data.Set("client_secret", ac.clientConfig.ClientSecret) data.Set("client_secret", ac.clientConfig.ClientSecret)
@@ -132,7 +133,7 @@ func (ac *AzureCredentials) requestJWTToken() (*http.Response, error) {
} }
req.Header.Add("content-type", "application/x-www-form-urlencoded") req.Header.Add("content-type", "application/x-www-form-urlencoded")
log.Debug("requesting new jwt token for azure idp manager") log.WithContext(ctx).Debug("requesting new jwt token for azure idp manager")
resp, err := ac.httpClient.Do(req) resp, err := ac.httpClient.Do(req)
if err != nil { if err != nil {
@@ -184,7 +185,7 @@ func (ac *AzureCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTT
} }
// Authenticate retrieves access token to use the azure Management API. // Authenticate retrieves access token to use the azure Management API.
func (ac *AzureCredentials) Authenticate() (JWTToken, error) { func (ac *AzureCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
ac.mux.Lock() ac.mux.Lock()
defer ac.mux.Unlock() defer ac.mux.Unlock()
@@ -198,7 +199,7 @@ func (ac *AzureCredentials) Authenticate() (JWTToken, error) {
return ac.jwtToken, nil return ac.jwtToken, nil
} }
resp, err := ac.requestJWTToken() resp, err := ac.requestJWTToken(ctx)
if err != nil { if err != nil {
return ac.jwtToken, err return ac.jwtToken, err
} }
@@ -215,16 +216,16 @@ func (ac *AzureCredentials) Authenticate() (JWTToken, error) {
} }
// CreateUser creates a new user in azure AD Idp. // CreateUser creates a new user in azure AD Idp.
func (am *AzureManager) CreateUser(_, _, _, _ string) (*UserData, error) { func (am *AzureManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented") return nil, fmt.Errorf("method CreateUser not implemented")
} }
// GetUserDataByID requests user data from keycloak via ID. // GetUserDataByID requests user data from keycloak via ID.
func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { func (am *AzureManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
q := url.Values{} q := url.Values{}
q.Add("$select", profileFields) q.Add("$select", profileFields)
body, err := am.get("users/"+userID, q) body, err := am.get(ctx, "users/"+userID, q)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -247,11 +248,11 @@ func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata)
// GetUserByEmail searches users with a given email. // GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list. // If no users have been found, this function returns an empty list.
func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) { func (am *AzureManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
q := url.Values{} q := url.Values{}
q.Add("$select", profileFields) q.Add("$select", profileFields)
body, err := am.get("users/"+email, q) body, err := am.get(ctx, "users/"+email, q)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -273,8 +274,8 @@ func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
} }
// GetAccount returns all the users for a given profile. // GetAccount returns all the users for a given profile.
func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) { func (am *AzureManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
users, err := am.getAllUsers() users, err := am.getAllUsers(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -293,8 +294,8 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
// GetAllAccounts gets all registered accounts with corresponding user data. // GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID. // It returns a list of users indexed by accountID.
func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) { func (am *AzureManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
users, err := am.getAllUsers() users, err := am.getAllUsers(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -310,19 +311,19 @@ func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
} }
// UpdateUserAppMetadata updates user app metadata based on userID. // UpdateUserAppMetadata updates user app metadata based on userID.
func (am *AzureManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { func (am *AzureManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil return nil
} }
// InviteUserByID resend invitations to users who haven't activated, // InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period. // their accounts prior to the expiration period.
func (am *AzureManager) InviteUserByID(_ string) error { func (am *AzureManager) InviteUserByID(_ context.Context, _ string) error {
return fmt.Errorf("method InviteUserByID not implemented") return fmt.Errorf("method InviteUserByID not implemented")
} }
// DeleteUser from Azure. // DeleteUser from Azure.
func (am *AzureManager) DeleteUser(userID string) error { func (am *AzureManager) DeleteUser(ctx context.Context, userID string) error {
jwtToken, err := am.credentials.Authenticate() jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil { if err != nil {
return err return err
} }
@@ -335,7 +336,7 @@ func (am *AzureManager) DeleteUser(userID string) error {
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json") req.Header.Add("content-type", "application/json")
log.Debugf("delete idp user %s", userID) log.WithContext(ctx).Debugf("delete idp user %s", userID)
resp, err := am.httpClient.Do(req) resp, err := am.httpClient.Do(req)
if err != nil { if err != nil {
@@ -358,7 +359,7 @@ func (am *AzureManager) DeleteUser(userID string) error {
} }
// getAllUsers returns all users in an Azure AD account. // getAllUsers returns all users in an Azure AD account.
func (am *AzureManager) getAllUsers() ([]*UserData, error) { func (am *AzureManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
users := make([]*UserData, 0) users := make([]*UserData, 0)
q := url.Values{} q := url.Values{}
@@ -366,7 +367,7 @@ func (am *AzureManager) getAllUsers() ([]*UserData, error) {
q.Add("$top", "500") q.Add("$top", "500")
for nextLink := "users"; nextLink != ""; { for nextLink := "users"; nextLink != ""; {
body, err := am.get(nextLink, q) body, err := am.get(ctx, nextLink, q)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -391,8 +392,8 @@ func (am *AzureManager) getAllUsers() ([]*UserData, error) {
} }
// get perform Get requests. // get perform Get requests.
func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) { func (am *AzureManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
jwtToken, err := am.credentials.Authenticate() jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -1,6 +1,7 @@
package idp package idp
import ( import (
"context"
"fmt" "fmt"
"testing" "testing"
"time" "time"
@@ -101,7 +102,7 @@ func TestAzureAuthenticate(t *testing.T) {
} }
creds.jwtToken.expiresInTime = testCase.inputExpireToken creds.jwtToken.expiresInTime = testCase.inputExpireToken
_, err := creds.Authenticate() _, err := creds.Authenticate(context.Background())
if err != nil { if err != nil {
if testCase.expectedFuncExitErrDiff != nil { if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")

View File

@@ -39,12 +39,12 @@ type GoogleWorkspaceCredentials struct {
appMetrics telemetry.AppMetrics appMetrics telemetry.AppMetrics
} }
func (gc *GoogleWorkspaceCredentials) Authenticate() (JWTToken, error) { func (gc *GoogleWorkspaceCredentials) Authenticate(_ context.Context) (JWTToken, error) {
return JWTToken{}, nil return JWTToken{}, nil
} }
// NewGoogleWorkspaceManager creates a new instance of the GoogleWorkspaceManager. // NewGoogleWorkspaceManager creates a new instance of the GoogleWorkspaceManager.
func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics telemetry.AppMetrics) (*GoogleWorkspaceManager, error) { func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClientConfig, appMetrics telemetry.AppMetrics) (*GoogleWorkspaceManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone() httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5 httpTransport.MaxIdleConns = 5
@@ -66,7 +66,7 @@ func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics te
} }
// Create a new Admin SDK Directory service client // Create a new Admin SDK Directory service client
adminCredentials, err := getGoogleCredentials(config.ServiceAccountKey) adminCredentials, err := getGoogleCredentials(ctx, config.ServiceAccountKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -90,12 +90,12 @@ func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics te
} }
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. // UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil return nil
} }
// GetUserDataByID requests user data from Google Workspace via ID. // GetUserDataByID requests user data from Google Workspace via ID.
func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { func (gm *GoogleWorkspaceManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
user, err := gm.usersService.Get(userID).Do() user, err := gm.usersService.Get(userID).Do()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -112,7 +112,7 @@ func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata App
} }
// GetAccount returns all the users for a given profile. // GetAccount returns all the users for a given profile.
func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, error) { func (gm *GoogleWorkspaceManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
users, err := gm.getAllUsers() users, err := gm.getAllUsers()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -132,7 +132,7 @@ func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, err
// GetAllAccounts gets all registered accounts with corresponding user data. // GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID. // It returns a list of users indexed by accountID.
func (gm *GoogleWorkspaceManager) GetAllAccounts() (map[string][]*UserData, error) { func (gm *GoogleWorkspaceManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
users, err := gm.getAllUsers() users, err := gm.getAllUsers()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -177,13 +177,13 @@ func (gm *GoogleWorkspaceManager) getAllUsers() ([]*UserData, error) {
} }
// CreateUser creates a new user in Google Workspace and sends an invitation. // CreateUser creates a new user in Google Workspace and sends an invitation.
func (gm *GoogleWorkspaceManager) CreateUser(_, _, _, _ string) (*UserData, error) { func (gm *GoogleWorkspaceManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented") return nil, fmt.Errorf("method CreateUser not implemented")
} }
// GetUserByEmail searches users with a given email. // GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list. // If no users have been found, this function returns an empty list.
func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, error) { func (gm *GoogleWorkspaceManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
user, err := gm.usersService.Get(email).Do() user, err := gm.usersService.Get(email).Do()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -201,12 +201,12 @@ func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, err
// InviteUserByID resend invitations to users who haven't activated, // InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period. // their accounts prior to the expiration period.
func (gm *GoogleWorkspaceManager) InviteUserByID(_ string) error { func (gm *GoogleWorkspaceManager) InviteUserByID(_ context.Context, _ string) error {
return fmt.Errorf("method InviteUserByID not implemented") return fmt.Errorf("method InviteUserByID not implemented")
} }
// DeleteUser from GoogleWorkspace. // DeleteUser from GoogleWorkspace.
func (gm *GoogleWorkspaceManager) DeleteUser(userID string) error { func (gm *GoogleWorkspaceManager) DeleteUser(_ context.Context, userID string) error {
if err := gm.usersService.Delete(userID).Do(); err != nil { if err := gm.usersService.Delete(userID).Do(); err != nil {
return err return err
} }
@@ -222,8 +222,8 @@ func (gm *GoogleWorkspaceManager) DeleteUser(userID string) error {
// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it. // It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it.
// If that fails, it falls back to using the default Google credentials path. // If that fails, it falls back to using the default Google credentials path.
// It returns the retrieved credentials or an error if unsuccessful. // It returns the retrieved credentials or an error if unsuccessful.
func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error) { func getGoogleCredentials(ctx context.Context, serviceAccountKey string) (*google.Credentials, error) {
log.Debug("retrieving google credentials from the base64 encoded service account key") log.WithContext(ctx).Debug("retrieving google credentials from the base64 encoded service account key")
decodeKey, err := base64.StdEncoding.DecodeString(serviceAccountKey) decodeKey, err := base64.StdEncoding.DecodeString(serviceAccountKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to decode service account key: %w", err) return nil, fmt.Errorf("failed to decode service account key: %w", err)
@@ -239,8 +239,8 @@ func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error)
return creds, nil return creds, nil
} }
log.Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err) log.WithContext(ctx).Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err)
log.Debug("falling back to default google credentials location") log.WithContext(ctx).Debug("falling back to default google credentials location")
creds, err = google.FindDefaultCredentials( creds, err = google.FindDefaultCredentials(
context.Background(), context.Background(),

View File

@@ -1,6 +1,7 @@
package idp package idp
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
@@ -16,14 +17,14 @@ const (
// Manager idp manager interface // Manager idp manager interface
type Manager interface { type Manager interface {
UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error UpdateUserAppMetadata(ctx context.Context, userId string, appMetadata AppMetadata) error
GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error)
GetAccount(accountId string) ([]*UserData, error) GetAccount(ctx context.Context, accountId string) ([]*UserData, error)
GetAllAccounts() (map[string][]*UserData, error) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error)
CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error)
GetUserByEmail(email string) ([]*UserData, error) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error)
InviteUserByID(userID string) error InviteUserByID(ctx context.Context, userID string) error
DeleteUser(userID string) error DeleteUser(ctx context.Context, userID string) error
} }
// ClientConfig defines common client configuration for all IdP manager // ClientConfig defines common client configuration for all IdP manager
@@ -51,7 +52,7 @@ type Config struct {
// ManagerCredentials interface that authenticates using the credential of each type of idp // ManagerCredentials interface that authenticates using the credential of each type of idp
type ManagerCredentials interface { type ManagerCredentials interface {
Authenticate() (JWTToken, error) Authenticate(ctx context.Context) (JWTToken, error)
} }
// ManagerHTTPClient http client interface for API calls // ManagerHTTPClient http client interface for API calls
@@ -91,7 +92,7 @@ type JWTToken struct {
} }
// NewManager returns a new idp manager based on the configuration that it receives // NewManager returns a new idp manager based on the configuration that it receives
func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) { func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetrics) (Manager, error) {
if config.ClientConfig != nil { if config.ClientConfig != nil {
config.ClientConfig.Issuer = strings.TrimSuffix(config.ClientConfig.Issuer, "/") config.ClientConfig.Issuer = strings.TrimSuffix(config.ClientConfig.Issuer, "/")
} }
@@ -175,7 +176,7 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
ServiceAccountKey: config.ExtraConfig["ServiceAccountKey"], ServiceAccountKey: config.ExtraConfig["ServiceAccountKey"],
CustomerID: config.ExtraConfig["CustomerId"], CustomerID: config.ExtraConfig["CustomerId"],
} }
return NewGoogleWorkspaceManager(googleClientConfig, appMetrics) return NewGoogleWorkspaceManager(ctx, googleClientConfig, appMetrics)
case "jumpcloud": case "jumpcloud":
jumpcloudConfig := JumpCloudClientConfig{ jumpcloudConfig := JumpCloudClientConfig{
APIToken: config.ExtraConfig["ApiToken"], APIToken: config.ExtraConfig["ApiToken"],

View File

@@ -74,7 +74,7 @@ func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppM
} }
// Authenticate retrieves access token to use the JumpCloud user API. // Authenticate retrieves access token to use the JumpCloud user API.
func (jc *JumpCloudCredentials) Authenticate() (JWTToken, error) { func (jc *JumpCloudCredentials) Authenticate(_ context.Context) (JWTToken, error) {
return JWTToken{}, nil return JWTToken{}, nil
} }
@@ -85,12 +85,12 @@ func (jm *JumpCloudManager) authenticationContext() context.Context {
} }
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. // UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (jm *JumpCloudManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { func (jm *JumpCloudManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil return nil
} }
// GetUserDataByID requests user data from JumpCloud via ID. // GetUserDataByID requests user data from JumpCloud via ID.
func (jm *JumpCloudManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
authCtx := jm.authenticationContext() authCtx := jm.authenticationContext()
user, resp, err := jm.client.SystemusersApi.SystemusersGet(authCtx, userID, contentType, accept, nil) user, resp, err := jm.client.SystemusersApi.SystemusersGet(authCtx, userID, contentType, accept, nil)
if err != nil { if err != nil {
@@ -116,7 +116,7 @@ func (jm *JumpCloudManager) GetUserDataByID(userID string, appMetadata AppMetada
} }
// GetAccount returns all the users for a given profile. // GetAccount returns all the users for a given profile.
func (jm *JumpCloudManager) GetAccount(accountID string) ([]*UserData, error) { func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
authCtx := jm.authenticationContext() authCtx := jm.authenticationContext()
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil) userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
if err != nil { if err != nil {
@@ -148,7 +148,7 @@ func (jm *JumpCloudManager) GetAccount(accountID string) ([]*UserData, error) {
// GetAllAccounts gets all registered accounts with corresponding user data. // GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID. // It returns a list of users indexed by accountID.
func (jm *JumpCloudManager) GetAllAccounts() (map[string][]*UserData, error) { func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
authCtx := jm.authenticationContext() authCtx := jm.authenticationContext()
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil) userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
if err != nil { if err != nil {
@@ -177,13 +177,13 @@ func (jm *JumpCloudManager) GetAllAccounts() (map[string][]*UserData, error) {
} }
// CreateUser creates a new user in JumpCloud Idp and sends an invitation. // CreateUser creates a new user in JumpCloud Idp and sends an invitation.
func (jm *JumpCloudManager) CreateUser(_, _, _, _ string) (*UserData, error) { func (jm *JumpCloudManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented") return nil, fmt.Errorf("method CreateUser not implemented")
} }
// GetUserByEmail searches users with a given email. // GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list. // If no users have been found, this function returns an empty list.
func (jm *JumpCloudManager) GetUserByEmail(email string) ([]*UserData, error) { func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
searchFilter := map[string]interface{}{ searchFilter := map[string]interface{}{
"searchFilter": map[string]interface{}{ "searchFilter": map[string]interface{}{
"filter": []string{email}, "filter": []string{email},
@@ -219,12 +219,12 @@ func (jm *JumpCloudManager) GetUserByEmail(email string) ([]*UserData, error) {
// InviteUserByID resend invitations to users who haven't activated, // InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period. // their accounts prior to the expiration period.
func (jm *JumpCloudManager) InviteUserByID(_ string) error { func (jm *JumpCloudManager) InviteUserByID(_ context.Context, _ string) error {
return fmt.Errorf("method InviteUserByID not implemented") return fmt.Errorf("method InviteUserByID not implemented")
} }
// DeleteUser from jumpCloud directory // DeleteUser from jumpCloud directory
func (jm *JumpCloudManager) DeleteUser(userID string) error { func (jm *JumpCloudManager) DeleteUser(_ context.Context, userID string) error {
authCtx := jm.authenticationContext() authCtx := jm.authenticationContext()
_, resp, err := jm.client.SystemusersApi.SystemusersDelete(authCtx, userID, contentType, accept, nil) _, resp, err := jm.client.SystemusersApi.SystemusersDelete(authCtx, userID, contentType, accept, nil)
if err != nil { if err != nil {

View File

@@ -1,6 +1,7 @@
package idp package idp
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@@ -109,7 +110,7 @@ func (kc *KeycloakCredentials) jwtStillValid() bool {
} }
// requestJWTToken performs request to get jwt token. // requestJWTToken performs request to get jwt token.
func (kc *KeycloakCredentials) requestJWTToken() (*http.Response, error) { func (kc *KeycloakCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
data := url.Values{} data := url.Values{}
data.Set("client_id", kc.clientConfig.ClientID) data.Set("client_id", kc.clientConfig.ClientID)
data.Set("client_secret", kc.clientConfig.ClientSecret) data.Set("client_secret", kc.clientConfig.ClientSecret)
@@ -122,7 +123,7 @@ func (kc *KeycloakCredentials) requestJWTToken() (*http.Response, error) {
} }
req.Header.Add("content-type", "application/x-www-form-urlencoded") req.Header.Add("content-type", "application/x-www-form-urlencoded")
log.Debug("requesting new jwt token for keycloak idp manager") log.WithContext(ctx).Debug("requesting new jwt token for keycloak idp manager")
resp, err := kc.httpClient.Do(req) resp, err := kc.httpClient.Do(req)
if err != nil { if err != nil {
@@ -174,7 +175,7 @@ func (kc *KeycloakCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (J
} }
// Authenticate retrieves access token to use the keycloak Management API. // Authenticate retrieves access token to use the keycloak Management API.
func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) { func (kc *KeycloakCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
kc.mux.Lock() kc.mux.Lock()
defer kc.mux.Unlock() defer kc.mux.Unlock()
@@ -188,7 +189,7 @@ func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) {
return kc.jwtToken, nil return kc.jwtToken, nil
} }
resp, err := kc.requestJWTToken() resp, err := kc.requestJWTToken(ctx)
if err != nil { if err != nil {
return kc.jwtToken, err return kc.jwtToken, err
} }
@@ -205,18 +206,18 @@ func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) {
} }
// CreateUser creates a new user in keycloak Idp and sends an invite. // CreateUser creates a new user in keycloak Idp and sends an invite.
func (km *KeycloakManager) CreateUser(_, _, _, _ string) (*UserData, error) { func (km *KeycloakManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented") return nil, fmt.Errorf("method CreateUser not implemented")
} }
// GetUserByEmail searches users with a given email. // GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list. // If no users have been found, this function returns an empty list.
func (km *KeycloakManager) GetUserByEmail(email string) ([]*UserData, error) { func (km *KeycloakManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
q := url.Values{} q := url.Values{}
q.Add("email", email) q.Add("email", email)
q.Add("exact", "true") q.Add("exact", "true")
body, err := km.get("users", q) body, err := km.get(ctx, "users", q)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -240,8 +241,8 @@ func (km *KeycloakManager) GetUserByEmail(email string) ([]*UserData, error) {
} }
// GetUserDataByID requests user data from keycloak via ID. // GetUserDataByID requests user data from keycloak via ID.
func (km *KeycloakManager) GetUserDataByID(userID string, _ AppMetadata) (*UserData, error) { func (km *KeycloakManager) GetUserDataByID(ctx context.Context, userID string, _ AppMetadata) (*UserData, error) {
body, err := km.get("users/"+userID, nil) body, err := km.get(ctx, "users/"+userID, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -260,8 +261,8 @@ func (km *KeycloakManager) GetUserDataByID(userID string, _ AppMetadata) (*UserD
} }
// GetAccount returns all the users for a given account profile. // GetAccount returns all the users for a given account profile.
func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) { func (km *KeycloakManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
profiles, err := km.fetchAllUserProfiles() profiles, err := km.fetchAllUserProfiles(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -283,8 +284,8 @@ func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) {
// GetAllAccounts gets all registered accounts with corresponding user data. // GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID. // It returns a list of users indexed by accountID.
func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) { func (km *KeycloakManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
profiles, err := km.fetchAllUserProfiles() profiles, err := km.fetchAllUserProfiles(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -303,19 +304,19 @@ func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) {
} }
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. // UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (km *KeycloakManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { func (km *KeycloakManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil return nil
} }
// InviteUserByID resend invitations to users who haven't activated, // InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period. // their accounts prior to the expiration period.
func (km *KeycloakManager) InviteUserByID(_ string) error { func (km *KeycloakManager) InviteUserByID(_ context.Context, _ string) error {
return fmt.Errorf("method InviteUserByID not implemented") return fmt.Errorf("method InviteUserByID not implemented")
} }
// DeleteUser from Keycloak by user ID. // DeleteUser from Keycloak by user ID.
func (km *KeycloakManager) DeleteUser(userID string) error { func (km *KeycloakManager) DeleteUser(ctx context.Context, userID string) error {
jwtToken, err := km.credentials.Authenticate() jwtToken, err := km.credentials.Authenticate(ctx)
if err != nil { if err != nil {
return err return err
} }
@@ -353,8 +354,8 @@ func (km *KeycloakManager) DeleteUser(userID string) error {
return nil return nil
} }
func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) { func (km *KeycloakManager) fetchAllUserProfiles(ctx context.Context) ([]keycloakProfile, error) {
totalUsers, err := km.totalUsersCount() totalUsers, err := km.totalUsersCount(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -362,7 +363,7 @@ func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) {
q := url.Values{} q := url.Values{}
q.Add("max", fmt.Sprint(*totalUsers)) q.Add("max", fmt.Sprint(*totalUsers))
body, err := km.get("users", q) body, err := km.get(ctx, "users", q)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -377,8 +378,8 @@ func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) {
} }
// get perform Get requests. // get perform Get requests.
func (km *KeycloakManager) get(resource string, q url.Values) ([]byte, error) { func (km *KeycloakManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
jwtToken, err := km.credentials.Authenticate() jwtToken, err := km.credentials.Authenticate(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -414,8 +415,8 @@ func (km *KeycloakManager) get(resource string, q url.Values) ([]byte, error) {
// totalUsersCount returns the total count of all user created. // totalUsersCount returns the total count of all user created.
// Used when fetching all registered accounts with pagination. // Used when fetching all registered accounts with pagination.
func (km *KeycloakManager) totalUsersCount() (*int, error) { func (km *KeycloakManager) totalUsersCount(ctx context.Context) (*int, error) {
body, err := km.get("users/count", nil) body, err := km.get(ctx, "users/count", nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -1,6 +1,7 @@
package idp package idp
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"strings" "strings"
@@ -128,7 +129,7 @@ func TestKeycloakRequestJWTToken(t *testing.T) {
helper: testCase.helper, helper: testCase.helper,
} }
resp, err := creds.requestJWTToken() resp, err := creds.requestJWTToken(context.Background())
if err != nil { if err != nil {
if testCase.expectedFuncExitErrDiff != nil { if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
@@ -294,7 +295,7 @@ func TestKeycloakAuthenticate(t *testing.T) {
} }
creds.jwtToken.expiresInTime = testCase.inputExpireToken creds.jwtToken.expiresInTime = testCase.inputExpireToken
_, err := creds.Authenticate() _, err := creds.Authenticate(context.Background())
if err != nil { if err != nil {
if testCase.expectedFuncExitErrDiff != nil { if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")

View File

@@ -1,77 +1,79 @@
package idp package idp
import "context"
// MockIDP is a mock implementation of the IDP interface // MockIDP is a mock implementation of the IDP interface
type MockIDP struct { type MockIDP struct {
UpdateUserAppMetadataFunc func(userId string, appMetadata AppMetadata) error UpdateUserAppMetadataFunc func(ctx context.Context, userId string, appMetadata AppMetadata) error
GetUserDataByIDFunc func(userId string, appMetadata AppMetadata) (*UserData, error) GetUserDataByIDFunc func(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error)
GetAccountFunc func(accountId string) ([]*UserData, error) GetAccountFunc func(ctx context.Context, accountId string) ([]*UserData, error)
GetAllAccountsFunc func() (map[string][]*UserData, error) GetAllAccountsFunc func(ctx context.Context) (map[string][]*UserData, error)
CreateUserFunc func(email, name, accountID, invitedByEmail string) (*UserData, error) CreateUserFunc func(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error)
GetUserByEmailFunc func(email string) ([]*UserData, error) GetUserByEmailFunc func(ctx context.Context, email string) ([]*UserData, error)
InviteUserByIDFunc func(userID string) error InviteUserByIDFunc func(ctx context.Context, userID string) error
DeleteUserFunc func(userID string) error DeleteUserFunc func(ctx context.Context, userID string) error
} }
// UpdateUserAppMetadata is a mock implementation of the IDP interface UpdateUserAppMetadata method // UpdateUserAppMetadata is a mock implementation of the IDP interface UpdateUserAppMetadata method
func (m *MockIDP) UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error { func (m *MockIDP) UpdateUserAppMetadata(ctx context.Context, userId string, appMetadata AppMetadata) error {
if m.UpdateUserAppMetadataFunc != nil { if m.UpdateUserAppMetadataFunc != nil {
return m.UpdateUserAppMetadataFunc(userId, appMetadata) return m.UpdateUserAppMetadataFunc(ctx, userId, appMetadata)
} }
return nil return nil
} }
// GetUserDataByID is a mock implementation of the IDP interface GetUserDataByID method // GetUserDataByID is a mock implementation of the IDP interface GetUserDataByID method
func (m *MockIDP) GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error) { func (m *MockIDP) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) {
if m.GetUserDataByIDFunc != nil { if m.GetUserDataByIDFunc != nil {
return m.GetUserDataByIDFunc(userId, appMetadata) return m.GetUserDataByIDFunc(ctx, userId, appMetadata)
} }
return nil, nil return nil, nil
} }
// GetAccount is a mock implementation of the IDP interface GetAccount method // GetAccount is a mock implementation of the IDP interface GetAccount method
func (m *MockIDP) GetAccount(accountId string) ([]*UserData, error) { func (m *MockIDP) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) {
if m.GetAccountFunc != nil { if m.GetAccountFunc != nil {
return m.GetAccountFunc(accountId) return m.GetAccountFunc(ctx, accountId)
} }
return nil, nil return nil, nil
} }
// GetAllAccounts is a mock implementation of the IDP interface GetAllAccounts method // GetAllAccounts is a mock implementation of the IDP interface GetAllAccounts method
func (m *MockIDP) GetAllAccounts() (map[string][]*UserData, error) { func (m *MockIDP) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
if m.GetAllAccountsFunc != nil { if m.GetAllAccountsFunc != nil {
return m.GetAllAccountsFunc() return m.GetAllAccountsFunc(ctx)
} }
return nil, nil return nil, nil
} }
// CreateUser is a mock implementation of the IDP interface CreateUser method // CreateUser is a mock implementation of the IDP interface CreateUser method
func (m *MockIDP) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { func (m *MockIDP) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
if m.CreateUserFunc != nil { if m.CreateUserFunc != nil {
return m.CreateUserFunc(email, name, accountID, invitedByEmail) return m.CreateUserFunc(ctx, email, name, accountID, invitedByEmail)
} }
return nil, nil return nil, nil
} }
// GetUserByEmail is a mock implementation of the IDP interface GetUserByEmail method // GetUserByEmail is a mock implementation of the IDP interface GetUserByEmail method
func (m *MockIDP) GetUserByEmail(email string) ([]*UserData, error) { func (m *MockIDP) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
if m.GetUserByEmailFunc != nil { if m.GetUserByEmailFunc != nil {
return m.GetUserByEmailFunc(email) return m.GetUserByEmailFunc(ctx, email)
} }
return nil, nil return nil, nil
} }
// InviteUserByID is a mock implementation of the IDP interface InviteUserByID method // InviteUserByID is a mock implementation of the IDP interface InviteUserByID method
func (m *MockIDP) InviteUserByID(userID string) error { func (m *MockIDP) InviteUserByID(ctx context.Context, userID string) error {
if m.InviteUserByIDFunc != nil { if m.InviteUserByIDFunc != nil {
return m.InviteUserByIDFunc(userID) return m.InviteUserByIDFunc(ctx, userID)
} }
return nil return nil
} }
// DeleteUser is a mock implementation of the IDP interface DeleteUser method // DeleteUser is a mock implementation of the IDP interface DeleteUser method
func (m *MockIDP) DeleteUser(userID string) error { func (m *MockIDP) DeleteUser(ctx context.Context, userID string) error {
if m.DeleteUserFunc != nil { if m.DeleteUserFunc != nil {
return m.DeleteUserFunc(userID) return m.DeleteUserFunc(ctx, userID)
} }
return nil return nil
} }

View File

@@ -94,17 +94,17 @@ func NewOktaManager(config OktaClientConfig, appMetrics telemetry.AppMetrics) (*
} }
// Authenticate retrieves access token to use the okta user API. // Authenticate retrieves access token to use the okta user API.
func (oc *OktaCredentials) Authenticate() (JWTToken, error) { func (oc *OktaCredentials) Authenticate(_ context.Context) (JWTToken, error) {
return JWTToken{}, nil return JWTToken{}, nil
} }
// CreateUser creates a new user in okta Idp and sends an invitation. // CreateUser creates a new user in okta Idp and sends an invitation.
func (om *OktaManager) CreateUser(_, _, _, _ string) (*UserData, error) { func (om *OktaManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented") return nil, fmt.Errorf("method CreateUser not implemented")
} }
// GetUserDataByID requests user data from keycloak via ID. // GetUserDataByID requests user data from keycloak via ID.
func (om *OktaManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { func (om *OktaManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
user, resp, err := om.client.User.GetUser(context.Background(), userID) user, resp, err := om.client.User.GetUser(context.Background(), userID)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -132,7 +132,7 @@ func (om *OktaManager) GetUserDataByID(userID string, appMetadata AppMetadata) (
// GetUserByEmail searches users with a given email. // GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list. // If no users have been found, this function returns an empty list.
func (om *OktaManager) GetUserByEmail(email string) ([]*UserData, error) { func (om *OktaManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
user, resp, err := om.client.User.GetUser(context.Background(), url.QueryEscape(email)) user, resp, err := om.client.User.GetUser(context.Background(), url.QueryEscape(email))
if err != nil { if err != nil {
return nil, err return nil, err
@@ -160,7 +160,7 @@ func (om *OktaManager) GetUserByEmail(email string) ([]*UserData, error) {
} }
// GetAccount returns all the users for a given profile. // GetAccount returns all the users for a given profile.
func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) { func (om *OktaManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
users, err := om.getAllUsers() users, err := om.getAllUsers()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -180,7 +180,7 @@ func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) {
// GetAllAccounts gets all registered accounts with corresponding user data. // GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID. // It returns a list of users indexed by accountID.
func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) { func (om *OktaManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
users, err := om.getAllUsers() users, err := om.getAllUsers()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -242,18 +242,18 @@ func (om *OktaManager) getAllUsers() ([]*UserData, error) {
} }
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. // UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (om *OktaManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { func (om *OktaManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil return nil
} }
// InviteUserByID resend invitations to users who haven't activated, // InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period. // their accounts prior to the expiration period.
func (om *OktaManager) InviteUserByID(_ string) error { func (om *OktaManager) InviteUserByID(_ context.Context, _ string) error {
return fmt.Errorf("method InviteUserByID not implemented") return fmt.Errorf("method InviteUserByID not implemented")
} }
// DeleteUser from Okta // DeleteUser from Okta
func (om *OktaManager) DeleteUser(userID string) error { func (om *OktaManager) DeleteUser(_ context.Context, userID string) error {
resp, err := om.client.User.DeactivateOrDeleteUser(context.Background(), userID, nil) resp, err := om.client.User.DeactivateOrDeleteUser(context.Background(), userID, nil)
if err != nil { if err != nil {
return err return err

View File

@@ -1,6 +1,7 @@
package idp package idp
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@@ -149,7 +150,7 @@ func (zc *ZitadelCredentials) jwtStillValid() bool {
} }
// requestJWTToken performs request to get jwt token. // requestJWTToken performs request to get jwt token.
func (zc *ZitadelCredentials) requestJWTToken() (*http.Response, error) { func (zc *ZitadelCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
data := url.Values{} data := url.Values{}
data.Set("client_id", zc.clientConfig.ClientID) data.Set("client_id", zc.clientConfig.ClientID)
data.Set("client_secret", zc.clientConfig.ClientSecret) data.Set("client_secret", zc.clientConfig.ClientSecret)
@@ -163,7 +164,7 @@ func (zc *ZitadelCredentials) requestJWTToken() (*http.Response, error) {
} }
req.Header.Add("content-type", "application/x-www-form-urlencoded") req.Header.Add("content-type", "application/x-www-form-urlencoded")
log.Debug("requesting new jwt token for zitadel idp manager") log.WithContext(ctx).Debug("requesting new jwt token for zitadel idp manager")
resp, err := zc.httpClient.Do(req) resp, err := zc.httpClient.Do(req)
if err != nil { if err != nil {
@@ -215,7 +216,7 @@ func (zc *ZitadelCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JW
} }
// Authenticate retrieves access token to use the Zitadel Management API. // Authenticate retrieves access token to use the Zitadel Management API.
func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) { func (zc *ZitadelCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
zc.mux.Lock() zc.mux.Lock()
defer zc.mux.Unlock() defer zc.mux.Unlock()
@@ -229,7 +230,7 @@ func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) {
return zc.jwtToken, nil return zc.jwtToken, nil
} }
resp, err := zc.requestJWTToken() resp, err := zc.requestJWTToken(ctx)
if err != nil { if err != nil {
return zc.jwtToken, err return zc.jwtToken, err
} }
@@ -246,7 +247,7 @@ func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) {
} }
// CreateUser creates a new user in zitadel Idp and sends an invite via Zitadel. // CreateUser creates a new user in zitadel Idp and sends an invite via Zitadel.
func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
firstLast := strings.SplitN(name, " ", 2) firstLast := strings.SplitN(name, " ", 2)
var addUser = map[string]any{ var addUser = map[string]any{
@@ -269,7 +270,7 @@ func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail stri
return nil, err return nil, err
} }
body, err := zm.post("users/human/_import", string(payload)) body, err := zm.post(ctx, "users/human/_import", string(payload))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -300,7 +301,7 @@ func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail stri
// GetUserByEmail searches users with a given email. // GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list. // If no users have been found, this function returns an empty list.
func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) { func (zm *ZitadelManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
searchByEmail := zitadelAttributes{ searchByEmail := zitadelAttributes{
"queries": { "queries": {
{ {
@@ -316,7 +317,7 @@ func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) {
return nil, err return nil, err
} }
body, err := zm.post("users/_search", string(payload)) body, err := zm.post(ctx, "users/_search", string(payload))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -340,8 +341,8 @@ func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) {
} }
// GetUserDataByID requests user data from zitadel via ID. // GetUserDataByID requests user data from zitadel via ID.
func (zm *ZitadelManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { func (zm *ZitadelManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
body, err := zm.get("users/"+userID, nil) body, err := zm.get(ctx, "users/"+userID, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -363,8 +364,8 @@ func (zm *ZitadelManager) GetUserDataByID(userID string, appMetadata AppMetadata
} }
// GetAccount returns all the users for a given profile. // GetAccount returns all the users for a given profile.
func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) { func (zm *ZitadelManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
body, err := zm.post("users/_search", "") body, err := zm.post(ctx, "users/_search", "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -392,8 +393,8 @@ func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) {
// GetAllAccounts gets all registered accounts with corresponding user data. // GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID. // It returns a list of users indexed by accountID.
func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) { func (zm *ZitadelManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
body, err := zm.post("users/_search", "") body, err := zm.post(ctx, "users/_search", "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -419,7 +420,7 @@ func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) {
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. // UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
// Metadata values are base64 encoded. // Metadata values are base64 encoded.
func (zm *ZitadelManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { func (zm *ZitadelManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil return nil
} }
@@ -429,7 +430,7 @@ type inviteUserRequest struct {
// InviteUserByID resend invitations to users who haven't activated, // InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period. // their accounts prior to the expiration period.
func (zm *ZitadelManager) InviteUserByID(userID string) error { func (zm *ZitadelManager) InviteUserByID(ctx context.Context, userID string) error {
inviteUser := inviteUserRequest{ inviteUser := inviteUserRequest{
Email: userID, Email: userID,
} }
@@ -440,14 +441,14 @@ func (zm *ZitadelManager) InviteUserByID(userID string) error {
} }
// don't care about the body in the response // don't care about the body in the response
_, err = zm.post(fmt.Sprintf("users/%s/_resend_initialization", userID), string(payload)) _, err = zm.post(ctx, fmt.Sprintf("users/%s/_resend_initialization", userID), string(payload))
return err return err
} }
// DeleteUser from Zitadel // DeleteUser from Zitadel
func (zm *ZitadelManager) DeleteUser(userID string) error { func (zm *ZitadelManager) DeleteUser(ctx context.Context, userID string) error {
resource := fmt.Sprintf("users/%s", userID) resource := fmt.Sprintf("users/%s", userID)
if err := zm.delete(resource); err != nil { if err := zm.delete(ctx, resource); err != nil {
return err return err
} }
@@ -459,8 +460,8 @@ func (zm *ZitadelManager) DeleteUser(userID string) error {
} }
// post perform Post requests. // post perform Post requests.
func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) { func (zm *ZitadelManager) post(ctx context.Context, resource string, body string) ([]byte, error) {
jwtToken, err := zm.credentials.Authenticate() jwtToken, err := zm.credentials.Authenticate(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -495,8 +496,8 @@ func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) {
} }
// delete perform Delete requests. // delete perform Delete requests.
func (zm *ZitadelManager) delete(resource string) error { func (zm *ZitadelManager) delete(ctx context.Context, resource string) error {
jwtToken, err := zm.credentials.Authenticate() jwtToken, err := zm.credentials.Authenticate(ctx)
if err != nil { if err != nil {
return err return err
} }
@@ -531,8 +532,8 @@ func (zm *ZitadelManager) delete(resource string) error {
} }
// get perform Get requests. // get perform Get requests.
func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) { func (zm *ZitadelManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
jwtToken, err := zm.credentials.Authenticate() jwtToken, err := zm.credentials.Authenticate(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -1,6 +1,7 @@
package idp package idp
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"strings" "strings"
@@ -108,7 +109,7 @@ func TestZitadelRequestJWTToken(t *testing.T) {
helper: testCase.helper, helper: testCase.helper,
} }
resp, err := creds.requestJWTToken() resp, err := creds.requestJWTToken(context.Background())
if err != nil { if err != nil {
if testCase.expectedFuncExitErrDiff != nil { if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
@@ -274,7 +275,7 @@ func TestZitadelAuthenticate(t *testing.T) {
} }
creds.jwtToken.expiresInTime = testCase.inputExpireToken creds.jwtToken.expiresInTime = testCase.inputExpireToken
_, err := creds.Authenticate() _, err := creds.Authenticate(context.Background())
if err != nil { if err != nil {
if testCase.expectedFuncExitErrDiff != nil { if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")

View File

@@ -1,9 +1,10 @@
package server package server
import ( import (
"context"
"errors" "errors"
"github.com/google/martian/v3/log" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
) )
@@ -19,22 +20,22 @@ import (
// //
// Returns: // Returns:
// - error: An error if any occurred during the process, otherwise returns nil // - error: An error if any occurred during the process, otherwise returns nil
func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error { func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error {
ok, err := am.GroupValidation(accountID, groups) ok, err := am.GroupValidation(ctx, accountID, groups)
if err != nil { if err != nil {
log.Debugf("error validating groups: %s", err.Error()) log.WithContext(ctx).Debugf("error validating groups: %s", err.Error())
return err return err
} }
if !ok { if !ok {
log.Debugf("invalid groups") log.WithContext(ctx).Debugf("invalid groups")
return errors.New("invalid groups") return errors.New("invalid groups")
} }
unlock := am.Store.AcquireAccountWriteLock(accountID) unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock() defer unlock()
a, err := am.Store.GetAccountByUser(userID) a, err := am.Store.GetAccountByUser(ctx, userID)
if err != nil { if err != nil {
return err return err
} }
@@ -48,14 +49,14 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID strin
a.Settings.Extra = extra a.Settings.Extra = extra
} }
extra.IntegratedValidatorGroups = groups extra.IntegratedValidatorGroups = groups
return am.Store.SaveAccount(a) return am.Store.SaveAccount(ctx, a)
} }
func (am *DefaultAccountManager) GroupValidation(accountId string, groups []string) (bool, error) { func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) {
if len(groups) == 0 { if len(groups) == 0 {
return true, nil return true, nil
} }
accountsGroups, err := am.ListGroups(accountId) accountsGroups, err := am.ListGroups(ctx, accountId)
if err != nil { if err != nil {
return false, err return false, err
} }

View File

@@ -1,6 +1,8 @@
package integrated_validator package integrated_validator
import ( import (
"context"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -8,12 +10,12 @@ import (
// IntegratedValidator interface exists to avoid the circle dependencies // IntegratedValidator interface exists to avoid the circle dependencies
type IntegratedValidator interface { type IntegratedValidator interface {
ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error ValidateExtraSettings(ctx context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error)
PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer
IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error)
GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error)
PeerDeleted(accountID, peerID string) error PeerDeleted(ctx context.Context, accountID, peerID string) error
SetPeerInvalidationListener(fn func(accountID string)) SetPeerInvalidationListener(fn func(accountID string))
Stop() Stop(ctx context.Context)
} }

Some files were not shown because too many files have changed in this diff Show More