mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-24 03:06:38 +00:00
Compare commits
127 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b0c3818e06 | ||
|
|
6922826919 | ||
|
|
56a1a75e3f | ||
|
|
d9402168ad | ||
|
|
dbdef04b9e | ||
|
|
29cbfe8467 | ||
|
|
6ce8643368 | ||
|
|
07d1ad35fc | ||
|
|
ef6cd36f1a | ||
|
|
c1c71b6d39 | ||
|
|
0480507a10 | ||
|
|
34ac4e4b5a | ||
|
|
52ff9d9602 | ||
|
|
1b73fae46e | ||
|
|
d897365abc | ||
|
|
f37aa2cc9d | ||
|
|
5343bee7b2 | ||
|
|
870e29db63 | ||
|
|
08e9b05d51 | ||
|
|
3581648071 | ||
|
|
2a51609436 | ||
|
|
83457f8b99 | ||
|
|
b45284f086 | ||
|
|
e9016aecea | ||
|
|
23b5d45b68 | ||
|
|
0e5dc9d412 | ||
|
|
91f7ee6a3c | ||
|
|
7c6b85b4cb | ||
|
|
08c9107c61 | ||
|
|
81d83245e1 | ||
|
|
af2b427751 | ||
|
|
f61ebdb3bc | ||
|
|
de7384e8ea | ||
|
|
75c1be69cf | ||
|
|
424ae28de9 | ||
|
|
d4a800edd5 | ||
|
|
dd9917f1a8 | ||
|
|
8df8c1012f | ||
|
|
bfa5c21d2d | ||
|
|
b1247a14ba | ||
|
|
f595057a0b | ||
|
|
089d442fb2 | ||
|
|
04a3765391 | ||
|
|
d24d8328f9 | ||
|
|
4f63996ae8 | ||
|
|
bdf2994e97 | ||
|
|
6d654acbad | ||
|
|
3e43298471 | ||
|
|
0ad2590974 | ||
|
|
9d11257b1a | ||
|
|
4ee1635baa | ||
|
|
75feb0da8b | ||
|
|
87376afd13 | ||
|
|
b76d9e8e9e | ||
|
|
e71383dcb9 | ||
|
|
e002a2e6e8 | ||
|
|
6127a01196 | ||
|
|
de27d6df36 | ||
|
|
3c535cdd2b | ||
|
|
0f050e5fe1 | ||
|
|
0f7c7f1da2 | ||
|
|
b56f61bf1b | ||
|
|
64f111923e | ||
|
|
122a89c02b | ||
|
|
c6cceba381 | ||
|
|
6c0cdb6ed1 | ||
|
|
84354951d3 | ||
|
|
55957a1960 | ||
|
|
df82a45d99 | ||
|
|
9424b88db2 | ||
|
|
609654eee7 | ||
|
|
b604c66140 | ||
|
|
ea4d13e96d | ||
|
|
87148c503f | ||
|
|
0cd36baf67 | ||
|
|
06980e7fa0 | ||
|
|
1ce4ee0cef | ||
|
|
f367925496 | ||
|
|
616b19c064 | ||
|
|
af27aaf9af | ||
|
|
35287f8241 | ||
|
|
07b220d91b | ||
|
|
41cd4952f1 | ||
|
|
f16f0c7831 | ||
|
|
aa07b3b87b | ||
|
|
2bef214cc0 | ||
|
|
cfb2d82352 | ||
|
|
684501fd35 | ||
|
|
0492c1724a | ||
|
|
6f436e57b5 | ||
|
|
a0d28f9851 | ||
|
|
cdd27a9fe5 | ||
|
|
5523040acd | ||
|
|
670446d42e | ||
|
|
5bed6777d5 | ||
|
|
a0482ebc7b | ||
|
|
2a89d6e47a | ||
|
|
24f932b2ce | ||
|
|
c03435061c | ||
|
|
8e948739f1 | ||
|
|
9b53cad752 | ||
|
|
802a18167c | ||
|
|
e9108ffe6c | ||
|
|
e806d9de38 | ||
|
|
daa8380df9 | ||
|
|
4785f23fc4 | ||
|
|
1d4cfb83e7 | ||
|
|
207fa059d2 | ||
|
|
cbcdad7814 | ||
|
|
701c13807a | ||
|
|
99f8dc7748 | ||
|
|
f1de8e6eb0 | ||
|
|
b2a10780af | ||
|
|
43ae79d848 | ||
|
|
e520b64c6d | ||
|
|
92c91bbdd8 | ||
|
|
adf494e1ac | ||
|
|
2158461121 | ||
|
|
0cd4b601c3 | ||
|
|
ee1cec47b3 | ||
|
|
efb0edfc4c | ||
|
|
20f59ddecb | ||
|
|
2f34e984b0 | ||
|
|
d5b52e86b6 | ||
|
|
cad2fe1f39 | ||
|
|
fcd2c15a37 | ||
|
|
ebda0fc538 |
15
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
15
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
@@ -37,17 +37,22 @@ If yes, which one?
|
|||||||
|
|
||||||
**Debug output**
|
**Debug output**
|
||||||
|
|
||||||
To help us resolve the problem, please attach the following debug output
|
To help us resolve the problem, please attach the following anonymized status output
|
||||||
|
|
||||||
netbird status -dA
|
netbird status -dA
|
||||||
|
|
||||||
As well as the file created by
|
Create and upload a debug bundle, and share the returned file key:
|
||||||
|
|
||||||
|
netbird debug for 1m -AS -U
|
||||||
|
|
||||||
|
*Uploaded files are automatically deleted after 30 days.*
|
||||||
|
|
||||||
|
|
||||||
|
Alternatively, create the file only and attach it here manually:
|
||||||
|
|
||||||
netbird debug for 1m -AS
|
netbird debug for 1m -AS
|
||||||
|
|
||||||
|
|
||||||
We advise reviewing the anonymized output for any remaining personal information.
|
|
||||||
|
|
||||||
**Screenshots**
|
**Screenshots**
|
||||||
|
|
||||||
If applicable, add screenshots to help explain your problem.
|
If applicable, add screenshots to help explain your problem.
|
||||||
@@ -57,8 +62,10 @@ If applicable, add screenshots to help explain your problem.
|
|||||||
Add any other context about the problem here.
|
Add any other context about the problem here.
|
||||||
|
|
||||||
**Have you tried these troubleshooting steps?**
|
**Have you tried these troubleshooting steps?**
|
||||||
|
- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable)
|
||||||
- [ ] Checked for newer NetBird versions
|
- [ ] Checked for newer NetBird versions
|
||||||
- [ ] Searched for similar issues on GitHub (including closed ones)
|
- [ ] Searched for similar issues on GitHub (including closed ones)
|
||||||
- [ ] Restarted the NetBird client
|
- [ ] Restarted the NetBird client
|
||||||
- [ ] Disabled other VPN software
|
- [ ] Disabled other VPN software
|
||||||
- [ ] Checked firewall settings
|
- [ ] Checked firewall settings
|
||||||
|
|
||||||
|
|||||||
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@@ -13,3 +13,5 @@
|
|||||||
- [ ] It is a refactor
|
- [ ] It is a refactor
|
||||||
- [ ] Created tests that fail without the change (if possible)
|
- [ ] Created tests that fail without the change (if possible)
|
||||||
- [ ] Extended the README / documentation, if necessary
|
- [ ] Extended the README / documentation, if necessary
|
||||||
|
|
||||||
|
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).
|
||||||
|
|||||||
8
.github/workflows/golang-test-linux.yml
vendored
8
.github/workflows/golang-test-linux.yml
vendored
@@ -223,6 +223,10 @@ jobs:
|
|||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
@@ -269,6 +273,10 @@ jobs:
|
|||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
|||||||
1
.github/workflows/golangci-lint.yml
vendored
1
.github/workflows/golangci-lint.yml
vendored
@@ -21,7 +21,6 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
|
||||||
skip: go.mod,go.sum
|
skip: go.mod,go.sum
|
||||||
only_warn: 1
|
|
||||||
golangci:
|
golangci:
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|||||||
9
.github/workflows/release.yml
vendored
9
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.0.18"
|
SIGN_PIPE_VER: "v0.0.20"
|
||||||
GORELEASER_VER: "v2.3.2"
|
GORELEASER_VER: "v2.3.2"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "NetBird GmbH"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
@@ -65,6 +65,13 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKER_USER }}
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
password: ${{ secrets.DOCKER_TOKEN }}
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
- name: Log in to the GitHub container registry
|
||||||
|
if: github.event_name != 'pull_request'
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: ghcr.io
|
||||||
|
username: ${{ github.actor }}
|
||||||
|
password: ${{ secrets.CI_DOCKER_PUSH_GITHUB_TOKEN }}
|
||||||
- name: Install OS build dependencies
|
- name: Install OS build dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
||||||
|
|
||||||
|
|||||||
@@ -134,6 +134,7 @@ jobs:
|
|||||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
||||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||||
|
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
|
||||||
|
|
||||||
run: |
|
run: |
|
||||||
set -x
|
set -x
|
||||||
@@ -172,13 +173,15 @@ jobs:
|
|||||||
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
|
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
|
||||||
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
||||||
# check relay values
|
# check relay values
|
||||||
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
|
grep "NB_EXPOSED_ADDRESS=rels://$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
|
||||||
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
|
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
|
||||||
grep '33445:33445' docker-compose.yml
|
grep '33445:33445' docker-compose.yml
|
||||||
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
|
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
|
||||||
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
|
grep -A 7 Relay management.json | grep "rels://$CI_NETBIRD_DOMAIN:33445"
|
||||||
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
||||||
grep DisablePromptLogin management.json | grep 'true'
|
grep DisablePromptLogin management.json | grep 'true'
|
||||||
|
grep LoginFlag management.json | grep 0
|
||||||
|
grep DisableDefaultPolicy management.json | grep "$CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY"
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|||||||
136
.goreleaser.yaml
136
.goreleaser.yaml
@@ -149,6 +149,7 @@ nfpms:
|
|||||||
dockers:
|
dockers:
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/netbird:{{ .Version }}-amd64
|
- netbirdio/netbird:{{ .Version }}-amd64
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
|
||||||
ids:
|
ids:
|
||||||
- netbird
|
- netbird
|
||||||
goarch: amd64
|
goarch: amd64
|
||||||
@@ -164,6 +165,7 @@ dockers:
|
|||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/netbird:{{ .Version }}-arm64v8
|
- netbirdio/netbird:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||||
ids:
|
ids:
|
||||||
- netbird
|
- netbird
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
@@ -175,10 +177,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/netbird:{{ .Version }}-arm
|
- netbirdio/netbird:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
|
||||||
ids:
|
ids:
|
||||||
- netbird
|
- netbird
|
||||||
goarch: arm
|
goarch: arm
|
||||||
@@ -191,11 +194,12 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||||
ids:
|
ids:
|
||||||
- netbird
|
- netbird
|
||||||
goarch: amd64
|
goarch: amd64
|
||||||
@@ -207,9 +211,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||||
ids:
|
ids:
|
||||||
- netbird
|
- netbird
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
@@ -221,9 +227,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||||
ids:
|
ids:
|
||||||
- netbird
|
- netbird
|
||||||
goarch: arm
|
goarch: arm
|
||||||
@@ -236,10 +244,12 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/relay:{{ .Version }}-amd64
|
- netbirdio/relay:{{ .Version }}-amd64
|
||||||
|
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
|
||||||
ids:
|
ids:
|
||||||
- netbird-relay
|
- netbird-relay
|
||||||
goarch: amd64
|
goarch: amd64
|
||||||
@@ -251,10 +261,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/relay:{{ .Version }}-arm64v8
|
- netbirdio/relay:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
|
||||||
ids:
|
ids:
|
||||||
- netbird-relay
|
- netbird-relay
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
@@ -266,10 +277,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/relay:{{ .Version }}-arm
|
- netbirdio/relay:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
|
||||||
ids:
|
ids:
|
||||||
- netbird-relay
|
- netbird-relay
|
||||||
goarch: arm
|
goarch: arm
|
||||||
@@ -282,10 +294,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/signal:{{ .Version }}-amd64
|
- netbirdio/signal:{{ .Version }}-amd64
|
||||||
|
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
|
||||||
ids:
|
ids:
|
||||||
- netbird-signal
|
- netbird-signal
|
||||||
goarch: amd64
|
goarch: amd64
|
||||||
@@ -297,10 +310,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/signal:{{ .Version }}-arm64v8
|
- netbirdio/signal:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
|
||||||
ids:
|
ids:
|
||||||
- netbird-signal
|
- netbird-signal
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
@@ -312,10 +326,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/signal:{{ .Version }}-arm
|
- netbirdio/signal:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
|
||||||
ids:
|
ids:
|
||||||
- netbird-signal
|
- netbird-signal
|
||||||
goarch: arm
|
goarch: arm
|
||||||
@@ -328,10 +343,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/management:{{ .Version }}-amd64
|
- netbirdio/management:{{ .Version }}-amd64
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
|
||||||
ids:
|
ids:
|
||||||
- netbird-mgmt
|
- netbird-mgmt
|
||||||
goarch: amd64
|
goarch: amd64
|
||||||
@@ -343,10 +359,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/management:{{ .Version }}-arm64v8
|
- netbirdio/management:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
|
||||||
ids:
|
ids:
|
||||||
- netbird-mgmt
|
- netbird-mgmt
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
@@ -358,10 +375,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/management:{{ .Version }}-arm
|
- netbirdio/management:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-arm
|
||||||
ids:
|
ids:
|
||||||
- netbird-mgmt
|
- netbird-mgmt
|
||||||
goarch: arm
|
goarch: arm
|
||||||
@@ -374,10 +392,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/management:{{ .Version }}-debug-amd64
|
- netbirdio/management:{{ .Version }}-debug-amd64
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
|
||||||
ids:
|
ids:
|
||||||
- netbird-mgmt
|
- netbird-mgmt
|
||||||
goarch: amd64
|
goarch: amd64
|
||||||
@@ -389,10 +408,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/management:{{ .Version }}-debug-arm64v8
|
- netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||||
ids:
|
ids:
|
||||||
- netbird-mgmt
|
- netbird-mgmt
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
@@ -404,11 +424,12 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/management:{{ .Version }}-debug-arm
|
- netbirdio/management:{{ .Version }}-debug-arm
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
|
||||||
ids:
|
ids:
|
||||||
- netbird-mgmt
|
- netbird-mgmt
|
||||||
goarch: arm
|
goarch: arm
|
||||||
@@ -421,10 +442,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/upload:{{ .Version }}-amd64
|
- netbirdio/upload:{{ .Version }}-amd64
|
||||||
|
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
||||||
ids:
|
ids:
|
||||||
- netbird-upload
|
- netbird-upload
|
||||||
goarch: amd64
|
goarch: amd64
|
||||||
@@ -436,10 +458,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/upload:{{ .Version }}-arm64v8
|
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
||||||
ids:
|
ids:
|
||||||
- netbird-upload
|
- netbird-upload
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
@@ -451,10 +474,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/upload:{{ .Version }}-arm
|
- netbirdio/upload:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
||||||
ids:
|
ids:
|
||||||
- netbird-upload
|
- netbird-upload
|
||||||
goarch: arm
|
goarch: arm
|
||||||
@@ -467,7 +491,7 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
docker_manifests:
|
docker_manifests:
|
||||||
- name_template: netbirdio/netbird:{{ .Version }}
|
- name_template: netbirdio/netbird:{{ .Version }}
|
||||||
@@ -546,6 +570,84 @@ docker_manifests:
|
|||||||
- netbirdio/upload:{{ .Version }}-arm64v8
|
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||||
- netbirdio/upload:{{ .Version }}-arm
|
- netbirdio/upload:{{ .Version }}-arm
|
||||||
- netbirdio/upload:{{ .Version }}-amd64
|
- netbirdio/upload:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/netbird:latest
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}-rootless
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/netbird:rootless-latest
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/relay:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/relay:latest
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/signal:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/signal:latest
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/management:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/management:latest
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/management:debug-latest
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/upload:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/upload:latest
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
||||||
brews:
|
brews:
|
||||||
- ids:
|
- ids:
|
||||||
- default
|
- default
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||||
</a>
|
</a>
|
||||||
<br>
|
<br>
|
||||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">
|
<a href="https://docs.netbird.io/slack-url">
|
||||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||||
</a>
|
</a>
|
||||||
<br>
|
<br>
|
||||||
@@ -29,7 +29,7 @@
|
|||||||
<br/>
|
<br/>
|
||||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||||
<br/>
|
<br/>
|
||||||
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">Slack channel</a>
|
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a>
|
||||||
<br/>
|
<br/>
|
||||||
|
|
||||||
</strong>
|
</strong>
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
FROM alpine:3.21.3
|
FROM alpine:3.21.3
|
||||||
# iproute2: busybox doesn't display ip rules properly
|
# iproute2: busybox doesn't display ip rules properly
|
||||||
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
|
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
|
||||||
|
|
||||||
|
ARG NETBIRD_BINARY=netbird
|
||||||
|
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
|
||||||
|
|
||||||
ENV NB_FOREGROUND_MODE=true
|
ENV NB_FOREGROUND_MODE=true
|
||||||
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
||||||
COPY netbird /usr/local/bin/netbird
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
FROM alpine:3.21.0
|
FROM alpine:3.21.0
|
||||||
|
|
||||||
COPY netbird /usr/local/bin/netbird
|
ARG NETBIRD_BINARY=netbird
|
||||||
|
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
|
||||||
|
|
||||||
RUN apk add --no-cache ca-certificates \
|
RUN apk add --no-cache ca-certificates \
|
||||||
&& adduser -D -h /var/lib/netbird netbird
|
&& adduser -D -h /var/lib/netbird netbird
|
||||||
|
|||||||
@@ -59,6 +59,8 @@ type Client struct {
|
|||||||
deviceName string
|
deviceName string
|
||||||
uiVersion string
|
uiVersion string
|
||||||
networkChangeListener listener.NetworkChangeListener
|
networkChangeListener listener.NetworkChangeListener
|
||||||
|
|
||||||
|
connectClient *internal.ConnectClient
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient instantiate a new Client
|
// NewClient instantiate a new Client
|
||||||
@@ -106,8 +108,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||||
@@ -132,8 +134,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
@@ -174,6 +176,53 @@ func (c *Client) PeersList() *PeerInfoArray {
|
|||||||
return &PeerInfoArray{items: peerInfos}
|
return &PeerInfoArray{items: peerInfos}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) Networks() *NetworkArray {
|
||||||
|
if c.connectClient == nil {
|
||||||
|
log.Error("not connected")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := c.connectClient.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
log.Error("could not get engine")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
routeManager := engine.GetRouteManager()
|
||||||
|
if routeManager == nil {
|
||||||
|
log.Error("could not get route manager")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
networkArray := &NetworkArray{
|
||||||
|
items: make([]Network, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, routes := range routeManager.GetClientRoutesWithNetID() {
|
||||||
|
if len(routes) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if routes[0].IsDynamic() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
peer, err := c.recorder.GetPeer(routes[0].Peer)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
network := Network{
|
||||||
|
Name: string(id),
|
||||||
|
Network: routes[0].Network.String(),
|
||||||
|
Peer: peer.FQDN,
|
||||||
|
Status: peer.ConnStatus.String(),
|
||||||
|
}
|
||||||
|
networkArray.Add(network)
|
||||||
|
}
|
||||||
|
return networkArray
|
||||||
|
}
|
||||||
|
|
||||||
// OnUpdatedHostDNS update the DNS servers addresses for root zones
|
// OnUpdatedHostDNS update the DNS servers addresses for root zones
|
||||||
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
||||||
dnsServer, err := dns.GetServerDns()
|
dnsServer, err := dns.GetServerDns()
|
||||||
|
|||||||
27
client/android/networks.go
Normal file
27
client/android/networks.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package android
|
||||||
|
|
||||||
|
type Network struct {
|
||||||
|
Name string
|
||||||
|
Network string
|
||||||
|
Peer string
|
||||||
|
Status string
|
||||||
|
}
|
||||||
|
|
||||||
|
type NetworkArray struct {
|
||||||
|
items []Network
|
||||||
|
}
|
||||||
|
|
||||||
|
func (array *NetworkArray) Add(s Network) *NetworkArray {
|
||||||
|
array.items = append(array.items, s)
|
||||||
|
return array
|
||||||
|
}
|
||||||
|
|
||||||
|
func (array *NetworkArray) Get(i int) *Network {
|
||||||
|
return &array.items[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (array *NetworkArray) Size() int {
|
||||||
|
return len(array.items)
|
||||||
|
}
|
||||||
@@ -7,30 +7,23 @@ type PeerInfo struct {
|
|||||||
ConnStatus string // Todo replace to enum
|
ConnStatus string // Todo replace to enum
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeerInfoCollection made for Java layer to get non default types as collection
|
// PeerInfoArray is a wrapper of []PeerInfo
|
||||||
type PeerInfoCollection interface {
|
|
||||||
Add(s string) PeerInfoCollection
|
|
||||||
Get(i int) string
|
|
||||||
Size() int
|
|
||||||
}
|
|
||||||
|
|
||||||
// PeerInfoArray is the implementation of the PeerInfoCollection
|
|
||||||
type PeerInfoArray struct {
|
type PeerInfoArray struct {
|
||||||
items []PeerInfo
|
items []PeerInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add new PeerInfo to the collection
|
// Add new PeerInfo to the collection
|
||||||
func (array PeerInfoArray) Add(s PeerInfo) PeerInfoArray {
|
func (array *PeerInfoArray) Add(s PeerInfo) *PeerInfoArray {
|
||||||
array.items = append(array.items, s)
|
array.items = append(array.items, s)
|
||||||
return array
|
return array
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get return an element of the collection
|
// Get return an element of the collection
|
||||||
func (array PeerInfoArray) Get(i int) *PeerInfo {
|
func (array *PeerInfoArray) Get(i int) *PeerInfo {
|
||||||
return &array.items[i]
|
return &array.items[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Size return with the size of the collection
|
// Size return with the size of the collection
|
||||||
func (array PeerInfoArray) Size() int {
|
func (array *PeerInfoArray) Size() int {
|
||||||
return len(array.items)
|
return len(array.items)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,12 +4,12 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Preferences export a subset of the internal config for gomobile
|
// Preferences exports a subset of the internal config for gomobile
|
||||||
type Preferences struct {
|
type Preferences struct {
|
||||||
configInput internal.ConfigInput
|
configInput internal.ConfigInput
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPreferences create new Preferences instance
|
// NewPreferences creates a new Preferences instance
|
||||||
func NewPreferences(configPath string) *Preferences {
|
func NewPreferences(configPath string) *Preferences {
|
||||||
ci := internal.ConfigInput{
|
ci := internal.ConfigInput{
|
||||||
ConfigPath: configPath,
|
ConfigPath: configPath,
|
||||||
@@ -17,7 +17,7 @@ func NewPreferences(configPath string) *Preferences {
|
|||||||
return &Preferences{ci}
|
return &Preferences{ci}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetManagementURL read url from config file
|
// GetManagementURL reads URL from config file
|
||||||
func (p *Preferences) GetManagementURL() (string, error) {
|
func (p *Preferences) GetManagementURL() (string, error) {
|
||||||
if p.configInput.ManagementURL != "" {
|
if p.configInput.ManagementURL != "" {
|
||||||
return p.configInput.ManagementURL, nil
|
return p.configInput.ManagementURL, nil
|
||||||
@@ -30,12 +30,12 @@ func (p *Preferences) GetManagementURL() (string, error) {
|
|||||||
return cfg.ManagementURL.String(), err
|
return cfg.ManagementURL.String(), err
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetManagementURL store the given url and wait for commit
|
// SetManagementURL stores the given URL and waits for commit
|
||||||
func (p *Preferences) SetManagementURL(url string) {
|
func (p *Preferences) SetManagementURL(url string) {
|
||||||
p.configInput.ManagementURL = url
|
p.configInput.ManagementURL = url
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAdminURL read url from config file
|
// GetAdminURL reads URL from config file
|
||||||
func (p *Preferences) GetAdminURL() (string, error) {
|
func (p *Preferences) GetAdminURL() (string, error) {
|
||||||
if p.configInput.AdminURL != "" {
|
if p.configInput.AdminURL != "" {
|
||||||
return p.configInput.AdminURL, nil
|
return p.configInput.AdminURL, nil
|
||||||
@@ -48,12 +48,12 @@ func (p *Preferences) GetAdminURL() (string, error) {
|
|||||||
return cfg.AdminURL.String(), err
|
return cfg.AdminURL.String(), err
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetAdminURL store the given url and wait for commit
|
// SetAdminURL stores the given URL and waits for commit
|
||||||
func (p *Preferences) SetAdminURL(url string) {
|
func (p *Preferences) SetAdminURL(url string) {
|
||||||
p.configInput.AdminURL = url
|
p.configInput.AdminURL = url
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPreSharedKey read preshared key from config file
|
// GetPreSharedKey reads pre-shared key from config file
|
||||||
func (p *Preferences) GetPreSharedKey() (string, error) {
|
func (p *Preferences) GetPreSharedKey() (string, error) {
|
||||||
if p.configInput.PreSharedKey != nil {
|
if p.configInput.PreSharedKey != nil {
|
||||||
return *p.configInput.PreSharedKey, nil
|
return *p.configInput.PreSharedKey, nil
|
||||||
@@ -66,12 +66,160 @@ func (p *Preferences) GetPreSharedKey() (string, error) {
|
|||||||
return cfg.PreSharedKey, err
|
return cfg.PreSharedKey, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetPreSharedKey store the given key and wait for commit
|
// SetPreSharedKey stores the given key and waits for commit
|
||||||
func (p *Preferences) SetPreSharedKey(key string) {
|
func (p *Preferences) SetPreSharedKey(key string) {
|
||||||
p.configInput.PreSharedKey = &key
|
p.configInput.PreSharedKey = &key
|
||||||
}
|
}
|
||||||
|
|
||||||
// Commit write out the changes into config file
|
// SetRosenpassEnabled stores whether Rosenpass is enabled
|
||||||
|
func (p *Preferences) SetRosenpassEnabled(enabled bool) {
|
||||||
|
p.configInput.RosenpassEnabled = &enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRosenpassEnabled reads Rosenpass enabled status from config file
|
||||||
|
func (p *Preferences) GetRosenpassEnabled() (bool, error) {
|
||||||
|
if p.configInput.RosenpassEnabled != nil {
|
||||||
|
return *p.configInput.RosenpassEnabled, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return cfg.RosenpassEnabled, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRosenpassPermissive stores the given permissive setting and waits for commit
|
||||||
|
func (p *Preferences) SetRosenpassPermissive(permissive bool) {
|
||||||
|
p.configInput.RosenpassPermissive = &permissive
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRosenpassPermissive reads Rosenpass permissive setting from config file
|
||||||
|
func (p *Preferences) GetRosenpassPermissive() (bool, error) {
|
||||||
|
if p.configInput.RosenpassPermissive != nil {
|
||||||
|
return *p.configInput.RosenpassPermissive, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return cfg.RosenpassPermissive, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDisableClientRoutes reads disable client routes setting from config file
|
||||||
|
func (p *Preferences) GetDisableClientRoutes() (bool, error) {
|
||||||
|
if p.configInput.DisableClientRoutes != nil {
|
||||||
|
return *p.configInput.DisableClientRoutes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return cfg.DisableClientRoutes, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDisableClientRoutes stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetDisableClientRoutes(disable bool) {
|
||||||
|
p.configInput.DisableClientRoutes = &disable
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDisableServerRoutes reads disable server routes setting from config file
|
||||||
|
func (p *Preferences) GetDisableServerRoutes() (bool, error) {
|
||||||
|
if p.configInput.DisableServerRoutes != nil {
|
||||||
|
return *p.configInput.DisableServerRoutes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return cfg.DisableServerRoutes, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDisableServerRoutes stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetDisableServerRoutes(disable bool) {
|
||||||
|
p.configInput.DisableServerRoutes = &disable
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDisableDNS reads disable DNS setting from config file
|
||||||
|
func (p *Preferences) GetDisableDNS() (bool, error) {
|
||||||
|
if p.configInput.DisableDNS != nil {
|
||||||
|
return *p.configInput.DisableDNS, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return cfg.DisableDNS, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDisableDNS stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetDisableDNS(disable bool) {
|
||||||
|
p.configInput.DisableDNS = &disable
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDisableFirewall reads disable firewall setting from config file
|
||||||
|
func (p *Preferences) GetDisableFirewall() (bool, error) {
|
||||||
|
if p.configInput.DisableFirewall != nil {
|
||||||
|
return *p.configInput.DisableFirewall, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return cfg.DisableFirewall, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDisableFirewall stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetDisableFirewall(disable bool) {
|
||||||
|
p.configInput.DisableFirewall = &disable
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetServerSSHAllowed reads server SSH allowed setting from config file
|
||||||
|
func (p *Preferences) GetServerSSHAllowed() (bool, error) {
|
||||||
|
if p.configInput.ServerSSHAllowed != nil {
|
||||||
|
return *p.configInput.ServerSSHAllowed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if cfg.ServerSSHAllowed == nil {
|
||||||
|
// Default to false for security on Android
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return *cfg.ServerSSHAllowed, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetServerSSHAllowed stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetServerSSHAllowed(allowed bool) {
|
||||||
|
p.configInput.ServerSSHAllowed = &allowed
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBlockInbound reads block inbound setting from config file
|
||||||
|
func (p *Preferences) GetBlockInbound() (bool, error) {
|
||||||
|
if p.configInput.BlockInbound != nil {
|
||||||
|
return *p.configInput.BlockInbound, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return cfg.BlockInbound, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBlockInbound stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetBlockInbound(block bool) {
|
||||||
|
p.configInput.BlockInbound = &block
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commit writes out the changes to the config file
|
||||||
func (p *Preferences) Commit() error {
|
func (p *Preferences) Commit() error {
|
||||||
_, err := internal.UpdateOrCreateConfig(p.configInput)
|
_, err := internal.UpdateOrCreateConfig(p.configInput)
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -69,6 +69,22 @@ func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
|
|||||||
return a.ipAnonymizer[ip]
|
return a.ipAnonymizer[ip]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Anonymizer) AnonymizeUDPAddr(addr net.UDPAddr) net.UDPAddr {
|
||||||
|
// Convert IP to netip.Addr
|
||||||
|
ip, ok := netip.AddrFromSlice(addr.IP)
|
||||||
|
if !ok {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
anonIP := a.AnonymizeIP(ip)
|
||||||
|
|
||||||
|
return net.UDPAddr{
|
||||||
|
IP: anonIP.AsSlice(),
|
||||||
|
Port: addr.Port,
|
||||||
|
Zone: addr.Zone,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
|
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
|
||||||
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
|
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
|
||||||
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {
|
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {
|
||||||
|
|||||||
210
client/cmd/flags.go
Normal file
210
client/cmd/flags.go
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SharedFlags contains all configuration flags that are common between up and set commands
|
||||||
|
type SharedFlags struct {
|
||||||
|
// Network configuration
|
||||||
|
InterfaceName string
|
||||||
|
WireguardPort uint16
|
||||||
|
NATExternalIPs []string
|
||||||
|
CustomDNSAddress string
|
||||||
|
ExtraIFaceBlackList []string
|
||||||
|
DNSLabels []string
|
||||||
|
DNSRouteInterval time.Duration
|
||||||
|
|
||||||
|
// Feature flags
|
||||||
|
RosenpassEnabled bool
|
||||||
|
RosenpassPermissive bool
|
||||||
|
ServerSSHAllowed bool
|
||||||
|
AutoConnectDisabled bool
|
||||||
|
NetworkMonitor bool
|
||||||
|
LazyConnEnabled bool
|
||||||
|
|
||||||
|
// System flags
|
||||||
|
DisableClientRoutes bool
|
||||||
|
DisableServerRoutes bool
|
||||||
|
DisableDNS bool
|
||||||
|
DisableFirewall bool
|
||||||
|
BlockLANAccess bool
|
||||||
|
BlockInbound bool
|
||||||
|
|
||||||
|
// Login-specific (only for up command)
|
||||||
|
NoBrowser bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSharedFlags adds all shared configuration flags to the given command
|
||||||
|
func AddSharedFlags(cmd *cobra.Command, flags *SharedFlags) {
|
||||||
|
// Network configuration flags
|
||||||
|
cmd.PersistentFlags().StringVar(&flags.InterfaceName, interfaceNameFlag, iface.WgInterfaceDefault,
|
||||||
|
"Wireguard interface name")
|
||||||
|
cmd.PersistentFlags().Uint16Var(&flags.WireguardPort, wireguardPortFlag, iface.DefaultWgPort,
|
||||||
|
"Wireguard interface listening port")
|
||||||
|
cmd.PersistentFlags().StringSliceVar(&flags.NATExternalIPs, externalIPMapFlag, nil,
|
||||||
|
`Sets external IPs maps between local addresses and interfaces. `+
|
||||||
|
`You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+
|
||||||
|
`An empty string "" clears the previous configuration. `+
|
||||||
|
`E.g. --external-ip-map 12.34.56.78/10.0.0.1 or --external-ip-map 12.34.56.200,12.34.56.78/10.0.0.1,12.34.56.80/eth1 `+
|
||||||
|
`or --external-ip-map ""`)
|
||||||
|
cmd.PersistentFlags().StringVar(&flags.CustomDNSAddress, dnsResolverAddress, "",
|
||||||
|
`Sets a custom address for NetBird's local DNS resolver. `+
|
||||||
|
`If set, the agent won't attempt to discover the best ip and port to listen on. `+
|
||||||
|
`An empty string "" clears the previous configuration. `+
|
||||||
|
`E.g. --dns-resolver-address 127.0.0.1:5053 or --dns-resolver-address ""`)
|
||||||
|
cmd.PersistentFlags().StringSliceVar(&flags.ExtraIFaceBlackList, extraIFaceBlackListFlag, nil,
|
||||||
|
"Extra list of default interfaces to ignore for listening")
|
||||||
|
cmd.PersistentFlags().StringSliceVar(&flags.DNSLabels, dnsLabelsFlag, nil,
|
||||||
|
`Sets DNS labels. `+
|
||||||
|
`You can specify a comma-separated list of up to 32 labels. `+
|
||||||
|
`An empty string "" clears the previous configuration. `+
|
||||||
|
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
|
||||||
|
`or --extra-dns-labels ""`)
|
||||||
|
cmd.PersistentFlags().DurationVar(&flags.DNSRouteInterval, dnsRouteIntervalFlag, time.Minute,
|
||||||
|
"DNS route update interval")
|
||||||
|
|
||||||
|
// Feature flags
|
||||||
|
cmd.PersistentFlags().BoolVar(&flags.RosenpassEnabled, enableRosenpassFlag, false,
|
||||||
|
"[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
|
||||||
|
cmd.PersistentFlags().BoolVar(&flags.RosenpassPermissive, rosenpassPermissiveFlag, false,
|
||||||
|
"[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||||
|
cmd.PersistentFlags().BoolVar(&flags.ServerSSHAllowed, serverSSHAllowedFlag, false,
|
||||||
|
"Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||||
|
cmd.PersistentFlags().BoolVar(&flags.AutoConnectDisabled, disableAutoConnectFlag, false,
|
||||||
|
"Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||||
|
cmd.PersistentFlags().BoolVarP(&flags.NetworkMonitor, networkMonitorFlag, "N", networkMonitor,
|
||||||
|
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
|
||||||
|
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`)
|
||||||
|
cmd.PersistentFlags().BoolVar(&flags.LazyConnEnabled, enableLazyConnectionFlag, false,
|
||||||
|
"[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.")
|
||||||
|
|
||||||
|
// System flags (from system.go)
|
||||||
|
cmd.PersistentFlags().BoolVar(&flags.DisableClientRoutes, disableClientRoutesFlag, false,
|
||||||
|
"Disable client routes. If enabled, the client won't process client routes received from the management service.")
|
||||||
|
cmd.PersistentFlags().BoolVar(&flags.DisableServerRoutes, disableServerRoutesFlag, false,
|
||||||
|
"Disable server routes. If enabled, the client won't act as a router for server routes received from the management service.")
|
||||||
|
cmd.PersistentFlags().BoolVar(&flags.DisableDNS, disableDNSFlag, false,
|
||||||
|
"Disable DNS. If enabled, the client won't configure DNS settings.")
|
||||||
|
cmd.PersistentFlags().BoolVar(&flags.DisableFirewall, disableFirewallFlag, false,
|
||||||
|
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
|
||||||
|
cmd.PersistentFlags().BoolVar(&flags.BlockLANAccess, blockLANAccessFlag, false,
|
||||||
|
"Block access to local networks (LAN) when using this peer as a router or exit node")
|
||||||
|
cmd.PersistentFlags().BoolVar(&flags.BlockInbound, blockInboundFlag, false,
|
||||||
|
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
|
||||||
|
"This overrides any policies received from the management service.")
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddUpOnlyFlags adds flags that are specific to the up command
|
||||||
|
func AddUpOnlyFlags(cmd *cobra.Command, flags *SharedFlags) {
|
||||||
|
cmd.PersistentFlags().BoolVar(&flags.NoBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildConfigInput creates an internal.ConfigInput from SharedFlags with Changed() checks
|
||||||
|
func BuildConfigInput(cmd *cobra.Command, flags *SharedFlags, customDNSAddressConverted []byte) (*internal.ConfigInput, error) {
|
||||||
|
ic := internal.ConfigInput{
|
||||||
|
ManagementURL: managementURL,
|
||||||
|
AdminURL: adminURL,
|
||||||
|
ConfigPath: configPath,
|
||||||
|
CustomDNSAddress: customDNSAddressConverted,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle PreSharedKey from root command
|
||||||
|
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||||
|
ic.PreSharedKey = &preSharedKey
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||||
|
ic.RosenpassEnabled = &flags.RosenpassEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(rosenpassPermissiveFlag).Changed {
|
||||||
|
ic.RosenpassPermissive = &flags.RosenpassPermissive
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||||
|
ic.ServerSSHAllowed = &flags.ServerSSHAllowed
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
|
if err := parseInterfaceName(flags.InterfaceName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ic.InterfaceName = &flags.InterfaceName
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(wireguardPortFlag).Changed {
|
||||||
|
p := int(flags.WireguardPort)
|
||||||
|
ic.WireguardPort = &p
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(networkMonitorFlag).Changed {
|
||||||
|
ic.NetworkMonitor = &flags.NetworkMonitor
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||||
|
ic.DisableAutoConnect = &flags.AutoConnectDisabled
|
||||||
|
|
||||||
|
if flags.AutoConnectDisabled {
|
||||||
|
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
|
||||||
|
} else {
|
||||||
|
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(dnsRouteIntervalFlag).Changed {
|
||||||
|
ic.DNSRouteInterval = &flags.DNSRouteInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableClientRoutesFlag).Changed {
|
||||||
|
ic.DisableClientRoutes = &flags.DisableClientRoutes
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableServerRoutesFlag).Changed {
|
||||||
|
ic.DisableServerRoutes = &flags.DisableServerRoutes
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableDNSFlag).Changed {
|
||||||
|
ic.DisableDNS = &flags.DisableDNS
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableFirewallFlag).Changed {
|
||||||
|
ic.DisableFirewall = &flags.DisableFirewall
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||||
|
ic.BlockLANAccess = &flags.BlockLANAccess
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(blockInboundFlag).Changed {
|
||||||
|
ic.BlockInbound = &flags.BlockInbound
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||||
|
ic.LazyConnectionEnabled = &flags.LazyConnEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(externalIPMapFlag).Changed {
|
||||||
|
ic.NATExternalIPs = flags.NATExternalIPs
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(extraIFaceBlackListFlag).Changed {
|
||||||
|
ic.ExtraIFaceBlackList = flags.ExtraIFaceBlackList
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(dnsLabelsFlag).Changed {
|
||||||
|
var err error
|
||||||
|
ic.DNSLabels, err = domain.FromStringList(flags.DNSLabels)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid DNS labels: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ic, nil
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -100,7 +101,7 @@ var loginCmd = &cobra.Command{
|
|||||||
loginRequest := proto.LoginRequest{
|
loginRequest := proto.LoginRequest{
|
||||||
SetupKey: providedSetupKey,
|
SetupKey: providedSetupKey,
|
||||||
ManagementUrl: managementURL,
|
ManagementUrl: managementURL,
|
||||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
IsUnixDesktopClient: isUnixRunningDesktop(),
|
||||||
Hostname: hostName,
|
Hostname: hostName,
|
||||||
DnsLabels: dnsLabelsReq,
|
DnsLabels: dnsLabelsReq,
|
||||||
}
|
}
|
||||||
@@ -195,7 +196,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
|
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop())
|
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -243,7 +244,10 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
|
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
||||||
func isLinuxRunningDesktop() bool {
|
func isUnixRunningDesktop() bool {
|
||||||
|
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ const (
|
|||||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||||
dnsRouteIntervalFlag = "dns-router-interval"
|
dnsRouteIntervalFlag = "dns-router-interval"
|
||||||
systemInfoFlag = "system-info"
|
systemInfoFlag = "system-info"
|
||||||
blockLANAccessFlag = "block-lan-access"
|
enableLazyConnectionFlag = "enable-lazy-connection"
|
||||||
uploadBundle = "upload-bundle"
|
uploadBundle = "upload-bundle"
|
||||||
uploadBundleURL = "upload-bundle-url"
|
uploadBundleURL = "upload-bundle-url"
|
||||||
)
|
)
|
||||||
@@ -77,9 +77,9 @@ var (
|
|||||||
anonymizeFlag bool
|
anonymizeFlag bool
|
||||||
debugSystemInfoFlag bool
|
debugSystemInfoFlag bool
|
||||||
dnsRouteInterval time.Duration
|
dnsRouteInterval time.Duration
|
||||||
blockLANAccess bool
|
|
||||||
debugUploadBundle bool
|
debugUploadBundle bool
|
||||||
debugUploadBundleURL string
|
debugUploadBundleURL string
|
||||||
|
lazyConnEnabled bool
|
||||||
|
|
||||||
rootCmd = &cobra.Command{
|
rootCmd = &cobra.Command{
|
||||||
Use: "netbird",
|
Use: "netbird",
|
||||||
@@ -149,6 +149,7 @@ func init() {
|
|||||||
rootCmd.AddCommand(loginCmd)
|
rootCmd.AddCommand(loginCmd)
|
||||||
rootCmd.AddCommand(versionCmd)
|
rootCmd.AddCommand(versionCmd)
|
||||||
rootCmd.AddCommand(sshCmd)
|
rootCmd.AddCommand(sshCmd)
|
||||||
|
rootCmd.AddCommand(setCmd)
|
||||||
rootCmd.AddCommand(networksCMD)
|
rootCmd.AddCommand(networksCMD)
|
||||||
rootCmd.AddCommand(forwardingRulesCmd)
|
rootCmd.AddCommand(forwardingRulesCmd)
|
||||||
rootCmd.AddCommand(debugCmd)
|
rootCmd.AddCommand(debugCmd)
|
||||||
@@ -167,23 +168,6 @@ func init() {
|
|||||||
debugCmd.AddCommand(forCmd)
|
debugCmd.AddCommand(forCmd)
|
||||||
debugCmd.AddCommand(persistenceCmd)
|
debugCmd.AddCommand(persistenceCmd)
|
||||||
|
|
||||||
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
|
|
||||||
`Sets external IPs maps between local addresses and interfaces.`+
|
|
||||||
`You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+
|
|
||||||
`An empty string "" clears the previous configuration. `+
|
|
||||||
`E.g. --external-ip-map 12.34.56.78/10.0.0.1 or --external-ip-map 12.34.56.200,12.34.56.78/10.0.0.1,12.34.56.80/eth1 `+
|
|
||||||
`or --external-ip-map ""`,
|
|
||||||
)
|
|
||||||
upCmd.PersistentFlags().StringVar(&customDNSAddress, dnsResolverAddress, "",
|
|
||||||
`Sets a custom address for NetBird's local DNS resolver. `+
|
|
||||||
`If set, the agent won't attempt to discover the best ip and port to listen on. `+
|
|
||||||
`An empty string "" clears the previous configuration. `+
|
|
||||||
`E.g. --dns-resolver-address 127.0.0.1:5053 or --dns-resolver-address ""`,
|
|
||||||
)
|
|
||||||
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
|
|
||||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
|
||||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
|
||||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
|
||||||
|
|
||||||
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
|
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
|
||||||
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
|
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/kardianos/service"
|
"github.com/kardianos/service"
|
||||||
@@ -27,12 +28,19 @@ func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newSVCConfig() *service.Config {
|
func newSVCConfig() *service.Config {
|
||||||
return &service.Config{
|
config := &service.Config{
|
||||||
Name: serviceName,
|
Name: serviceName,
|
||||||
DisplayName: "Netbird",
|
DisplayName: "Netbird",
|
||||||
Description: "A WireGuard-based mesh network that connects your devices into a single private network.",
|
Description: "Netbird mesh network client",
|
||||||
Option: make(service.KeyValue),
|
Option: make(service.KeyValue),
|
||||||
|
EnvVars: make(map[string]string),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
config.EnvVars["SYSTEMD_UNIT"] = serviceName
|
||||||
|
}
|
||||||
|
|
||||||
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSVC(prg *program, conf *service.Config) (service.Service, error) {
|
func newSVC(prg *program, conf *service.Config) (service.Service, error) {
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ var installCmd = &cobra.Command{
|
|||||||
svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL)
|
svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
if logFile != "console" {
|
if logFile != "" {
|
||||||
svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile)
|
svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
161
client/cmd/set.go
Normal file
161
client/cmd/set.go
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"google.golang.org/protobuf/types/known/durationpb"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
setFlags = &SharedFlags{}
|
||||||
|
|
||||||
|
setCmd = &cobra.Command{
|
||||||
|
Use: "set",
|
||||||
|
Short: "Update NetBird client configuration",
|
||||||
|
Long: `Update NetBird client configuration without connecting. Uses the same flags as 'netbird up' but only updates the configuration file.`,
|
||||||
|
RunE: setFunc,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// Add all shared flags to the set command
|
||||||
|
AddSharedFlags(setCmd, setFlags)
|
||||||
|
// Note: We don't add up-only flags like --no-browser to set command
|
||||||
|
}
|
||||||
|
|
||||||
|
func setFunc(cmd *cobra.Command, _ []string) error {
|
||||||
|
SetFlagsFromEnvVars(rootCmd)
|
||||||
|
SetFlagsFromEnvVars(cmd)
|
||||||
|
|
||||||
|
cmd.SetOut(cmd.OutOrStdout())
|
||||||
|
|
||||||
|
// Validate inputs (reuse validation logic from up.go)
|
||||||
|
if err := validateNATExternalIPs(setFlags.NATExternalIPs); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(dnsLabelsFlag).Changed {
|
||||||
|
if _, err := validateDnsLabels(setFlags.DNSLabels); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var customDNSAddressConverted []byte
|
||||||
|
if cmd.Flag(dnsResolverAddress).Changed {
|
||||||
|
var err error
|
||||||
|
customDNSAddressConverted, err = parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect to daemon
|
||||||
|
ctx := cmd.Context()
|
||||||
|
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("connect to daemon: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if closeErr := conn.Close(); closeErr != nil {
|
||||||
|
fmt.Printf("Warning: failed to close connection: %v\n", closeErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
req := &proto.SetConfigRequest{}
|
||||||
|
|
||||||
|
// Set fields based on changed flags
|
||||||
|
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||||
|
req.RosenpassEnabled = &setFlags.RosenpassEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(rosenpassPermissiveFlag).Changed {
|
||||||
|
req.RosenpassPermissive = &setFlags.RosenpassPermissive
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||||
|
req.ServerSSHAllowed = &setFlags.ServerSSHAllowed
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||||
|
req.DisableAutoConnect = &setFlags.AutoConnectDisabled
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(networkMonitorFlag).Changed {
|
||||||
|
req.NetworkMonitor = &setFlags.NetworkMonitor
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
|
if err := parseInterfaceName(setFlags.InterfaceName); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.InterfaceName = &setFlags.InterfaceName
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(wireguardPortFlag).Changed {
|
||||||
|
port := int64(setFlags.WireguardPort)
|
||||||
|
req.WireguardPort = &port
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(dnsResolverAddress).Changed {
|
||||||
|
customAddr := string(customDNSAddressConverted)
|
||||||
|
req.CustomDNSAddress = &customAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(extraIFaceBlackListFlag).Changed {
|
||||||
|
req.ExtraIFaceBlacklist = setFlags.ExtraIFaceBlackList
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(dnsLabelsFlag).Changed {
|
||||||
|
req.DnsLabels = setFlags.DNSLabels
|
||||||
|
req.CleanDNSLabels = &[]bool{setFlags.DNSLabels != nil && len(setFlags.DNSLabels) == 0}[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(externalIPMapFlag).Changed {
|
||||||
|
req.NatExternalIPs = setFlags.NATExternalIPs
|
||||||
|
req.CleanNATExternalIPs = &[]bool{setFlags.NATExternalIPs != nil && len(setFlags.NATExternalIPs) == 0}[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(dnsRouteIntervalFlag).Changed {
|
||||||
|
req.DnsRouteInterval = durationpb.New(setFlags.DNSRouteInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableClientRoutesFlag).Changed {
|
||||||
|
req.DisableClientRoutes = &setFlags.DisableClientRoutes
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableServerRoutesFlag).Changed {
|
||||||
|
req.DisableServerRoutes = &setFlags.DisableServerRoutes
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableDNSFlag).Changed {
|
||||||
|
req.DisableDns = &setFlags.DisableDNS
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableFirewallFlag).Changed {
|
||||||
|
req.DisableFirewall = &setFlags.DisableFirewall
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||||
|
req.BlockLanAccess = &setFlags.BlockLANAccess
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(blockInboundFlag).Changed {
|
||||||
|
req.BlockInbound = &setFlags.BlockInbound
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||||
|
req.LazyConnectionEnabled = &setFlags.LazyConnEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the request
|
||||||
|
if _, err := client.SetConfig(ctx, req); err != nil {
|
||||||
|
return fmt.Errorf("update configuration: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Println("Configuration updated successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
110
client/cmd/set_test.go
Normal file
110
client/cmd/set_test.go
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseBoolArg(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected bool
|
||||||
|
hasError bool
|
||||||
|
}{
|
||||||
|
{"true", true, false},
|
||||||
|
{"True", true, false},
|
||||||
|
{"1", true, false},
|
||||||
|
{"yes", true, false},
|
||||||
|
{"on", true, false},
|
||||||
|
{"false", false, false},
|
||||||
|
{"False", false, false},
|
||||||
|
{"0", false, false},
|
||||||
|
{"no", false, false},
|
||||||
|
{"off", false, false},
|
||||||
|
{"invalid", false, true},
|
||||||
|
{"maybe", false, true},
|
||||||
|
{"", false, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.input, func(t *testing.T) {
|
||||||
|
result, err := parseBoolArg(test.input)
|
||||||
|
|
||||||
|
if test.hasError {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error for input %q, but got none", test.input)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error for input %q: %v", test.input, err)
|
||||||
|
}
|
||||||
|
if result != test.expected {
|
||||||
|
t.Errorf("For input %q, expected %v but got %v", test.input, test.expected, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetCommandStructure(t *testing.T) {
|
||||||
|
// Test that the set command has the expected subcommands
|
||||||
|
expectedSubcommands := []string{
|
||||||
|
"autoconnect",
|
||||||
|
"ssh-server",
|
||||||
|
"network-monitor",
|
||||||
|
"rosenpass",
|
||||||
|
"dns",
|
||||||
|
"dns-interval",
|
||||||
|
}
|
||||||
|
|
||||||
|
actualSubcommands := make([]string, 0, len(setCmd.Commands()))
|
||||||
|
for _, cmd := range setCmd.Commands() {
|
||||||
|
actualSubcommands = append(actualSubcommands, cmd.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(actualSubcommands) != len(expectedSubcommands) {
|
||||||
|
t.Errorf("Expected %d subcommands, got %d", len(expectedSubcommands), len(actualSubcommands))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range expectedSubcommands {
|
||||||
|
found := false
|
||||||
|
for _, actual := range actualSubcommands {
|
||||||
|
if actual == expected {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Errorf("Expected subcommand %q not found", expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetCommandUsage(t *testing.T) {
|
||||||
|
if setCmd.Use != "set" {
|
||||||
|
t.Errorf("Expected command use to be 'set', got %q", setCmd.Use)
|
||||||
|
}
|
||||||
|
|
||||||
|
if setCmd.Short != "Set NetBird client configuration" {
|
||||||
|
t.Errorf("Expected short description to be 'Set NetBird client configuration', got %q", setCmd.Short)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubcommandArgRequirements(t *testing.T) {
|
||||||
|
// Test that all subcommands except dns-interval require exactly 1 argument
|
||||||
|
subcommands := []*cobra.Command{
|
||||||
|
setAutoconnectCmd,
|
||||||
|
setSSHServerCmd,
|
||||||
|
setNetworkMonitorCmd,
|
||||||
|
setRosenpassCmd,
|
||||||
|
setDNSCmd,
|
||||||
|
setDNSIntervalCmd,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cmd := range subcommands {
|
||||||
|
if cmd.Args == nil {
|
||||||
|
t.Errorf("Command %q should have Args validation", cmd.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -44,7 +44,7 @@ func init() {
|
|||||||
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
|
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
|
||||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
||||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
|
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||||
}
|
}
|
||||||
|
|
||||||
func statusFunc(cmd *cobra.Command, args []string) error {
|
func statusFunc(cmd *cobra.Command, args []string) error {
|
||||||
@@ -69,7 +69,10 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
|
status := resp.GetStatus()
|
||||||
|
|
||||||
|
if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
|
||||||
|
status == string(internal.StatusSessionExpired) {
|
||||||
cmd.Printf("Daemon status: %s\n\n"+
|
cmd.Printf("Daemon status: %s\n\n"+
|
||||||
"Run UP command to log in with SSO (interactive login):\n\n"+
|
"Run UP command to log in with SSO (interactive login):\n\n"+
|
||||||
" netbird up \n\n"+
|
" netbird up \n\n"+
|
||||||
@@ -117,7 +120,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
|||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
||||||
}
|
}
|
||||||
@@ -127,12 +130,12 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
|||||||
|
|
||||||
func parseFilters() error {
|
func parseFilters() error {
|
||||||
switch strings.ToLower(statusFilter) {
|
switch strings.ToLower(statusFilter) {
|
||||||
case "", "disconnected", "connected":
|
case "", "idle", "connecting", "connected":
|
||||||
if strings.ToLower(statusFilter) != "" {
|
if strings.ToLower(statusFilter) != "" {
|
||||||
enableDetailFlagWhenFilterFlag()
|
enableDetailFlagWhenFilterFlag()
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter)
|
return fmt.Errorf("wrong status filter, should be one of connected|connecting|idle, got: %s", statusFilter)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(ipsFilter) > 0 {
|
if len(ipsFilter) > 0 {
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ const (
|
|||||||
disableServerRoutesFlag = "disable-server-routes"
|
disableServerRoutesFlag = "disable-server-routes"
|
||||||
disableDNSFlag = "disable-dns"
|
disableDNSFlag = "disable-dns"
|
||||||
disableFirewallFlag = "disable-firewall"
|
disableFirewallFlag = "disable-firewall"
|
||||||
|
blockLANAccessFlag = "block-lan-access"
|
||||||
|
blockInboundFlag = "block-inbound"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -13,19 +15,7 @@ var (
|
|||||||
disableServerRoutes bool
|
disableServerRoutes bool
|
||||||
disableDNS bool
|
disableDNS bool
|
||||||
disableFirewall bool
|
disableFirewall bool
|
||||||
|
blockLANAccess bool
|
||||||
|
blockInbound bool
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
|
||||||
// Add system flags to upCmd
|
|
||||||
upCmd.PersistentFlags().BoolVar(&disableClientRoutes, disableClientRoutesFlag, false,
|
|
||||||
"Disable client routes. If enabled, the client won't process client routes received from the management service.")
|
|
||||||
|
|
||||||
upCmd.PersistentFlags().BoolVar(&disableServerRoutes, disableServerRoutesFlag, false,
|
|
||||||
"Disable server routes. If enabled, the client won't act as a router for server routes received from the management service.")
|
|
||||||
|
|
||||||
upCmd.PersistentFlags().BoolVar(&disableDNS, disableDNSFlag, false,
|
|
||||||
"Disable DNS. If enabled, the client won't configure DNS settings.")
|
|
||||||
|
|
||||||
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
|
|
||||||
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
|
|||||||
Return(&types.Settings{}, nil).
|
Return(&types.Settings{}, nil).
|
||||||
AnyTimes()
|
AnyTimes()
|
||||||
|
|
||||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ var traceCmd = &cobra.Command{
|
|||||||
Example: `
|
Example: `
|
||||||
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
|
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
|
||||||
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
|
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
|
||||||
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
|
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --icmp-type 8 --icmp-code 0
|
||||||
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
|
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
|
||||||
Args: cobra.ExactArgs(3),
|
Args: cobra.ExactArgs(3),
|
||||||
RunE: tracePacket,
|
RunE: tracePacket,
|
||||||
@@ -118,7 +118,7 @@ func tracePacket(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
|
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
|
||||||
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
|
cmd.Printf("Packet trace %s:%d → %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
|
||||||
|
|
||||||
for _, stage := range resp.Stages {
|
for _, stage := range resp.Stages {
|
||||||
if stage.ForwardingDetails != nil {
|
if stage.ForwardingDetails != nil {
|
||||||
|
|||||||
200
client/cmd/up.go
200
client/cmd/up.go
@@ -7,7 +7,6 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
@@ -15,7 +14,6 @@ import (
|
|||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
"google.golang.org/protobuf/types/known/durationpb"
|
"google.golang.org/protobuf/types/known/durationpb"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
@@ -42,6 +40,7 @@ var (
|
|||||||
dnsLabels []string
|
dnsLabels []string
|
||||||
dnsLabelsValidated domain.List
|
dnsLabelsValidated domain.List
|
||||||
noBrowser bool
|
noBrowser bool
|
||||||
|
upFlags = &SharedFlags{}
|
||||||
|
|
||||||
upCmd = &cobra.Command{
|
upCmd = &cobra.Command{
|
||||||
Use: "up",
|
Use: "up",
|
||||||
@@ -51,27 +50,12 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
// Add shared flags to up command
|
||||||
|
AddSharedFlags(upCmd, upFlags)
|
||||||
|
|
||||||
|
// Add up-specific flags
|
||||||
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
|
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
|
||||||
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
AddUpOnlyFlags(upCmd, upFlags)
|
||||||
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
|
||||||
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
|
|
||||||
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+
|
|
||||||
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
|
|
||||||
)
|
|
||||||
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
|
||||||
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
|
||||||
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
|
|
||||||
|
|
||||||
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
|
|
||||||
`Sets DNS labels`+
|
|
||||||
`You can specify a comma-separated list of up to 32 labels. `+
|
|
||||||
`An empty string "" clears the previous configuration. `+
|
|
||||||
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
|
|
||||||
`or --extra-dns-labels ""`,
|
|
||||||
)
|
|
||||||
|
|
||||||
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func upFunc(cmd *cobra.Command, args []string) error {
|
func upFunc(cmd *cobra.Command, args []string) error {
|
||||||
@@ -119,79 +103,18 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
ic := internal.ConfigInput{
|
// Handle DNS labels validation and assignment to SharedFlags
|
||||||
ManagementURL: managementURL,
|
if cmd.Flag(dnsLabelsFlag).Changed {
|
||||||
AdminURL: adminURL,
|
var err error
|
||||||
ConfigPath: configPath,
|
dnsLabelsValidated, err = validateDnsLabels(upFlags.DNSLabels)
|
||||||
NATExternalIPs: natExternalIPs,
|
if err != nil {
|
||||||
CustomDNSAddress: customDNSAddressConverted,
|
|
||||||
ExtraIFaceBlackList: extraIFaceBlackList,
|
|
||||||
DNSLabels: dnsLabelsValidated,
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
|
||||||
ic.RosenpassEnabled = &rosenpassEnabled
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(rosenpassPermissiveFlag).Changed {
|
|
||||||
ic.RosenpassPermissive = &rosenpassPermissive
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
|
||||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
|
||||||
if err := parseInterfaceName(interfaceName); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ic.InterfaceName = &interfaceName
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmd.Flag(wireguardPortFlag).Changed {
|
ic, err := BuildConfigInput(cmd, upFlags, customDNSAddressConverted)
|
||||||
p := int(wireguardPort)
|
if err != nil {
|
||||||
ic.WireguardPort = &p
|
return fmt.Errorf("setup config: %v", err)
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(networkMonitorFlag).Changed {
|
|
||||||
ic.NetworkMonitor = &networkMonitor
|
|
||||||
}
|
|
||||||
|
|
||||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
|
||||||
ic.PreSharedKey = &preSharedKey
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
|
||||||
ic.DisableAutoConnect = &autoConnectDisabled
|
|
||||||
|
|
||||||
if autoConnectDisabled {
|
|
||||||
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !autoConnectDisabled {
|
|
||||||
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(dnsRouteIntervalFlag).Changed {
|
|
||||||
ic.DNSRouteInterval = &dnsRouteInterval
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(disableClientRoutesFlag).Changed {
|
|
||||||
ic.DisableClientRoutes = &disableClientRoutes
|
|
||||||
}
|
|
||||||
if cmd.Flag(disableServerRoutesFlag).Changed {
|
|
||||||
ic.DisableServerRoutes = &disableServerRoutes
|
|
||||||
}
|
|
||||||
if cmd.Flag(disableDNSFlag).Changed {
|
|
||||||
ic.DisableDNS = &disableDNS
|
|
||||||
}
|
|
||||||
if cmd.Flag(disableFirewallFlag).Changed {
|
|
||||||
ic.DisableFirewall = &disableFirewall
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd.Flag(blockLANAccessFlag).Changed {
|
|
||||||
ic.BlockLANAccess = &blockLANAccess
|
|
||||||
}
|
}
|
||||||
|
|
||||||
providedSetupKey, err := getSetupKey()
|
providedSetupKey, err := getSetupKey()
|
||||||
@@ -199,7 +122,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := internal.UpdateOrCreateConfig(ic)
|
config, err := internal.UpdateOrCreateConfig(*ic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get config file: %v", err)
|
return fmt.Errorf("get config file: %v", err)
|
||||||
}
|
}
|
||||||
@@ -258,9 +181,55 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
|
|
||||||
providedSetupKey, err := getSetupKey()
|
providedSetupKey, err := getSetupKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("get setup key: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
loginRequest, err := setupLoginRequest(providedSetupKey, customDNSAddressConverted, cmd)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("setup login request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var loginErr error
|
||||||
|
var loginResp *proto.LoginResponse
|
||||||
|
|
||||||
|
err = WithBackOff(func() error {
|
||||||
|
var backOffErr error
|
||||||
|
loginResp, backOffErr = client.Login(ctx, loginRequest)
|
||||||
|
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
|
||||||
|
s.Code() == codes.PermissionDenied ||
|
||||||
|
s.Code() == codes.NotFound ||
|
||||||
|
s.Code() == codes.Unimplemented) {
|
||||||
|
loginErr = backOffErr
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return backOffErr
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("login backoff cycle failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if loginErr != nil {
|
||||||
|
return fmt.Errorf("login failed: %v", loginErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if loginResp.NeedsSSOLogin {
|
||||||
|
|
||||||
|
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||||
|
|
||||||
|
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("waiting sso login failed with: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
|
||||||
|
return fmt.Errorf("call service up method: %v", err)
|
||||||
|
}
|
||||||
|
cmd.Println("Connected")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
|
||||||
loginRequest := proto.LoginRequest{
|
loginRequest := proto.LoginRequest{
|
||||||
SetupKey: providedSetupKey,
|
SetupKey: providedSetupKey,
|
||||||
ManagementUrl: managementURL,
|
ManagementUrl: managementURL,
|
||||||
@@ -268,7 +237,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
NatExternalIPs: natExternalIPs,
|
NatExternalIPs: natExternalIPs,
|
||||||
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
||||||
CustomDNSAddress: customDNSAddressConverted,
|
CustomDNSAddress: customDNSAddressConverted,
|
||||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
IsUnixDesktopClient: isUnixRunningDesktop(),
|
||||||
Hostname: hostName,
|
Hostname: hostName,
|
||||||
ExtraIFaceBlacklist: extraIFaceBlackList,
|
ExtraIFaceBlacklist: extraIFaceBlackList,
|
||||||
DnsLabels: dnsLabels,
|
DnsLabels: dnsLabels,
|
||||||
@@ -297,7 +266,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
|
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
if err := parseInterfaceName(interfaceName); err != nil {
|
if err := parseInterfaceName(interfaceName); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
loginRequest.InterfaceName = &interfaceName
|
loginRequest.InterfaceName = &interfaceName
|
||||||
}
|
}
|
||||||
@@ -332,45 +301,14 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
loginRequest.BlockLanAccess = &blockLANAccess
|
loginRequest.BlockLanAccess = &blockLANAccess
|
||||||
}
|
}
|
||||||
|
|
||||||
var loginErr error
|
if cmd.Flag(blockInboundFlag).Changed {
|
||||||
|
loginRequest.BlockInbound = &blockInbound
|
||||||
var loginResp *proto.LoginResponse
|
|
||||||
|
|
||||||
err = WithBackOff(func() error {
|
|
||||||
var backOffErr error
|
|
||||||
loginResp, backOffErr = client.Login(ctx, &loginRequest)
|
|
||||||
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
|
|
||||||
s.Code() == codes.PermissionDenied ||
|
|
||||||
s.Code() == codes.NotFound ||
|
|
||||||
s.Code() == codes.Unimplemented) {
|
|
||||||
loginErr = backOffErr
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return backOffErr
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("login backoff cycle failed: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if loginErr != nil {
|
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||||
return fmt.Errorf("login failed: %v", loginErr)
|
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
|
||||||
}
|
}
|
||||||
|
return &loginRequest, nil
|
||||||
if loginResp.NeedsSSOLogin {
|
|
||||||
|
|
||||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
|
||||||
|
|
||||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("waiting sso login failed with: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
|
|
||||||
return fmt.Errorf("call service up method: %v", err)
|
|
||||||
}
|
|
||||||
cmd.Println("Connected")
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateNATExternalIPs(list []string) error {
|
func validateNATExternalIPs(list []string) error {
|
||||||
|
|||||||
@@ -147,6 +147,10 @@ func (m *Manager) IsServerRouteSupported() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) IsStateful() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
@@ -198,7 +202,7 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
_, err := m.AddPeerFiltering(
|
_, err := m.AddPeerFiltering(
|
||||||
nil,
|
nil,
|
||||||
net.IP{0, 0, 0, 0},
|
net.IP{0, 0, 0, 0},
|
||||||
"all",
|
firewall.ProtocolALL,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
firewall.ActionAccept,
|
firewall.ActionAccept,
|
||||||
@@ -219,10 +223,16 @@ func (m *Manager) SetLogLevel(log.Level) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) EnableRouting() error {
|
func (m *Manager) EnableRouting() error {
|
||||||
|
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||||
|
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DisableRouting() error {
|
func (m *Manager) DisableRouting() error {
|
||||||
|
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
|
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package iptables
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -19,11 +19,8 @@ var ifaceMock = &iFaceMock{
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
IP: netip.MustParseAddr("10.20.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -70,12 +67,12 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
|
|
||||||
var rule2 []fw.Rule
|
var rule2 []fw.Rule
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := netip.MustParseAddr("10.20.0.3")
|
||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
Values: []uint16{8043, 8046},
|
Values: []uint16{8043, 8046},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
|
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
@@ -95,9 +92,9 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("reset check", func(t *testing.T) {
|
t.Run("reset check", func(t *testing.T) {
|
||||||
// add second rule
|
// add second rule
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := netip.MustParseAddr("10.20.0.3")
|
||||||
port := &fw.Port{Values: []uint16{5353}}
|
port := &fw.Port{Values: []uint16{5353}}
|
||||||
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Close(nil)
|
err = manager.Close(nil)
|
||||||
@@ -119,11 +116,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
IP: netip.MustParseAddr("10.20.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -144,11 +138,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
|
|
||||||
var rule2 []fw.Rule
|
var rule2 []fw.Rule
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := netip.MustParseAddr("10.20.0.3")
|
||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
Values: []uint16{443},
|
Values: []uint16{443},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
|
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||||
@@ -186,11 +180,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
IP: netip.MustParseAddr("10.20.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -212,11 +203,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ip := net.ParseIP("10.20.0.100")
|
ip := netip.MustParseAddr("10.20.0.100")
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -248,10 +248,6 @@ func (r *router) deleteIpSet(setName string) error {
|
|||||||
|
|
||||||
// AddNatRule inserts an iptables rule pair into the nat chain
|
// AddNatRule inserts an iptables rule pair into the nat chain
|
||||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.legacyManagement {
|
if r.legacyManagement {
|
||||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||||
@@ -278,10 +274,6 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
|
|
||||||
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
||||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
|
||||||
log.Errorf("%v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if pair.Masquerade {
|
if pair.Masquerade {
|
||||||
if err := r.removeNatRule(pair); err != nil {
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
return fmt.Errorf("remove nat rule: %w", err)
|
return fmt.Errorf("remove nat rule: %w", err)
|
||||||
|
|||||||
@@ -116,6 +116,8 @@ type Manager interface {
|
|||||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||||
IsServerRouteSupported() bool
|
IsServerRouteSupported() bool
|
||||||
|
|
||||||
|
IsStateful() bool
|
||||||
|
|
||||||
AddRouteFiltering(
|
AddRouteFiltering(
|
||||||
id []byte,
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
|
|||||||
@@ -170,6 +170,10 @@ func (m *Manager) IsServerRouteSupported() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) IsStateful() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
@@ -324,10 +328,16 @@ func (m *Manager) SetLogLevel(log.Level) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) EnableRouting() error {
|
func (m *Manager) EnableRouting() error {
|
||||||
|
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||||
|
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DisableRouting() error {
|
func (m *Manager) DisableRouting() error {
|
||||||
|
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
|
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package nftables
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -25,11 +24,8 @@ var ifaceMock = &iFaceMock{
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.96.0.1"),
|
IP: netip.MustParseAddr("100.96.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
||||||
IP: net.ParseIP("100.96.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -70,11 +66,11 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
ip := net.ParseIP("100.96.0.1")
|
ip := netip.MustParseAddr("100.96.0.1").Unmap()
|
||||||
|
|
||||||
testClient := &nftables.Conn{}
|
testClient := &nftables.Conn{}
|
||||||
|
|
||||||
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Flush()
|
err = manager.Flush()
|
||||||
@@ -109,8 +105,6 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
}
|
}
|
||||||
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
|
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
|
||||||
|
|
||||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
|
||||||
add := ipToAdd.Unmap()
|
|
||||||
expectedExprs2 := []expr.Any{
|
expectedExprs2 := []expr.Any{
|
||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
@@ -132,7 +126,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
Op: expr.CmpOpEq,
|
Op: expr.CmpOpEq,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: add.AsSlice(),
|
Data: ip.AsSlice(),
|
||||||
},
|
},
|
||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
@@ -173,11 +167,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.96.0.1"),
|
IP: netip.MustParseAddr("100.96.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
||||||
IP: net.ParseIP("100.96.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -197,11 +188,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
ip := net.ParseIP("10.20.0.100")
|
ip := netip.MustParseAddr("10.20.0.100")
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
if i%100 == 0 {
|
if i%100 == 0 {
|
||||||
@@ -282,8 +273,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
verifyIptablesOutput(t, stdout, stderr)
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
})
|
})
|
||||||
|
|
||||||
ip := net.ParseIP("100.96.0.1")
|
ip := netip.MustParseAddr("100.96.0.1")
|
||||||
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add peer filtering rule")
|
require.NoError(t, err, "failed to add peer filtering rule")
|
||||||
|
|
||||||
_, err = manager.AddRouteFiltering(
|
_, err = manager.AddRouteFiltering(
|
||||||
|
|||||||
@@ -573,10 +573,6 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
|||||||
|
|
||||||
// AddNatRule appends a nftables rule pair to the nat chain
|
// AddNatRule appends a nftables rule pair to the nat chain
|
||||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
@@ -1006,10 +1002,6 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
|
|||||||
|
|
||||||
// RemoveNatRule removes the prerouting mark rule
|
// RemoveNatRule removes the prerouting mark rule
|
||||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
|
||||||
log.Errorf("%v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -62,5 +62,5 @@ type ConnKey struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c ConnKey) String() string {
|
func (c ConnKey) String() string {
|
||||||
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
return fmt.Sprintf("%s:%d → %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package conntrack
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -19,6 +20,10 @@ const (
|
|||||||
DefaultICMPTimeout = 30 * time.Second
|
DefaultICMPTimeout = 30 * time.Second
|
||||||
// ICMPCleanupInterval is how often we check for stale ICMP connections
|
// ICMPCleanupInterval is how often we check for stale ICMP connections
|
||||||
ICMPCleanupInterval = 15 * time.Second
|
ICMPCleanupInterval = 15 * time.Second
|
||||||
|
|
||||||
|
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info,
|
||||||
|
// which includes the IP header (20 bytes) and transport header (8 bytes)
|
||||||
|
MaxICMPPayloadLength = 28
|
||||||
)
|
)
|
||||||
|
|
||||||
// ICMPConnKey uniquely identifies an ICMP connection
|
// ICMPConnKey uniquely identifies an ICMP connection
|
||||||
@@ -29,7 +34,7 @@ type ICMPConnKey struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (i ICMPConnKey) String() string {
|
func (i ICMPConnKey) String() string {
|
||||||
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
|
return fmt.Sprintf("%s → %s (id %d)", i.SrcIP, i.DstIP, i.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ICMPConnTrack represents an ICMP connection state
|
// ICMPConnTrack represents an ICMP connection state
|
||||||
@@ -50,6 +55,72 @@ type ICMPTracker struct {
|
|||||||
flowLogger nftypes.FlowLogger
|
flowLogger nftypes.FlowLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ICMPInfo holds ICMP type, code, and payload for lazy string formatting in logs
|
||||||
|
type ICMPInfo struct {
|
||||||
|
TypeCode layers.ICMPv4TypeCode
|
||||||
|
PayloadData [MaxICMPPayloadLength]byte
|
||||||
|
// actual length of valid data
|
||||||
|
PayloadLen int
|
||||||
|
}
|
||||||
|
|
||||||
|
// String implements fmt.Stringer for lazy evaluation in log messages
|
||||||
|
func (info ICMPInfo) String() string {
|
||||||
|
if info.isErrorMessage() && info.PayloadLen >= MaxICMPPayloadLength {
|
||||||
|
if origInfo := info.parseOriginalPacket(); origInfo != "" {
|
||||||
|
return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return info.TypeCode.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// isErrorMessage returns true if this ICMP type carries original packet info
|
||||||
|
func (info ICMPInfo) isErrorMessage() bool {
|
||||||
|
typ := info.TypeCode.Type()
|
||||||
|
return typ == 3 || // Destination Unreachable
|
||||||
|
typ == 5 || // Redirect
|
||||||
|
typ == 11 || // Time Exceeded
|
||||||
|
typ == 12 // Parameter Problem
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseOriginalPacket extracts info about the original packet from ICMP payload
|
||||||
|
func (info ICMPInfo) parseOriginalPacket() string {
|
||||||
|
if info.PayloadLen < MaxICMPPayloadLength {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: handle IPv6
|
||||||
|
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
protocol := info.PayloadData[9]
|
||||||
|
srcIP := net.IP(info.PayloadData[12:16])
|
||||||
|
dstIP := net.IP(info.PayloadData[16:20])
|
||||||
|
|
||||||
|
transportData := info.PayloadData[20:]
|
||||||
|
|
||||||
|
switch nftypes.Protocol(protocol) {
|
||||||
|
case nftypes.TCP:
|
||||||
|
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
||||||
|
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
|
||||||
|
return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
|
case nftypes.UDP:
|
||||||
|
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
||||||
|
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
|
||||||
|
return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
|
case nftypes.ICMP:
|
||||||
|
icmpType := transportData[0]
|
||||||
|
icmpCode := transportData[1]
|
||||||
|
return fmt.Sprintf("ICMP %s → %s (type %d code %d)", srcIP, dstIP, icmpType, icmpCode)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("Proto %d %s → %s", protocol, srcIP, dstIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NewICMPTracker creates a new ICMP connection tracker
|
// NewICMPTracker creates a new ICMP connection tracker
|
||||||
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
|
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
|
||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
@@ -93,30 +164,64 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TrackOutbound records an outbound ICMP connection
|
// TrackOutbound records an outbound ICMP connection
|
||||||
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
|
func (t *ICMPTracker) TrackOutbound(
|
||||||
|
srcIP netip.Addr,
|
||||||
|
dstIP netip.Addr,
|
||||||
|
id uint16,
|
||||||
|
typecode layers.ICMPv4TypeCode,
|
||||||
|
payload []byte,
|
||||||
|
size int,
|
||||||
|
) {
|
||||||
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
|
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
|
||||||
// if (inverted direction) conn is not tracked, track this direction
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
|
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, payload, size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TrackInbound records an inbound ICMP Echo Request
|
// TrackInbound records an inbound ICMP Echo Request
|
||||||
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
|
func (t *ICMPTracker) TrackInbound(
|
||||||
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
|
srcIP netip.Addr,
|
||||||
|
dstIP netip.Addr,
|
||||||
|
id uint16,
|
||||||
|
typecode layers.ICMPv4TypeCode,
|
||||||
|
ruleId []byte,
|
||||||
|
payload []byte,
|
||||||
|
size int,
|
||||||
|
) {
|
||||||
|
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, payload, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
// track is the common implementation for tracking both inbound and outbound ICMP connections
|
// track is the common implementation for tracking both inbound and outbound ICMP connections
|
||||||
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
|
func (t *ICMPTracker) track(
|
||||||
|
srcIP netip.Addr,
|
||||||
|
dstIP netip.Addr,
|
||||||
|
id uint16,
|
||||||
|
typecode layers.ICMPv4TypeCode,
|
||||||
|
direction nftypes.Direction,
|
||||||
|
ruleId []byte,
|
||||||
|
payload []byte,
|
||||||
|
size int,
|
||||||
|
) {
|
||||||
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
|
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
|
||||||
if exists {
|
if exists {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
typ, code := typecode.Type(), typecode.Code()
|
typ, code := typecode.Type(), typecode.Code()
|
||||||
|
icmpInfo := ICMPInfo{
|
||||||
|
TypeCode: typecode,
|
||||||
|
}
|
||||||
|
if len(payload) > 0 {
|
||||||
|
icmpInfo.PayloadLen = len(payload)
|
||||||
|
if icmpInfo.PayloadLen > MaxICMPPayloadLength {
|
||||||
|
icmpInfo.PayloadLen = MaxICMPPayloadLength
|
||||||
|
}
|
||||||
|
copy(icmpInfo.PayloadData[:], payload[:icmpInfo.PayloadLen])
|
||||||
|
}
|
||||||
|
|
||||||
// non echo requests don't need tracking
|
// non echo requests don't need tracking
|
||||||
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
||||||
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||||
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
|
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -138,7 +243,7 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec
|
|||||||
t.connections[key] = conn
|
t.connections[key] = conn
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
|
|
||||||
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||||
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ func BenchmarkICMPTracker(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, []byte{}, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -28,7 +28,7 @@ func BenchmarkICMPTracker(b *testing.B) {
|
|||||||
|
|
||||||
// Pre-populate some connections
|
// Pre-populate some connections
|
||||||
for i := 0; i < 1000; i++ {
|
for i := 0; i < 1000; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, []byte{}, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|||||||
@@ -86,5 +86,5 @@ type epID stack.TransportEndpointID
|
|||||||
|
|
||||||
func (i epID) String() string {
|
func (i epID) String() string {
|
||||||
// src and remote is swapped
|
// src and remote is swapped
|
||||||
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
return fmt.Sprintf("%s:%d → %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ type Forwarder struct {
|
|||||||
udpForwarder *udpForwarder
|
udpForwarder *udpForwarder
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
ip net.IP
|
ip tcpip.Address
|
||||||
netstack bool
|
netstack bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -71,12 +71,11 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ones, _ := iface.Address().Network.Mask.Size()
|
|
||||||
protoAddr := tcpip.ProtocolAddress{
|
protoAddr := tcpip.ProtocolAddress{
|
||||||
Protocol: ipv4.ProtocolNumber,
|
Protocol: ipv4.ProtocolNumber,
|
||||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||||
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
|
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||||
PrefixLen: ones,
|
PrefixLen: iface.Address().Network.Bits(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -116,7 +115,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
netstack: netstack,
|
netstack: netstack,
|
||||||
ip: iface.Address().IP,
|
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||||
}
|
}
|
||||||
|
|
||||||
receiveWindow := defaultReceiveWindow
|
receiveWindow := defaultReceiveWindow
|
||||||
@@ -167,7 +166,7 @@ func (f *Forwarder) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||||
if f.netstack && f.ip.Equal(addr.AsSlice()) {
|
if f.netstack && f.ip.Equal(addr) {
|
||||||
return net.IPv4(127, 0, 0, 1)
|
return net.IPv4(127, 0, 0, 1)
|
||||||
}
|
}
|
||||||
return addr.AsSlice()
|
return addr.AsSlice()
|
||||||
@@ -179,7 +178,6 @@ func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uin
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
|
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
|
||||||
|
|
||||||
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
|
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
|
||||||
return value.([]byte), true
|
return value.([]byte), true
|
||||||
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {
|
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {
|
||||||
|
|||||||
@@ -111,12 +111,12 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
|||||||
|
|
||||||
if errInToOut != nil {
|
if errInToOut != nil {
|
||||||
if !isClosedError(errInToOut) {
|
if !isClosedError(errInToOut) {
|
||||||
f.logger.Error("proxyTCP: copy error (in -> out): %v", errInToOut)
|
f.logger.Error("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if errOutToIn != nil {
|
if errOutToIn != nil {
|
||||||
if !isClosedError(errOutToIn) {
|
if !isClosedError(errOutToIn) {
|
||||||
f.logger.Error("proxyTCP: copy error (out -> in): %v", errOutToIn)
|
f.logger.Error("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -250,10 +250,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
if outboundErr != nil && !isClosedError(outboundErr) {
|
if outboundErr != nil && !isClosedError(outboundErr) {
|
||||||
f.logger.Error("proxyUDP: copy error (outbound->inbound): %v", outboundErr)
|
f.logger.Error("proxyUDP: copy error (outbound→inbound) for %s: %v", epID(id), outboundErr)
|
||||||
}
|
}
|
||||||
if inboundErr != nil && !isClosedError(inboundErr) {
|
if inboundErr != nil && !isClosedError(inboundErr) {
|
||||||
f.logger.Error("proxyUDP: copy error (inbound->outbound): %v", inboundErr)
|
f.logger.Error("proxyUDP: copy error (inbound→outbound) for %s: %v", epID(id), inboundErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
var rxPackets, txPackets uint64
|
var rxPackets, txPackets uint64
|
||||||
|
|||||||
@@ -45,8 +45,12 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
|
|||||||
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||||
if ipv4 := ip.To4(); ipv4 != nil {
|
if !ip.Is4() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ipv4 := ip.AsSlice()
|
||||||
|
|
||||||
high := uint16(ipv4[0])
|
high := uint16(ipv4[0])
|
||||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||||
|
|
||||||
@@ -58,11 +62,9 @@ func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap,
|
|||||||
bit := low % 32
|
bit := low % 32
|
||||||
bitmap[high].bitmap[index] |= 1 << bit
|
bitmap[high].bitmap[index] |= 1 << bit
|
||||||
|
|
||||||
ipStr := ipv4.String()
|
if _, exists := ipv4Set[ip]; !exists {
|
||||||
if _, exists := ipv4Set[ipStr]; !exists {
|
ipv4Set[ip] = struct{}{}
|
||||||
ipv4Set[ipStr] = struct{}{}
|
*ipv4Addresses = append(*ipv4Addresses, ip)
|
||||||
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -79,12 +81,12 @@ func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
|||||||
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
|
||||||
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||||
addrs, err := iface.Addrs()
|
addrs, err := iface.Addrs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||||
@@ -102,7 +104,13 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil {
|
addr, ok := netip.AddrFromSlice(ip)
|
||||||
|
if !ok {
|
||||||
|
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||||
log.Debugf("process IP failed: %v", err)
|
log.Debugf("process IP failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -116,8 +124,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
var newIPv4Bitmap [256]*ipv4LowBitmap
|
var newIPv4Bitmap [256]*ipv4LowBitmap
|
||||||
ipv4Set := make(map[string]struct{})
|
ipv4Set := make(map[netip.Addr]struct{})
|
||||||
var ipv4Addresses []string
|
var ipv4Addresses []netip.Addr
|
||||||
|
|
||||||
// 127.0.0.0/8
|
// 127.0.0.0/8
|
||||||
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
||||||
|
|||||||
@@ -20,11 +20,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Localhost range",
|
name: "Localhost range",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("127.0.0.2"),
|
testIP: netip.MustParseAddr("127.0.0.2"),
|
||||||
expected: true,
|
expected: true,
|
||||||
@@ -32,11 +29,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Localhost standard address",
|
name: "Localhost standard address",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("127.0.0.1"),
|
testIP: netip.MustParseAddr("127.0.0.1"),
|
||||||
expected: true,
|
expected: true,
|
||||||
@@ -44,11 +38,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Localhost range edge",
|
name: "Localhost range edge",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("127.255.255.255"),
|
testIP: netip.MustParseAddr("127.255.255.255"),
|
||||||
expected: true,
|
expected: true,
|
||||||
@@ -56,11 +47,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Local IP matches",
|
name: "Local IP matches",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("192.168.1.1"),
|
testIP: netip.MustParseAddr("192.168.1.1"),
|
||||||
expected: true,
|
expected: true,
|
||||||
@@ -68,11 +56,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Local IP doesn't match",
|
name: "Local IP doesn't match",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("192.168.1.2"),
|
testIP: netip.MustParseAddr("192.168.1.2"),
|
||||||
expected: false,
|
expected: false,
|
||||||
@@ -80,11 +65,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Local IP doesn't match - addresses 32 apart",
|
name: "Local IP doesn't match - addresses 32 apart",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("192.168.1.33"),
|
testIP: netip.MustParseAddr("192.168.1.33"),
|
||||||
expected: false,
|
expected: false,
|
||||||
@@ -92,11 +74,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "IPv6 address",
|
name: "IPv6 address",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("fe80::1"),
|
IP: netip.MustParseAddr("fe80::1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("fe80::"),
|
|
||||||
Mask: net.CIDRMask(64, 128),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("fe80::1"),
|
testIP: netip.MustParseAddr("fe80::1"),
|
||||||
expected: false,
|
expected: false,
|
||||||
|
|||||||
@@ -38,11 +38,8 @@ func TestTracePacket(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.10.0.100"),
|
IP: netip.MustParseAddr("100.10.0.100"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,8 +39,12 @@ const (
|
|||||||
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
|
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
|
||||||
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
||||||
|
|
||||||
// EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack
|
// EnvEnableLocalForwarding enables forwarding of local traffic to the native stack for internal (non-NetBird) interfaces.
|
||||||
// Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible
|
// Default off as it might be security risk because sockets listening on localhost only will become accessible.
|
||||||
|
EnvEnableLocalForwarding = "NB_ENABLE_LOCAL_FORWARDING"
|
||||||
|
|
||||||
|
// EnvEnableNetstackLocalForwarding is an alias for EnvEnableLocalForwarding.
|
||||||
|
// In netstack mode, it enables forwarding of local traffic to the native stack for all interfaces.
|
||||||
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -71,7 +75,6 @@ type Manager struct {
|
|||||||
// incomingRules is used for filtering and hooks
|
// incomingRules is used for filtering and hooks
|
||||||
incomingRules map[netip.Addr]RuleSet
|
incomingRules map[netip.Addr]RuleSet
|
||||||
routeRules RouteRules
|
routeRules RouteRules
|
||||||
wgNetwork *net.IPNet
|
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
wgIface common.IFaceMapper
|
wgIface common.IFaceMapper
|
||||||
nativeFirewall firewall.Manager
|
nativeFirewall firewall.Manager
|
||||||
@@ -148,6 +151,11 @@ func parseCreateEnv() (bool, bool) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
|
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
|
||||||
}
|
}
|
||||||
|
} else if val := os.Getenv(EnvEnableLocalForwarding); val != "" {
|
||||||
|
enableLocalForwarding, err = strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return disableConntrack, enableLocalForwarding
|
return disableConntrack, enableLocalForwarding
|
||||||
@@ -269,7 +277,7 @@ func (m *Manager) determineRouting() error {
|
|||||||
|
|
||||||
log.Info("userspace routing is forced")
|
log.Info("userspace routing is forced")
|
||||||
|
|
||||||
case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported():
|
case !m.netstack && m.nativeFirewall != nil:
|
||||||
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
||||||
// netstack mode won't support native routing as there is no interface
|
// netstack mode won't support native routing as there is no interface
|
||||||
|
|
||||||
@@ -326,6 +334,10 @@ func (m *Manager) IsServerRouteSupported() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) IsStateful() bool {
|
||||||
|
return m.stateful
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.AddNatRule(pair)
|
return m.nativeFirewall.AddNatRule(pair)
|
||||||
@@ -606,9 +618,8 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.stateful {
|
// for netflow we keep track even if the firewall is stateless
|
||||||
m.trackOutbound(d, srcIP, dstIP, size)
|
m.trackOutbound(d, srcIP, dstIP, size)
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -660,7 +671,7 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
|
|||||||
flags := getTCPFlags(&d.tcp)
|
flags := getTCPFlags(&d.tcp)
|
||||||
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
|
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
|
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -673,7 +684,7 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
|||||||
flags := getTCPFlags(&d.tcp)
|
flags := getTCPFlags(&d.tcp)
|
||||||
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
|
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size)
|
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -777,9 +788,10 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// if running in netstack mode we need to pass this to the forwarder
|
// If requested we pass local traffic to internal interfaces to the forwarder.
|
||||||
if m.netstack && m.localForwarding {
|
// netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder.
|
||||||
return m.handleNetstackLocalTraffic(packetData)
|
if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) {
|
||||||
|
return m.handleForwardedLocalTraffic(packetData)
|
||||||
}
|
}
|
||||||
|
|
||||||
// track inbound packets to get the correct direction and session id for flows
|
// track inbound packets to get the correct direction and session id for flows
|
||||||
@@ -789,8 +801,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
|
||||||
|
|
||||||
fwd := m.forwarder.Load()
|
fwd := m.forwarder.Load()
|
||||||
if fwd == nil {
|
if fwd == nil {
|
||||||
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
||||||
@@ -1088,11 +1099,6 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetNetwork of the wireguard interface to which filtering applied
|
|
||||||
func (m *Manager) SetNetwork(network *net.IPNet) {
|
|
||||||
m.wgNetwork = network
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||||
//
|
//
|
||||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||||
|
|||||||
@@ -174,11 +174,6 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply scenario-specific setup
|
// Apply scenario-specific setup
|
||||||
sc.setupFunc(manager)
|
sc.setupFunc(manager)
|
||||||
|
|
||||||
@@ -219,11 +214,6 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pre-populate connection table
|
// Pre-populate connection table
|
||||||
srcIPs := generateRandomIPs(count)
|
srcIPs := generateRandomIPs(count)
|
||||||
dstIPs := generateRandomIPs(count)
|
dstIPs := generateRandomIPs(count)
|
||||||
@@ -267,11 +257,6 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
srcIP := generateRandomIPs(1)[0]
|
srcIP := generateRandomIPs(1)[0]
|
||||||
dstIP := generateRandomIPs(1)[0]
|
dstIP := generateRandomIPs(1)[0]
|
||||||
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
|
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
|
||||||
@@ -304,10 +289,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "new",
|
state: "new",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -321,10 +302,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "established",
|
state: "established",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -339,10 +316,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolUDP,
|
proto: layers.IPProtocolUDP,
|
||||||
state: "new",
|
state: "new",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -356,10 +329,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolUDP,
|
proto: layers.IPProtocolUDP,
|
||||||
state: "established",
|
state: "established",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -373,10 +342,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "new",
|
state: "new",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("0.0.0.0"),
|
|
||||||
Mask: net.CIDRMask(0, 32),
|
|
||||||
}
|
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -390,10 +355,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "established",
|
state: "established",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("0.0.0.0"),
|
|
||||||
Mask: net.CIDRMask(0, 32),
|
|
||||||
}
|
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -408,10 +369,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "post_handshake",
|
state: "post_handshake",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("0.0.0.0"),
|
|
||||||
Mask: net.CIDRMask(0, 32),
|
|
||||||
}
|
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -426,10 +383,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolUDP,
|
proto: layers.IPProtocolUDP,
|
||||||
state: "new",
|
state: "new",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("0.0.0.0"),
|
|
||||||
Mask: net.CIDRMask(0, 32),
|
|
||||||
}
|
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -443,10 +396,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolUDP,
|
proto: layers.IPProtocolUDP,
|
||||||
state: "established",
|
state: "established",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("0.0.0.0"),
|
|
||||||
Mask: net.CIDRMask(0, 32),
|
|
||||||
}
|
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -593,11 +542,6 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.SetNetwork(&net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
})
|
|
||||||
|
|
||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
// Single rule to allow all return traffic from port 80
|
// Single rule to allow all return traffic from port 80
|
||||||
@@ -681,11 +625,6 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.SetNetwork(&net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
})
|
|
||||||
|
|
||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
// Single rule to allow all return traffic from port 80
|
// Single rule to allow all return traffic from port 80
|
||||||
@@ -797,11 +736,6 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.SetNetwork(&net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
})
|
|
||||||
|
|
||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
@@ -882,11 +816,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.SetNetwork(&net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
})
|
|
||||||
|
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
@@ -1032,7 +961,8 @@ func BenchmarkRouteACLs(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rules {
|
for _, r := range rules {
|
||||||
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
|
dst := fw.Network{Prefix: r.dest}
|
||||||
|
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,12 +19,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestPeerACLFiltering(t *testing.T) {
|
func TestPeerACLFiltering(t *testing.T) {
|
||||||
localIP := net.ParseIP("100.10.0.100")
|
localIP := netip.MustParseAddr("100.10.0.100")
|
||||||
wgNet := &net.IPNet{
|
wgNet := netip.MustParsePrefix("100.10.0.0/16")
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
@@ -43,8 +39,6 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.wgNetwork = wgNet
|
|
||||||
|
|
||||||
err = manager.UpdateLocalIPs()
|
err = manager.UpdateLocalIPs()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -581,14 +575,13 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
|||||||
dev := mocks.NewMockDevice(ctrl)
|
dev := mocks.NewMockDevice(ctrl)
|
||||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||||
|
|
||||||
localIP, wgNet, err := net.ParseCIDR(network)
|
wgNet := netip.MustParsePrefix(network)
|
||||||
require.NoError(tb, err)
|
|
||||||
|
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: localIP,
|
IP: wgNet.Addr(),
|
||||||
Network: wgNet,
|
Network: wgNet,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -1440,11 +1433,8 @@ func TestRouteACLSet(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.10.0.100"),
|
IP: netip.MustParseAddr("100.10.0.100"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -271,11 +271,8 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.10.0.100"),
|
IP: netip.MustParseAddr("100.10.0.100"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -285,10 +282,6 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
ip := net.ParseIP("0.0.0.0")
|
ip := net.ParseIP("0.0.0.0")
|
||||||
proto := fw.ProtocolUDP
|
proto := fw.ProtocolUDP
|
||||||
@@ -396,10 +389,6 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
}, false, flowLogger)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
}
|
|
||||||
manager.udpTracker.Close()
|
manager.udpTracker.Close()
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
|
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -509,11 +498,6 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}, false, flowLogger)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
manager.udpTracker.Close() // Close the existing tracker
|
manager.udpTracker.Close() // Close the existing tracker
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
|
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
|
||||||
manager.decoders = sync.Pool{
|
manager.decoders = sync.Pool{
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if u.address.Network.Contains(a.AsSlice()) {
|
if u.address.Network.Contains(a) {
|
||||||
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||||
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||||
}
|
}
|
||||||
|
|||||||
17
client/iface/configurer/common.go
Normal file
17
client/iface/configurer/common.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package configurer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
)
|
||||||
|
|
||||||
|
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
|
||||||
|
ipNets := make([]net.IPNet, len(prefixes))
|
||||||
|
for i, prefix := range prefixes {
|
||||||
|
ipNets[i] = net.IPNet{
|
||||||
|
IP: prefix.Addr().AsSlice(), // Convert netip.Addr to net.IP
|
||||||
|
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ipNets
|
||||||
|
}
|
||||||
@@ -5,6 +5,7 @@ package configurer
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -12,6 +13,8 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var zeroKey wgtypes.Key
|
||||||
|
|
||||||
type KernelConfigurer struct {
|
type KernelConfigurer struct {
|
||||||
deviceName string
|
deviceName string
|
||||||
}
|
}
|
||||||
@@ -43,7 +46,7 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -52,7 +55,7 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, ke
|
|||||||
PublicKey: peerKeyParsed,
|
PublicKey: peerKeyParsed,
|
||||||
ReplaceAllowedIPs: false,
|
ReplaceAllowedIPs: false,
|
||||||
// don't replace allowed ips, wg will handle duplicated peer IP
|
// don't replace allowed ips, wg will handle duplicated peer IP
|
||||||
AllowedIPs: allowedIps,
|
AllowedIPs: prefixesToIPNets(allowedIps),
|
||||||
PersistentKeepaliveInterval: &keepAlive,
|
PersistentKeepaliveInterval: &keepAlive,
|
||||||
Endpoint: endpoint,
|
Endpoint: endpoint,
|
||||||
PresharedKey: preSharedKey,
|
PresharedKey: preSharedKey,
|
||||||
@@ -89,10 +92,10 @@ func (c *KernelConfigurer) RemovePeer(peerKey string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
|
func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
||||||
_, ipNet, err := net.ParseCIDR(allowedIP)
|
ipNet := net.IPNet{
|
||||||
if err != nil {
|
IP: allowedIP.Addr().AsSlice(),
|
||||||
return err
|
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
|
||||||
}
|
}
|
||||||
|
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
@@ -103,7 +106,7 @@ func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error
|
|||||||
PublicKey: peerKeyParsed,
|
PublicKey: peerKeyParsed,
|
||||||
UpdateOnly: true,
|
UpdateOnly: true,
|
||||||
ReplaceAllowedIPs: false,
|
ReplaceAllowedIPs: false,
|
||||||
AllowedIPs: []net.IPNet{*ipNet},
|
AllowedIPs: []net.IPNet{ipNet},
|
||||||
}
|
}
|
||||||
|
|
||||||
config := wgtypes.Config{
|
config := wgtypes.Config{
|
||||||
@@ -116,10 +119,10 @@ func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP string) error {
|
func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
||||||
_, ipNet, err := net.ParseCIDR(allowedIP)
|
ipNet := net.IPNet{
|
||||||
if err != nil {
|
IP: allowedIP.Addr().AsSlice(),
|
||||||
return fmt.Errorf("parse allowed IP: %w", err)
|
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
|
||||||
}
|
}
|
||||||
|
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
@@ -187,7 +190,11 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer wg.Close()
|
defer func() {
|
||||||
|
if err := wg.Close(); err != nil {
|
||||||
|
log.Errorf("Failed to close wgctrl client: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// validate if device with name exists
|
// validate if device with name exists
|
||||||
_, err = wg.Device(c.deviceName)
|
_, err = wg.Device(c.deviceName)
|
||||||
@@ -201,14 +208,71 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
|
|||||||
func (c *KernelConfigurer) Close() {
|
func (c *KernelConfigurer) Close() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) {
|
func (c *KernelConfigurer) FullStats() (*Stats, error) {
|
||||||
peer, err := c.getPeer(c.deviceName, peerKey)
|
wg, err := wgctrl.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return WGStats{}, fmt.Errorf("get wireguard stats: %w", err)
|
return nil, fmt.Errorf("wgctl: %w", err)
|
||||||
}
|
}
|
||||||
return WGStats{
|
defer func() {
|
||||||
|
err = wg.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Got error while closing wgctl: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wgDevice, err := wg.Device(c.deviceName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
|
||||||
|
}
|
||||||
|
fullStats := &Stats{
|
||||||
|
DeviceName: wgDevice.Name,
|
||||||
|
PublicKey: wgDevice.PublicKey.String(),
|
||||||
|
ListenPort: wgDevice.ListenPort,
|
||||||
|
FWMark: wgDevice.FirewallMark,
|
||||||
|
Peers: []Peer{},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range wgDevice.Peers {
|
||||||
|
peer := Peer{
|
||||||
|
PublicKey: p.PublicKey.String(),
|
||||||
|
AllowedIPs: p.AllowedIPs,
|
||||||
|
TxBytes: p.TransmitBytes,
|
||||||
|
RxBytes: p.ReceiveBytes,
|
||||||
|
LastHandshake: p.LastHandshakeTime,
|
||||||
|
PresharedKey: p.PresharedKey != zeroKey,
|
||||||
|
}
|
||||||
|
if p.Endpoint != nil {
|
||||||
|
peer.Endpoint = *p.Endpoint
|
||||||
|
}
|
||||||
|
fullStats.Peers = append(fullStats.Peers, peer)
|
||||||
|
}
|
||||||
|
return fullStats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
|
||||||
|
stats := make(map[string]WGStats)
|
||||||
|
wg, err := wgctrl.New()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("wgctl: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err = wg.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Got error while closing wgctl: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wgDevice, err := wg.Device(c.deviceName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peer := range wgDevice.Peers {
|
||||||
|
stats[peer.PublicKey.String()] = WGStats{
|
||||||
LastHandshake: peer.LastHandshakeTime,
|
LastHandshake: peer.LastHandshakeTime,
|
||||||
TxBytes: peer.TransmitBytes,
|
TxBytes: peer.TransmitBytes,
|
||||||
RxBytes: peer.ReceiveBytes,
|
RxBytes: peer.ReceiveBytes,
|
||||||
}, nil
|
}
|
||||||
|
}
|
||||||
|
return stats, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
package configurer
|
package configurer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/base64"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -17,6 +19,20 @@ import (
|
|||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
privateKey = "private_key"
|
||||||
|
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
|
||||||
|
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
|
||||||
|
ipcKeyTxBytes = "tx_bytes"
|
||||||
|
ipcKeyRxBytes = "rx_bytes"
|
||||||
|
allowedIP = "allowed_ip"
|
||||||
|
endpoint = "endpoint"
|
||||||
|
fwmark = "fwmark"
|
||||||
|
listenPort = "listen_port"
|
||||||
|
publicKey = "public_key"
|
||||||
|
presharedKey = "preshared_key"
|
||||||
|
)
|
||||||
|
|
||||||
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
|
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
|
||||||
|
|
||||||
type WGUSPConfigurer struct {
|
type WGUSPConfigurer struct {
|
||||||
@@ -52,7 +68,7 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
|
|||||||
return c.device.IpcSet(toWgUserspaceString(config))
|
return c.device.IpcSet(toWgUserspaceString(config))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -61,7 +77,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, kee
|
|||||||
PublicKey: peerKeyParsed,
|
PublicKey: peerKeyParsed,
|
||||||
ReplaceAllowedIPs: false,
|
ReplaceAllowedIPs: false,
|
||||||
// don't replace allowed ips, wg will handle duplicated peer IP
|
// don't replace allowed ips, wg will handle duplicated peer IP
|
||||||
AllowedIPs: allowedIps,
|
AllowedIPs: prefixesToIPNets(allowedIps),
|
||||||
PersistentKeepaliveInterval: &keepAlive,
|
PersistentKeepaliveInterval: &keepAlive,
|
||||||
PresharedKey: preSharedKey,
|
PresharedKey: preSharedKey,
|
||||||
Endpoint: endpoint,
|
Endpoint: endpoint,
|
||||||
@@ -91,10 +107,10 @@ func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
|||||||
return c.device.IpcSet(toWgUserspaceString(config))
|
return c.device.IpcSet(toWgUserspaceString(config))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
|
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
||||||
_, ipNet, err := net.ParseCIDR(allowedIP)
|
ipNet := net.IPNet{
|
||||||
if err != nil {
|
IP: allowedIP.Addr().AsSlice(),
|
||||||
return err
|
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
|
||||||
}
|
}
|
||||||
|
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
@@ -105,7 +121,7 @@ func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
|
|||||||
PublicKey: peerKeyParsed,
|
PublicKey: peerKeyParsed,
|
||||||
UpdateOnly: true,
|
UpdateOnly: true,
|
||||||
ReplaceAllowedIPs: false,
|
ReplaceAllowedIPs: false,
|
||||||
AllowedIPs: []net.IPNet{*ipNet},
|
AllowedIPs: []net.IPNet{ipNet},
|
||||||
}
|
}
|
||||||
|
|
||||||
config := wgtypes.Config{
|
config := wgtypes.Config{
|
||||||
@@ -115,7 +131,7 @@ func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
|
|||||||
return c.device.IpcSet(toWgUserspaceString(config))
|
return c.device.IpcSet(toWgUserspaceString(config))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
|
func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
||||||
ipc, err := c.device.IpcGet()
|
ipc, err := c.device.IpcGet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -138,6 +154,8 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
|
|||||||
|
|
||||||
foundPeer := false
|
foundPeer := false
|
||||||
removedAllowedIP := false
|
removedAllowedIP := false
|
||||||
|
ip := allowedIP.String()
|
||||||
|
|
||||||
for _, line := range lines {
|
for _, line := range lines {
|
||||||
line = strings.TrimSpace(line)
|
line = strings.TrimSpace(line)
|
||||||
|
|
||||||
@@ -160,8 +178,8 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
|
|||||||
|
|
||||||
// Append the line to the output string
|
// Append the line to the output string
|
||||||
if foundPeer && strings.HasPrefix(line, "allowed_ip=") {
|
if foundPeer && strings.HasPrefix(line, "allowed_ip=") {
|
||||||
allowedIP := strings.TrimPrefix(line, "allowed_ip=")
|
allowedIPStr := strings.TrimPrefix(line, "allowed_ip=")
|
||||||
_, ipNet, err := net.ParseCIDR(allowedIP)
|
_, ipNet, err := net.ParseCIDR(allowedIPStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -178,6 +196,15 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
|
|||||||
return c.device.IpcSet(toWgUserspaceString(config))
|
return c.device.IpcSet(toWgUserspaceString(config))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
|
||||||
|
ipcStr, err := c.device.IpcGet()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("IpcGet failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return parseStatus(c.deviceName, ipcStr)
|
||||||
|
}
|
||||||
|
|
||||||
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
|
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
|
||||||
func (t *WGUSPConfigurer) startUAPI() {
|
func (t *WGUSPConfigurer) startUAPI() {
|
||||||
var err error
|
var err error
|
||||||
@@ -217,91 +244,75 @@ func (t *WGUSPConfigurer) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) {
|
func (t *WGUSPConfigurer) GetStats() (map[string]WGStats, error) {
|
||||||
ipc, err := t.device.IpcGet()
|
ipc, err := t.device.IpcGet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return WGStats{}, fmt.Errorf("ipc get: %w", err)
|
return nil, fmt.Errorf("ipc get: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, err := findPeerInfo(ipc, peerKey, []string{
|
return parseTransfers(ipc)
|
||||||
"last_handshake_time_sec",
|
|
||||||
"last_handshake_time_nsec",
|
|
||||||
"tx_bytes",
|
|
||||||
"rx_bytes",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return WGStats{}, fmt.Errorf("find peer info: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sec, err := strconv.ParseInt(stats["last_handshake_time_sec"], 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return WGStats{}, fmt.Errorf("parse handshake sec: %w", err)
|
|
||||||
}
|
|
||||||
nsec, err := strconv.ParseInt(stats["last_handshake_time_nsec"], 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return WGStats{}, fmt.Errorf("parse handshake nsec: %w", err)
|
|
||||||
}
|
|
||||||
txBytes, err := strconv.ParseInt(stats["tx_bytes"], 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return WGStats{}, fmt.Errorf("parse tx_bytes: %w", err)
|
|
||||||
}
|
|
||||||
rxBytes, err := strconv.ParseInt(stats["rx_bytes"], 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return WGStats{}, fmt.Errorf("parse rx_bytes: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return WGStats{
|
|
||||||
LastHandshake: time.Unix(sec, nsec),
|
|
||||||
TxBytes: txBytes,
|
|
||||||
RxBytes: rxBytes,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func findPeerInfo(ipcInput string, peerKey string, searchConfigKeys []string) (map[string]string, error) {
|
func parseTransfers(ipc string) (map[string]WGStats, error) {
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
stats := make(map[string]WGStats)
|
||||||
if err != nil {
|
var (
|
||||||
return nil, fmt.Errorf("parse key: %w", err)
|
currentKey string
|
||||||
}
|
currentStats WGStats
|
||||||
|
hasPeer bool
|
||||||
hexKey := hex.EncodeToString(peerKeyParsed[:])
|
)
|
||||||
|
lines := strings.Split(ipc, "\n")
|
||||||
lines := strings.Split(ipcInput, "\n")
|
|
||||||
|
|
||||||
configFound := map[string]string{}
|
|
||||||
foundPeer := false
|
|
||||||
for _, line := range lines {
|
for _, line := range lines {
|
||||||
line = strings.TrimSpace(line)
|
line = strings.TrimSpace(line)
|
||||||
|
|
||||||
// If we're within the details of the found peer and encounter another public key,
|
// If we're within the details of the found peer and encounter another public key,
|
||||||
// this means we're starting another peer's details. So, stop.
|
// this means we're starting another peer's details. So, stop.
|
||||||
if strings.HasPrefix(line, "public_key=") && foundPeer {
|
if strings.HasPrefix(line, "public_key=") {
|
||||||
break
|
peerID := strings.TrimPrefix(line, "public_key=")
|
||||||
|
h, err := hex.DecodeString(peerID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decode peerID: %w", err)
|
||||||
|
}
|
||||||
|
currentKey = base64.StdEncoding.EncodeToString(h)
|
||||||
|
currentStats = WGStats{} // Reset stats for the new peer
|
||||||
|
hasPeer = true
|
||||||
|
stats[currentKey] = currentStats
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Identify the peer with the specific public key
|
if !hasPeer {
|
||||||
if line == fmt.Sprintf("public_key=%s", hexKey) {
|
continue
|
||||||
foundPeer = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, key := range searchConfigKeys {
|
key := strings.SplitN(line, "=", 2)
|
||||||
if foundPeer && strings.HasPrefix(line, key+"=") {
|
if len(key) != 2 {
|
||||||
v := strings.SplitN(line, "=", 2)
|
continue
|
||||||
configFound[v[0]] = v[1]
|
|
||||||
}
|
}
|
||||||
|
switch key[0] {
|
||||||
|
case ipcKeyLastHandshakeTimeSec:
|
||||||
|
hs, err := toLastHandshake(key[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
currentStats.LastHandshake = hs
|
||||||
|
stats[currentKey] = currentStats
|
||||||
|
case ipcKeyRxBytes:
|
||||||
|
rxBytes, err := toBytes(key[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse rx_bytes: %w", err)
|
||||||
|
}
|
||||||
|
currentStats.RxBytes = rxBytes
|
||||||
|
stats[currentKey] = currentStats
|
||||||
|
case ipcKeyTxBytes:
|
||||||
|
TxBytes, err := toBytes(key[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse tx_bytes: %w", err)
|
||||||
|
}
|
||||||
|
currentStats.TxBytes = TxBytes
|
||||||
|
stats[currentKey] = currentStats
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo: use multierr
|
return stats, nil
|
||||||
for _, key := range searchConfigKeys {
|
|
||||||
if _, ok := configFound[key]; !ok {
|
|
||||||
return configFound, fmt.Errorf("config key not found: %s", key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !foundPeer {
|
|
||||||
return nil, fmt.Errorf("%w: %s", ErrPeerNotFound, peerKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
return configFound, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||||
@@ -355,9 +366,154 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
|||||||
return sb.String()
|
return sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toLastHandshake(stringVar string) (time.Time, error) {
|
||||||
|
sec, err := strconv.ParseInt(stringVar, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
|
||||||
|
}
|
||||||
|
return time.Unix(sec, 0), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func toBytes(s string) (int64, error) {
|
||||||
|
return strconv.ParseInt(s, 10, 64)
|
||||||
|
}
|
||||||
|
|
||||||
func getFwmark() int {
|
func getFwmark() int {
|
||||||
if nbnet.AdvancedRouting() {
|
if nbnet.AdvancedRouting() {
|
||||||
return nbnet.ControlPlaneMark
|
return nbnet.ControlPlaneMark
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hexToWireguardKey(hexKey string) (wgtypes.Key, error) {
|
||||||
|
// Decode hex string to bytes
|
||||||
|
keyBytes, err := hex.DecodeString(hexKey)
|
||||||
|
if err != nil {
|
||||||
|
return wgtypes.Key{}, fmt.Errorf("failed to decode hex key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we have the right number of bytes (WireGuard keys are 32 bytes)
|
||||||
|
if len(keyBytes) != 32 {
|
||||||
|
return wgtypes.Key{}, fmt.Errorf("invalid key length: expected 32 bytes, got %d", len(keyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to wgtypes.Key
|
||||||
|
var key wgtypes.Key
|
||||||
|
copy(key[:], keyBytes)
|
||||||
|
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseStatus(deviceName, ipcStr string) (*Stats, error) {
|
||||||
|
stats := &Stats{DeviceName: deviceName}
|
||||||
|
var currentPeer *Peer
|
||||||
|
for _, line := range strings.Split(strings.TrimSpace(ipcStr), "\n") {
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parts := strings.SplitN(line, "=", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key := parts[0]
|
||||||
|
val := parts[1]
|
||||||
|
|
||||||
|
switch key {
|
||||||
|
case privateKey:
|
||||||
|
key, err := hexToWireguardKey(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to parse private key: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
stats.PublicKey = key.PublicKey().String()
|
||||||
|
case publicKey:
|
||||||
|
// Save previous peer
|
||||||
|
if currentPeer != nil {
|
||||||
|
stats.Peers = append(stats.Peers, *currentPeer)
|
||||||
|
}
|
||||||
|
key, err := hexToWireguardKey(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to parse public key: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
currentPeer = &Peer{
|
||||||
|
PublicKey: key.String(),
|
||||||
|
}
|
||||||
|
case listenPort:
|
||||||
|
if port, err := strconv.Atoi(val); err == nil {
|
||||||
|
stats.ListenPort = port
|
||||||
|
}
|
||||||
|
case fwmark:
|
||||||
|
if fwmark, err := strconv.Atoi(val); err == nil {
|
||||||
|
stats.FWMark = fwmark
|
||||||
|
}
|
||||||
|
case endpoint:
|
||||||
|
if currentPeer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to parse endpoint: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to parse endpoint port: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
currentPeer.Endpoint = net.UDPAddr{
|
||||||
|
IP: net.ParseIP(host),
|
||||||
|
Port: port,
|
||||||
|
}
|
||||||
|
case allowedIP:
|
||||||
|
if currentPeer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, ipnet, err := net.ParseCIDR(val)
|
||||||
|
if err == nil {
|
||||||
|
currentPeer.AllowedIPs = append(currentPeer.AllowedIPs, *ipnet)
|
||||||
|
}
|
||||||
|
case ipcKeyTxBytes:
|
||||||
|
if currentPeer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rxBytes, err := toBytes(val)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
currentPeer.TxBytes = rxBytes
|
||||||
|
case ipcKeyRxBytes:
|
||||||
|
if currentPeer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rxBytes, err := toBytes(val)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
currentPeer.RxBytes = rxBytes
|
||||||
|
|
||||||
|
case ipcKeyLastHandshakeTimeSec:
|
||||||
|
if currentPeer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ts, err := toLastHandshake(val)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
currentPeer.LastHandshake = ts
|
||||||
|
case presharedKey:
|
||||||
|
if currentPeer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if val != "" {
|
||||||
|
currentPeer.PresharedKey = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if currentPeer != nil {
|
||||||
|
stats.Peers = append(stats.Peers, *currentPeer)
|
||||||
|
}
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,10 +2,8 @@ package configurer
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
@@ -34,58 +32,35 @@ errno=0
|
|||||||
|
|
||||||
`
|
`
|
||||||
|
|
||||||
func Test_findPeerInfo(t *testing.T) {
|
func Test_parseTransfers(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
peerKey string
|
peerKey string
|
||||||
searchKeys []string
|
want WGStats
|
||||||
want map[string]string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "single",
|
name: "single",
|
||||||
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
|
peerKey: "b85996fecc9c7f1fc6d2572a76eda11d59bcd20be8e543b15ce4bd85a8e75a33",
|
||||||
searchKeys: []string{"tx_bytes"},
|
want: WGStats{
|
||||||
want: map[string]string{
|
TxBytes: 0,
|
||||||
"tx_bytes": "38333",
|
RxBytes: 0,
|
||||||
},
|
},
|
||||||
wantErr: false,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple",
|
name: "multiple",
|
||||||
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
|
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
|
||||||
searchKeys: []string{"tx_bytes", "rx_bytes"},
|
want: WGStats{
|
||||||
want: map[string]string{
|
TxBytes: 38333,
|
||||||
"tx_bytes": "38333",
|
RxBytes: 2224,
|
||||||
"rx_bytes": "2224",
|
|
||||||
},
|
},
|
||||||
wantErr: false,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "lastpeer",
|
name: "lastpeer",
|
||||||
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
|
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
|
||||||
searchKeys: []string{"tx_bytes", "rx_bytes"},
|
want: WGStats{
|
||||||
want: map[string]string{
|
TxBytes: 1212111,
|
||||||
"tx_bytes": "1212111",
|
RxBytes: 1929999999,
|
||||||
"rx_bytes": "1929999999",
|
|
||||||
},
|
},
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "peer not found",
|
|
||||||
peerKey: "1111111111111111111111111111111111111111111111111111111111111111",
|
|
||||||
searchKeys: nil,
|
|
||||||
want: nil,
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "key not found",
|
|
||||||
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
|
|
||||||
searchKeys: []string{"tx_bytes", "unknown_key"},
|
|
||||||
want: map[string]string{
|
|
||||||
"tx_bytes": "1212111",
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -96,9 +71,19 @@ func Test_findPeerInfo(t *testing.T) {
|
|||||||
key, err := wgtypes.NewKey(res)
|
key, err := wgtypes.NewKey(res)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
got, err := findPeerInfo(ipcFixture, key.String(), tt.searchKeys)
|
stats, err := parseTransfers(ipcFixture)
|
||||||
assert.Equalf(t, tt.wantErr, err != nil, fmt.Sprintf("findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys))
|
if err != nil {
|
||||||
assert.Equalf(t, tt.want, got, "findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys)
|
require.NoError(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
stat, ok := stats[key.String()]
|
||||||
|
if !ok {
|
||||||
|
require.True(t, ok)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, tt.want, stat)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
24
client/iface/configurer/wgshow.go
Normal file
24
client/iface/configurer/wgshow.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package configurer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Peer struct {
|
||||||
|
PublicKey string
|
||||||
|
Endpoint net.UDPAddr
|
||||||
|
AllowedIPs []net.IPNet
|
||||||
|
TxBytes int64
|
||||||
|
RxBytes int64
|
||||||
|
LastHandshake time.Time
|
||||||
|
PresharedKey bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type Stats struct {
|
||||||
|
DeviceName string
|
||||||
|
PublicKey string
|
||||||
|
ListenPort int
|
||||||
|
FWMark int
|
||||||
|
Peers []Peer
|
||||||
|
}
|
||||||
@@ -24,6 +24,7 @@ type WGTunDevice struct {
|
|||||||
mtu int
|
mtu int
|
||||||
iceBind *bind.ICEBind
|
iceBind *bind.ICEBind
|
||||||
tunAdapter TunAdapter
|
tunAdapter TunAdapter
|
||||||
|
disableDNS bool
|
||||||
|
|
||||||
name string
|
name string
|
||||||
device *device.Device
|
device *device.Device
|
||||||
@@ -32,7 +33,7 @@ type WGTunDevice struct {
|
|||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
|
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
|
||||||
return &WGTunDevice{
|
return &WGTunDevice{
|
||||||
address: address,
|
address: address,
|
||||||
port: port,
|
port: port,
|
||||||
@@ -40,6 +41,7 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
|
|||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
iceBind: iceBind,
|
iceBind: iceBind,
|
||||||
tunAdapter: tunAdapter,
|
tunAdapter: tunAdapter,
|
||||||
|
disableDNS: disableDNS,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -49,6 +51,13 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
|||||||
routesString := routesToString(routes)
|
routesString := routesToString(routes)
|
||||||
searchDomainsToString := searchDomainsToString(searchDomains)
|
searchDomainsToString := searchDomainsToString(searchDomains)
|
||||||
|
|
||||||
|
// Skip DNS configuration when DisableDNS is enabled
|
||||||
|
if t.disableDNS {
|
||||||
|
log.Info("DNS is disabled, skipping DNS and search domain configuration")
|
||||||
|
dns = ""
|
||||||
|
searchDomainsToString = ""
|
||||||
|
}
|
||||||
|
|
||||||
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
|
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to create Android interface: %s", err)
|
log.Errorf("failed to create Android interface: %s", err)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -24,9 +23,6 @@ type PacketFilter interface {
|
|||||||
|
|
||||||
// RemovePacketHook removes hook by ID
|
// RemovePacketHook removes hook by ID
|
||||||
RemovePacketHook(hookID string) error
|
RemovePacketHook(hookID string) error
|
||||||
|
|
||||||
// SetNetwork of the wireguard interface to which filtering applied
|
|
||||||
SetNetwork(*net.IPNet)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilteredDevice to override Read or Write of packets
|
// FilteredDevice to override Read or Write of packets
|
||||||
|
|||||||
@@ -51,7 +51,11 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
|
|||||||
log.Info("create nbnetstack tun interface")
|
log.Info("create nbnetstack tun interface")
|
||||||
|
|
||||||
// TODO: get from service listener runtime IP
|
// TODO: get from service listener runtime IP
|
||||||
dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
|
dnsAddr, err := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("last ip: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
log.Debugf("netstack using address: %s", t.address.IP)
|
log.Debugf("netstack using address: %s", t.address.IP)
|
||||||
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
|
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
|
||||||
log.Debugf("netstack using dns address: %s", dnsAddr)
|
log.Debugf("netstack using dns address: %s", dnsAddr)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package device
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
@@ -11,10 +12,11 @@ import (
|
|||||||
|
|
||||||
type WGConfigurer interface {
|
type WGConfigurer interface {
|
||||||
ConfigureInterface(privateKey string, port int) error
|
ConfigureInterface(privateKey string, port int) error
|
||||||
UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||||
RemovePeer(peerKey string) error
|
RemovePeer(peerKey string) error
|
||||||
AddAllowedIP(peerKey string, allowedIP string) error
|
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
||||||
RemoveAllowedIP(peerKey string, allowedIP string) error
|
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
||||||
Close()
|
Close()
|
||||||
GetStats(peerKey string) (configurer.WGStats, error)
|
GetStats() (map[string]configurer.WGStats, error)
|
||||||
|
FullStats() (*configurer.Stats, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,7 +64,15 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ip := address.IP.String()
|
ip := address.IP.String()
|
||||||
mask := "0x" + address.Network.Mask.String()
|
|
||||||
|
// Convert prefix length to hex netmask
|
||||||
|
prefixLen := address.Network.Bits()
|
||||||
|
if !address.IP.Is4() {
|
||||||
|
return fmt.Errorf("IPv6 not supported for interface assignment")
|
||||||
|
}
|
||||||
|
|
||||||
|
maskBits := uint32(0xffffffff) << (32 - prefixLen)
|
||||||
|
mask := fmt.Sprintf("0x%08x", maskBits)
|
||||||
|
|
||||||
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)
|
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)
|
||||||
|
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ type WGIFaceOpts struct {
|
|||||||
MobileArgs *device.MobileIFaceArguments
|
MobileArgs *device.MobileIFaceArguments
|
||||||
TransportNet transport.Net
|
TransportNet transport.Net
|
||||||
FilterFn bind.FilterFn
|
FilterFn bind.FilterFn
|
||||||
|
DisableDNS bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// WGIface represents an interface instance
|
// WGIface represents an interface instance
|
||||||
@@ -111,14 +112,14 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
|
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
|
||||||
// Endpoint is optional
|
// Endpoint is optional.
|
||||||
|
// If allowedIps is given it will be added to the existing ones.
|
||||||
func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
netIPNets := prefixesToIPNets(allowedIps)
|
log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps)
|
||||||
log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint)
|
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
||||||
return w.configurer.UpdatePeer(peerKey, netIPNets, keepAlive, endpoint, preSharedKey)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemovePeer removes a Wireguard Peer from the interface iface
|
// RemovePeer removes a Wireguard Peer from the interface iface
|
||||||
@@ -131,7 +132,7 @@ func (w *WGIface) RemovePeer(peerKey string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddAllowedIP adds a prefix to the allowed IPs list of peer
|
// AddAllowedIP adds a prefix to the allowed IPs list of peer
|
||||||
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error {
|
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
@@ -140,7 +141,7 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RemoveAllowedIP removes a prefix from the allowed IPs list of peer
|
// RemoveAllowedIP removes a prefix from the allowed IPs list of peer
|
||||||
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
|
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
@@ -185,7 +186,6 @@ func (w *WGIface) SetFilter(filter device.PacketFilter) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
w.filter = filter
|
w.filter = filter
|
||||||
w.filter.SetNetwork(w.tun.WgAddress().Network)
|
|
||||||
|
|
||||||
w.tun.FilteredDevice().SetFilter(filter)
|
w.tun.FilteredDevice().SetFilter(filter)
|
||||||
return nil
|
return nil
|
||||||
@@ -212,9 +212,13 @@ func (w *WGIface) GetWGDevice() *wgdevice.Device {
|
|||||||
return w.tun.Device()
|
return w.tun.Device()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetStats returns the last handshake time, rx and tx bytes for the given peer
|
// GetStats returns the last handshake time, rx and tx bytes
|
||||||
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
|
||||||
return w.configurer.GetStats(peerKey)
|
return w.configurer.GetStats()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *WGIface) FullStats() (*configurer.Stats, error) {
|
||||||
|
return w.configurer.FullStats()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WGIface) waitUntilRemoved() error {
|
func (w *WGIface) waitUntilRemoved() error {
|
||||||
@@ -251,14 +255,3 @@ func (w *WGIface) GetNet() *netstack.Net {
|
|||||||
|
|
||||||
return w.tun.GetNet()
|
return w.tun.GetNet()
|
||||||
}
|
}
|
||||||
|
|
||||||
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
|
|
||||||
ipNets := make([]net.IPNet, len(prefixes))
|
|
||||||
for i, prefix := range prefixes {
|
|
||||||
ipNets[i] = net.IPNet{
|
|
||||||
IP: net.IP(prefix.Addr().AsSlice()), // Convert netip.Addr to net.IP
|
|
||||||
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ipNets
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
|||||||
|
|
||||||
wgIFace := &WGIface{
|
wgIFace := &WGIface{
|
||||||
userspaceBind: true,
|
userspaceBind: true,
|
||||||
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter),
|
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
|
||||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||||
}
|
}
|
||||||
return wgIFace, nil
|
return wgIFace, nil
|
||||||
|
|||||||
@@ -5,7 +5,6 @@
|
|||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
net "net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
@@ -90,15 +89,3 @@ func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomo
|
|||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetNetwork mocks base method.
|
|
||||||
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
m.ctrl.Call(m, "SetNetwork", arg0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNetwork indicates an expected call of SetNetwork.
|
|
||||||
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package netstack
|
package netstack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -15,8 +13,8 @@ import (
|
|||||||
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
|
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
|
||||||
|
|
||||||
type NetStackTun struct { //nolint:revive
|
type NetStackTun struct { //nolint:revive
|
||||||
address net.IP
|
address netip.Addr
|
||||||
dnsAddress net.IP
|
dnsAddress netip.Addr
|
||||||
mtu int
|
mtu int
|
||||||
listenAddress string
|
listenAddress string
|
||||||
|
|
||||||
@@ -24,7 +22,7 @@ type NetStackTun struct { //nolint:revive
|
|||||||
tundev tun.Device
|
tundev tun.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu int) *NetStackTun {
|
func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
|
||||||
return &NetStackTun{
|
return &NetStackTun{
|
||||||
address: address,
|
address: address,
|
||||||
dnsAddress: dnsAddress,
|
dnsAddress: dnsAddress,
|
||||||
@@ -34,19 +32,9 @@ func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
||||||
addr, ok := netip.AddrFromSlice(t.address)
|
|
||||||
if !ok {
|
|
||||||
return nil, nil, fmt.Errorf("convert address to netip.Addr: %v", t.address)
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsAddr, ok := netip.AddrFromSlice(t.dnsAddress)
|
|
||||||
if !ok {
|
|
||||||
return nil, nil, fmt.Errorf("convert dns address to netip.Addr: %v", t.dnsAddress)
|
|
||||||
}
|
|
||||||
|
|
||||||
nsTunDev, tunNet, err := netstack.CreateNetTUN(
|
nsTunDev, tunNet, err := netstack.CreateNetTUN(
|
||||||
[]netip.Addr{addr.Unmap()},
|
[]netip.Addr{t.address},
|
||||||
[]netip.Addr{dnsAddr.Unmap()},
|
[]netip.Addr{t.dnsAddress},
|
||||||
t.mtu)
|
t.mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
|||||||
@@ -2,28 +2,27 @@ package wgaddr
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Address WireGuard parsed address
|
// Address WireGuard parsed address
|
||||||
type Address struct {
|
type Address struct {
|
||||||
IP net.IP
|
IP netip.Addr
|
||||||
Network *net.IPNet
|
Network netip.Prefix
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
||||||
func ParseWGAddress(address string) (Address, error) {
|
func ParseWGAddress(address string) (Address, error) {
|
||||||
ip, network, err := net.ParseCIDR(address)
|
prefix, err := netip.ParsePrefix(address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Address{}, err
|
return Address{}, err
|
||||||
}
|
}
|
||||||
return Address{
|
return Address{
|
||||||
IP: ip,
|
IP: prefix.Addr().Unmap(),
|
||||||
Network: network,
|
Network: prefix.Masked(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (addr Address) String() string {
|
func (addr Address) String() string {
|
||||||
maskSize, _ := addr.Network.Mask.Size()
|
return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits())
|
||||||
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,6 +24,8 @@
|
|||||||
|
|
||||||
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
|
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
|
||||||
|
|
||||||
|
!define NETBIRD_DATA_DIR "$COMMONPROGRAMDATA\Netbird"
|
||||||
|
|
||||||
Unicode True
|
Unicode True
|
||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
@@ -49,6 +51,10 @@ ShowInstDetails Show
|
|||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
|
!include "MUI2.nsh"
|
||||||
|
!include LogicLib.nsh
|
||||||
|
!include "nsDialogs.nsh"
|
||||||
|
|
||||||
!define MUI_ICON "${ICON}"
|
!define MUI_ICON "${ICON}"
|
||||||
!define MUI_UNICON "${ICON}"
|
!define MUI_UNICON "${ICON}"
|
||||||
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
|
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
|
||||||
@@ -58,9 +64,6 @@ ShowInstDetails Show
|
|||||||
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
|
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
!include "MUI2.nsh"
|
|
||||||
!include LogicLib.nsh
|
|
||||||
|
|
||||||
!define MUI_ABORTWARNING
|
!define MUI_ABORTWARNING
|
||||||
!define MUI_UNABORTWARNING
|
!define MUI_UNABORTWARNING
|
||||||
|
|
||||||
@@ -70,13 +73,16 @@ ShowInstDetails Show
|
|||||||
|
|
||||||
!insertmacro MUI_PAGE_DIRECTORY
|
!insertmacro MUI_PAGE_DIRECTORY
|
||||||
|
|
||||||
; Custom page for autostart checkbox
|
|
||||||
Page custom AutostartPage AutostartPageLeave
|
Page custom AutostartPage AutostartPageLeave
|
||||||
|
|
||||||
!insertmacro MUI_PAGE_INSTFILES
|
!insertmacro MUI_PAGE_INSTFILES
|
||||||
|
|
||||||
!insertmacro MUI_PAGE_FINISH
|
!insertmacro MUI_PAGE_FINISH
|
||||||
|
|
||||||
|
!insertmacro MUI_UNPAGE_WELCOME
|
||||||
|
|
||||||
|
UninstPage custom un.DeleteDataPage un.DeleteDataPageLeave
|
||||||
|
|
||||||
!insertmacro MUI_UNPAGE_CONFIRM
|
!insertmacro MUI_UNPAGE_CONFIRM
|
||||||
|
|
||||||
!insertmacro MUI_UNPAGE_INSTFILES
|
!insertmacro MUI_UNPAGE_INSTFILES
|
||||||
@@ -89,6 +95,10 @@ Page custom AutostartPage AutostartPageLeave
|
|||||||
Var AutostartCheckbox
|
Var AutostartCheckbox
|
||||||
Var AutostartEnabled
|
Var AutostartEnabled
|
||||||
|
|
||||||
|
; Variables for uninstall data deletion option
|
||||||
|
Var DeleteDataCheckbox
|
||||||
|
Var DeleteDataEnabled
|
||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
; Function to create the autostart options page
|
; Function to create the autostart options page
|
||||||
@@ -104,8 +114,8 @@ Function AutostartPage
|
|||||||
|
|
||||||
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
|
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
|
||||||
Pop $AutostartCheckbox
|
Pop $AutostartCheckbox
|
||||||
${NSD_Check} $AutostartCheckbox ; Default to checked
|
${NSD_Check} $AutostartCheckbox
|
||||||
StrCpy $AutostartEnabled "1" ; Default to enabled
|
StrCpy $AutostartEnabled "1"
|
||||||
|
|
||||||
nsDialogs::Show
|
nsDialogs::Show
|
||||||
FunctionEnd
|
FunctionEnd
|
||||||
@@ -115,6 +125,30 @@ Function AutostartPageLeave
|
|||||||
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
|
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
|
||||||
FunctionEnd
|
FunctionEnd
|
||||||
|
|
||||||
|
; Function to create the uninstall data deletion page
|
||||||
|
Function un.DeleteDataPage
|
||||||
|
!insertmacro MUI_HEADER_TEXT "Uninstall Options" "Choose whether to delete ${APP_NAME} data."
|
||||||
|
|
||||||
|
nsDialogs::Create 1018
|
||||||
|
Pop $0
|
||||||
|
|
||||||
|
${If} $0 == error
|
||||||
|
Abort
|
||||||
|
${EndIf}
|
||||||
|
|
||||||
|
${NSD_CreateCheckbox} 0 20u 100% 10u "Delete all ${APP_NAME} configuration and state data (${NETBIRD_DATA_DIR})"
|
||||||
|
Pop $DeleteDataCheckbox
|
||||||
|
${NSD_Uncheck} $DeleteDataCheckbox
|
||||||
|
StrCpy $DeleteDataEnabled "0"
|
||||||
|
|
||||||
|
nsDialogs::Show
|
||||||
|
FunctionEnd
|
||||||
|
|
||||||
|
; Function to handle leaving the data deletion page
|
||||||
|
Function un.DeleteDataPageLeave
|
||||||
|
${NSD_GetState} $DeleteDataCheckbox $DeleteDataEnabled
|
||||||
|
FunctionEnd
|
||||||
|
|
||||||
Function GetAppFromCommand
|
Function GetAppFromCommand
|
||||||
Exch $1
|
Exch $1
|
||||||
Push $2
|
Push $2
|
||||||
@@ -225,31 +259,58 @@ SectionEnd
|
|||||||
Section Uninstall
|
Section Uninstall
|
||||||
${INSTALL_TYPE}
|
${INSTALL_TYPE}
|
||||||
|
|
||||||
|
DetailPrint "Stopping Netbird service..."
|
||||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
|
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
|
||||||
|
DetailPrint "Uninstalling Netbird service..."
|
||||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
||||||
|
|
||||||
# kill ui client
|
DetailPrint "Terminating Netbird UI process..."
|
||||||
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
||||||
|
|
||||||
; Remove autostart registry entry
|
; Remove autostart registry entry
|
||||||
|
DetailPrint "Removing autostart registry entry if exists..."
|
||||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||||
|
|
||||||
|
; Handle data deletion based on checkbox
|
||||||
|
DetailPrint "Checking if user requested data deletion..."
|
||||||
|
${If} $DeleteDataEnabled == "1"
|
||||||
|
DetailPrint "User opted to delete Netbird data. Removing ${NETBIRD_DATA_DIR}..."
|
||||||
|
ClearErrors
|
||||||
|
RMDir /r "${NETBIRD_DATA_DIR}"
|
||||||
|
IfErrors 0 +2 ; If no errors, jump over the message
|
||||||
|
DetailPrint "Error deleting Netbird data directory. It might be in use or already removed."
|
||||||
|
DetailPrint "Netbird data directory removal complete."
|
||||||
|
${Else}
|
||||||
|
DetailPrint "User did not opt to delete Netbird data."
|
||||||
|
${EndIf}
|
||||||
|
|
||||||
# wait the service uninstall take unblock the executable
|
# wait the service uninstall take unblock the executable
|
||||||
|
DetailPrint "Waiting for service handle to be released..."
|
||||||
Sleep 3000
|
Sleep 3000
|
||||||
|
|
||||||
|
DetailPrint "Deleting application files..."
|
||||||
Delete "$INSTDIR\${UI_APP_EXE}"
|
Delete "$INSTDIR\${UI_APP_EXE}"
|
||||||
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
||||||
Delete "$INSTDIR\wintun.dll"
|
Delete "$INSTDIR\wintun.dll"
|
||||||
Delete "$INSTDIR\opengl32.dll"
|
Delete "$INSTDIR\opengl32.dll"
|
||||||
|
DetailPrint "Removing application directory..."
|
||||||
RmDir /r "$INSTDIR"
|
RmDir /r "$INSTDIR"
|
||||||
|
|
||||||
|
DetailPrint "Removing shortcuts..."
|
||||||
SetShellVarContext all
|
SetShellVarContext all
|
||||||
Delete "$DESKTOP\${APP_NAME}.lnk"
|
Delete "$DESKTOP\${APP_NAME}.lnk"
|
||||||
Delete "$SMPROGRAMS\${APP_NAME}.lnk"
|
Delete "$SMPROGRAMS\${APP_NAME}.lnk"
|
||||||
|
|
||||||
|
DetailPrint "Removing registry keys..."
|
||||||
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
|
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
|
||||||
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
|
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
|
||||||
|
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
|
||||||
|
|
||||||
|
DetailPrint "Removing application directory from PATH..."
|
||||||
EnVar::SetHKLM
|
EnVar::SetHKLM
|
||||||
EnVar::DeleteValue "path" "$INSTDIR"
|
EnVar::DeleteValue "path" "$INSTDIR"
|
||||||
|
|
||||||
|
DetailPrint "Uninstallation finished."
|
||||||
SectionEnd
|
SectionEnd
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -58,6 +58,11 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
|||||||
d.mutex.Lock()
|
d.mutex.Lock()
|
||||||
defer d.mutex.Unlock()
|
defer d.mutex.Unlock()
|
||||||
|
|
||||||
|
if d.firewall == nil {
|
||||||
|
log.Debug("firewall manager is not supported, skipping firewall rules")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
defer func() {
|
defer func() {
|
||||||
total := 0
|
total := 0
|
||||||
@@ -69,20 +74,8 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
|||||||
time.Since(start), total)
|
time.Since(start), total)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if d.firewall == nil {
|
|
||||||
log.Debug("firewall manager is not supported, skipping firewall rules")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
d.applyPeerACLs(networkMap)
|
d.applyPeerACLs(networkMap)
|
||||||
|
|
||||||
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
|
|
||||||
// then the mgmt server is older than the client, and we need to allow all traffic for routes
|
|
||||||
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
|
|
||||||
if err := d.firewall.SetLegacyManagement(isLegacy); err != nil {
|
|
||||||
log.Errorf("failed to set legacy management flag: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
|
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
|
||||||
log.Errorf("Failed to apply route ACLs: %v", err)
|
log.Errorf("Failed to apply route ACLs: %v", err)
|
||||||
}
|
}
|
||||||
@@ -291,8 +284,10 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
|||||||
case mgmProto.RuleDirection_IN:
|
case mgmProto.RuleDirection_IN:
|
||||||
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||||
case mgmProto.RuleDirection_OUT:
|
case mgmProto.RuleDirection_OUT:
|
||||||
// TODO: Remove this soon. Outbound rules are obsolete.
|
if d.firewall.IsStateful() {
|
||||||
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
|
return "", nil, nil
|
||||||
|
}
|
||||||
|
// return traffic for outbound connections if firewall is stateless
|
||||||
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||||
default:
|
default:
|
||||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||||
@@ -403,11 +398,15 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
//
|
//
|
||||||
// We zeroed this to notify squash function that this protocol can't be squashed.
|
// We zeroed this to notify squash function that this protocol can't be squashed.
|
||||||
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
|
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
|
||||||
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
|
hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
|
||||||
if drop {
|
r.Port != "" || !portInfoEmpty(r.PortInfo)
|
||||||
|
|
||||||
|
if hasPortRestrictions {
|
||||||
|
// Don't squash rules with port restrictions
|
||||||
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
|
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := protocols[r.Protocol]; !ok {
|
if _, ok := protocols[r.Protocol]; !ok {
|
||||||
protocols[r.Protocol] = &protoMatch{
|
protocols[r.Protocol] = &protoMatch{
|
||||||
ips: map[string]int{},
|
ips: map[string]int{},
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
package acl
|
package acl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
@@ -42,35 +43,31 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to parse IP address: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
IP: ip,
|
IP: network.Addr(),
|
||||||
Network: network,
|
Network: network,
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Errorf("create firewall: %v", err)
|
defer func() {
|
||||||
return
|
err = fw.Close(nil)
|
||||||
}
|
require.NoError(t, err)
|
||||||
defer func(fw manager.Manager) {
|
}()
|
||||||
_ = fw.Close(nil)
|
|
||||||
}(fw)
|
|
||||||
acl := NewDefaultManager(fw)
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
t.Run("apply firewall rules", func(t *testing.T) {
|
t.Run("apply firewall rules", func(t *testing.T) {
|
||||||
acl.ApplyFiltering(networkMap, false)
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
|
||||||
if len(acl.peerRulesPairs) != 2 {
|
if fw.IsStateful() {
|
||||||
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs)
|
assert.Equal(t, 0, len(acl.peerRulesPairs))
|
||||||
return
|
} else {
|
||||||
|
assert.Equal(t, 2, len(acl.peerRulesPairs))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -94,12 +91,13 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
|
|
||||||
acl.ApplyFiltering(networkMap, false)
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
|
||||||
// we should have one old and one new rule in the existed rules
|
expectedRules := 2
|
||||||
if len(acl.peerRulesPairs) != 2 {
|
if fw.IsStateful() {
|
||||||
t.Errorf("firewall rules not applied")
|
expectedRules = 1 // only the inbound rule
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
|
||||||
|
|
||||||
// check that old rule was removed
|
// check that old rule was removed
|
||||||
previousCount := 0
|
previousCount := 0
|
||||||
for id := range acl.peerRulesPairs {
|
for id := range acl.peerRulesPairs {
|
||||||
@@ -107,26 +105,86 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
previousCount++
|
previousCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if previousCount != 1 {
|
|
||||||
t.Errorf("old rule was not removed")
|
expectedPreviousCount := 0
|
||||||
|
if !fw.IsStateful() {
|
||||||
|
expectedPreviousCount = 1
|
||||||
}
|
}
|
||||||
|
assert.Equal(t, expectedPreviousCount, previousCount)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("handle default rules", func(t *testing.T) {
|
t.Run("handle default rules", func(t *testing.T) {
|
||||||
networkMap.FirewallRules = networkMap.FirewallRules[:0]
|
networkMap.FirewallRules = networkMap.FirewallRules[:0]
|
||||||
|
|
||||||
networkMap.FirewallRulesIsEmpty = true
|
networkMap.FirewallRulesIsEmpty = true
|
||||||
if acl.ApplyFiltering(networkMap, false); len(acl.peerRulesPairs) != 0 {
|
acl.ApplyFiltering(networkMap, false)
|
||||||
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs))
|
assert.Equal(t, 0, len(acl.peerRulesPairs))
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
networkMap.FirewallRulesIsEmpty = false
|
networkMap.FirewallRulesIsEmpty = false
|
||||||
acl.ApplyFiltering(networkMap, false)
|
acl.ApplyFiltering(networkMap, false)
|
||||||
if len(acl.peerRulesPairs) != 1 {
|
|
||||||
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
|
expectedRules := 1
|
||||||
return
|
if fw.IsStateful() {
|
||||||
|
expectedRules = 1 // only inbound allow-all rule
|
||||||
}
|
}
|
||||||
|
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultManagerStateless(t *testing.T) {
|
||||||
|
// stateless currently only in userspace, so we have to disable kernel
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
t.Setenv("NB_DISABLE_CONNTRACK", "true")
|
||||||
|
|
||||||
|
networkMap := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_OUT,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "80",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_UDP,
|
||||||
|
Port: "53",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
|
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
|
IP: network.Addr(),
|
||||||
|
Network: network,
|
||||||
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
err = fw.Close(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
|
t.Run("stateless firewall creates outbound rules", func(t *testing.T) {
|
||||||
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
|
||||||
|
// In stateless mode, we should have both inbound and outbound rules
|
||||||
|
assert.False(t, fw.IsStateful())
|
||||||
|
assert.Equal(t, 2, len(acl.peerRulesPairs))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -192,42 +250,19 @@ func TestDefaultManagerSquashRules(t *testing.T) {
|
|||||||
|
|
||||||
manager := &DefaultManager{}
|
manager := &DefaultManager{}
|
||||||
rules, _ := manager.squashAcceptRules(networkMap)
|
rules, _ := manager.squashAcceptRules(networkMap)
|
||||||
if len(rules) != 2 {
|
assert.Equal(t, 2, len(rules))
|
||||||
t.Errorf("rules should contain 2, got: %v", rules)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r := rules[0]
|
r := rules[0]
|
||||||
switch {
|
assert.Equal(t, "0.0.0.0", r.PeerIP)
|
||||||
case r.PeerIP != "0.0.0.0":
|
assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction)
|
||||||
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
|
||||||
return
|
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
|
||||||
case r.Direction != mgmProto.RuleDirection_IN:
|
|
||||||
t.Errorf("direction should be IN, got: %v", r.Direction)
|
|
||||||
return
|
|
||||||
case r.Protocol != mgmProto.RuleProtocol_ALL:
|
|
||||||
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
|
||||||
return
|
|
||||||
case r.Action != mgmProto.RuleAction_ACCEPT:
|
|
||||||
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r = rules[1]
|
r = rules[1]
|
||||||
switch {
|
assert.Equal(t, "0.0.0.0", r.PeerIP)
|
||||||
case r.PeerIP != "0.0.0.0":
|
assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction)
|
||||||
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
|
||||||
return
|
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
|
||||||
case r.Direction != mgmProto.RuleDirection_OUT:
|
|
||||||
t.Errorf("direction should be OUT, got: %v", r.Direction)
|
|
||||||
return
|
|
||||||
case r.Protocol != mgmProto.RuleProtocol_ALL:
|
|
||||||
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
|
||||||
return
|
|
||||||
case r.Action != mgmProto.RuleAction_ACCEPT:
|
|
||||||
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
||||||
@@ -291,8 +326,435 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
manager := &DefaultManager{}
|
manager := &DefaultManager{}
|
||||||
if rules, _ := manager.squashAcceptRules(networkMap); len(rules) != len(networkMap.FirewallRules) {
|
rules, _ := manager.squashAcceptRules(networkMap)
|
||||||
t.Errorf("we should get the same amount of rules as output, got %v", len(rules))
|
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rules []*mgmProto.FirewallRule
|
||||||
|
expectedCount int
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "should not squash rules with port ranges",
|
||||||
|
rules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Range_{
|
||||||
|
Range: &mgmProto.PortInfo_Range{
|
||||||
|
Start: 8080,
|
||||||
|
End: 8090,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Range_{
|
||||||
|
Range: &mgmProto.PortInfo_Range{
|
||||||
|
Start: 8080,
|
||||||
|
End: 8090,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Range_{
|
||||||
|
Range: &mgmProto.PortInfo_Range{
|
||||||
|
Start: 8080,
|
||||||
|
End: 8090,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Range_{
|
||||||
|
Range: &mgmProto.PortInfo_Range{
|
||||||
|
Start: 8080,
|
||||||
|
End: 8090,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCount: 4,
|
||||||
|
description: "Rules with port ranges should not be squashed even if they cover all peers",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should not squash rules with specific ports",
|
||||||
|
rules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Port{
|
||||||
|
Port: 80,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Port{
|
||||||
|
Port: 80,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Port{
|
||||||
|
Port: 80,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Port{
|
||||||
|
Port: 80,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCount: 4,
|
||||||
|
description: "Rules with specific ports should not be squashed even if they cover all peers",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should not squash rules with legacy port field",
|
||||||
|
rules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCount: 4,
|
||||||
|
description: "Rules with legacy port field should not be squashed",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should not squash rules with DROP action",
|
||||||
|
rules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCount: 4,
|
||||||
|
description: "Rules with DROP action should not be squashed",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should squash rules without port restrictions",
|
||||||
|
rules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCount: 1,
|
||||||
|
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed rules should not squash protocol with port restrictions",
|
||||||
|
rules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Port{
|
||||||
|
Port: 80,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCount: 4,
|
||||||
|
description: "TCP should not be squashed because one rule has port restrictions",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should squash UDP but not TCP when TCP has port restrictions",
|
||||||
|
rules: []*mgmProto.FirewallRule{
|
||||||
|
// TCP rules with port restrictions - should NOT be squashed
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
// UDP rules without port restrictions - SHOULD be squashed
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_UDP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_UDP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_UDP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_UDP,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
|
||||||
|
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
networkMap := &mgmProto.NetworkMap{
|
||||||
|
RemotePeers: []*mgmProto.RemotePeerConfig{
|
||||||
|
{AllowedIps: []string{"10.93.0.1"}},
|
||||||
|
{AllowedIps: []string{"10.93.0.2"}},
|
||||||
|
{AllowedIps: []string{"10.93.0.3"}},
|
||||||
|
{AllowedIps: []string{"10.93.0.4"}},
|
||||||
|
},
|
||||||
|
FirewallRules: tt.rules,
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := &DefaultManager{}
|
||||||
|
rules, _ := manager.squashAcceptRules(networkMap)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
|
||||||
|
|
||||||
|
// For squashed rules, verify we get the expected 0.0.0.0 rule
|
||||||
|
if tt.expectedCount == 1 {
|
||||||
|
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
|
||||||
|
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
|
||||||
|
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPortInfoEmpty(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
portInfo *mgmProto.PortInfo
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil PortInfo should be empty",
|
||||||
|
portInfo: nil,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "PortInfo with zero port should be empty",
|
||||||
|
portInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Port{
|
||||||
|
Port: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "PortInfo with valid port should not be empty",
|
||||||
|
portInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Port{
|
||||||
|
Port: 80,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "PortInfo with nil range should be empty",
|
||||||
|
portInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Range_{
|
||||||
|
Range: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "PortInfo with zero start range should be empty",
|
||||||
|
portInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Range_{
|
||||||
|
Range: &mgmProto.PortInfo_Range{
|
||||||
|
Start: 0,
|
||||||
|
End: 100,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "PortInfo with zero end range should be empty",
|
||||||
|
portInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Range_{
|
||||||
|
Range: &mgmProto.PortInfo_Range{
|
||||||
|
Start: 80,
|
||||||
|
End: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "PortInfo with valid range should not be empty",
|
||||||
|
portInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Range_{
|
||||||
|
Range: &mgmProto.PortInfo_Range{
|
||||||
|
Start: 8080,
|
||||||
|
End: 8090,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := portInfoEmpty(tt.portInfo)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -336,33 +798,29 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to parse IP address: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
IP: ip,
|
IP: network.Addr(),
|
||||||
Network: network,
|
Network: network,
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Errorf("create firewall: %v", err)
|
defer func() {
|
||||||
return
|
err = fw.Close(nil)
|
||||||
}
|
require.NoError(t, err)
|
||||||
defer func(fw manager.Manager) {
|
}()
|
||||||
_ = fw.Close(nil)
|
|
||||||
}(fw)
|
|
||||||
acl := NewDefaultManager(fw)
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
acl.ApplyFiltering(networkMap, false)
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
|
||||||
if len(acl.peerRulesPairs) != 3 {
|
expectedRules := 3
|
||||||
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
|
if fw.IsStateful() {
|
||||||
return
|
expectedRules = 3 // 2 inbound rules + SSH rule
|
||||||
}
|
}
|
||||||
|
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,13 +64,8 @@ func (t TokenInfo) GetTokenToUse() string {
|
|||||||
// and if that also fails, the authentication process is deemed unsuccessful
|
// and if that also fails, the authentication process is deemed unsuccessful
|
||||||
//
|
//
|
||||||
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||||
func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopClient bool) (OAuthFlow, error) {
|
func NewOAuthFlow(ctx context.Context, config *internal.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
|
||||||
if runtime.GOOS == "linux" && !isLinuxDesktopClient {
|
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
|
||||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
|
||||||
}
|
|
||||||
|
|
||||||
// On FreeBSD we currently do not support desktop environments and offer only Device Code Flow (#2384)
|
|
||||||
if runtime.GOOS == "freebsd" {
|
|
||||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -101,8 +101,13 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
|||||||
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
||||||
}
|
}
|
||||||
if !p.providerConfig.DisablePromptLogin {
|
if !p.providerConfig.DisablePromptLogin {
|
||||||
|
if p.providerConfig.LoginFlag.IsPromptLogin() {
|
||||||
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
||||||
}
|
}
|
||||||
|
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
|
||||||
|
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
||||||
|
|
||||||
|
|||||||
@@ -7,15 +7,36 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
mgm "github.com/netbirdio/netbird/management/client/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPromptLogin(t *testing.T) {
|
func TestPromptLogin(t *testing.T) {
|
||||||
|
const (
|
||||||
|
promptLogin = "prompt=login"
|
||||||
|
maxAge0 = "max_age=0"
|
||||||
|
)
|
||||||
|
|
||||||
tt := []struct {
|
tt := []struct {
|
||||||
name string
|
name string
|
||||||
prompt bool
|
loginFlag mgm.LoginFlag
|
||||||
|
disablePromptLogin bool
|
||||||
|
expect string
|
||||||
}{
|
}{
|
||||||
{"PromptLogin", true},
|
{
|
||||||
{"NoPromptLogin", false},
|
name: "Prompt login",
|
||||||
|
loginFlag: mgm.LoginFlagPrompt,
|
||||||
|
expect: promptLogin,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Max age 0 login",
|
||||||
|
loginFlag: mgm.LoginFlagMaxAge0,
|
||||||
|
expect: maxAge0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Disable prompt login",
|
||||||
|
loginFlag: mgm.LoginFlagPrompt,
|
||||||
|
disablePromptLogin: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
@@ -28,7 +49,7 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
|
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
|
||||||
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
||||||
UseIDToken: true,
|
UseIDToken: true,
|
||||||
DisablePromptLogin: !tc.prompt,
|
LoginFlag: tc.loginFlag,
|
||||||
}
|
}
|
||||||
pkce, err := NewPKCEAuthorizationFlow(config)
|
pkce, err := NewPKCEAuthorizationFlow(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -38,11 +59,12 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to request auth info: %v", err)
|
t.Fatalf("Failed to request auth info: %v", err)
|
||||||
}
|
}
|
||||||
pattern := "prompt=login"
|
|
||||||
if tc.prompt {
|
if !tc.disablePromptLogin {
|
||||||
require.Contains(t, authInfo.VerificationURIComplete, pattern)
|
require.Contains(t, authInfo.VerificationURIComplete, tc.expect)
|
||||||
} else {
|
} else {
|
||||||
require.NotContains(t, authInfo.VerificationURIComplete, pattern)
|
require.Contains(t, authInfo.VerificationURIComplete, promptLogin)
|
||||||
|
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -68,12 +68,14 @@ type ConfigInput struct {
|
|||||||
DisableServerRoutes *bool
|
DisableServerRoutes *bool
|
||||||
DisableDNS *bool
|
DisableDNS *bool
|
||||||
DisableFirewall *bool
|
DisableFirewall *bool
|
||||||
|
|
||||||
BlockLANAccess *bool
|
BlockLANAccess *bool
|
||||||
|
BlockInbound *bool
|
||||||
|
|
||||||
DisableNotifications *bool
|
DisableNotifications *bool
|
||||||
|
|
||||||
DNSLabels domain.List
|
DNSLabels domain.List
|
||||||
|
|
||||||
|
LazyConnectionEnabled *bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config Configuration type
|
// Config Configuration type
|
||||||
@@ -96,8 +98,8 @@ type Config struct {
|
|||||||
DisableServerRoutes bool
|
DisableServerRoutes bool
|
||||||
DisableDNS bool
|
DisableDNS bool
|
||||||
DisableFirewall bool
|
DisableFirewall bool
|
||||||
|
|
||||||
BlockLANAccess bool
|
BlockLANAccess bool
|
||||||
|
BlockInbound bool
|
||||||
|
|
||||||
DisableNotifications *bool
|
DisableNotifications *bool
|
||||||
|
|
||||||
@@ -138,6 +140,8 @@ type Config struct {
|
|||||||
ClientCertKeyPath string
|
ClientCertKeyPath string
|
||||||
|
|
||||||
ClientCertKeyPair *tls.Certificate `json:"-"`
|
ClientCertKeyPair *tls.Certificate `json:"-"`
|
||||||
|
|
||||||
|
LazyConnectionEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||||
@@ -219,6 +223,8 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
|||||||
config := &Config{
|
config := &Config{
|
||||||
// defaults to false only for new (post 0.26) configurations
|
// defaults to false only for new (post 0.26) configurations
|
||||||
ServerSSHAllowed: util.False(),
|
ServerSSHAllowed: util.False(),
|
||||||
|
// default to disabling server routes on Android for security
|
||||||
|
DisableServerRoutes: runtime.GOOS == "android",
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := config.apply(input); err != nil {
|
if _, err := config.apply(input); err != nil {
|
||||||
@@ -313,10 +319,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
*input.WireguardPort, config.WgPort)
|
*input.WireguardPort, config.WgPort)
|
||||||
config.WgPort = *input.WireguardPort
|
config.WgPort = *input.WireguardPort
|
||||||
updated = true
|
updated = true
|
||||||
} else if config.WgPort == 0 {
|
|
||||||
config.WgPort = iface.DefaultWgPort
|
|
||||||
log.Infof("using default Wireguard port %d", config.WgPort)
|
|
||||||
updated = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.InterfaceName != nil && *input.InterfaceName != config.WgIface {
|
if input.InterfaceName != nil && *input.InterfaceName != config.WgIface {
|
||||||
@@ -412,9 +414,15 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
config.ServerSSHAllowed = input.ServerSSHAllowed
|
config.ServerSSHAllowed = input.ServerSSHAllowed
|
||||||
updated = true
|
updated = true
|
||||||
} else if config.ServerSSHAllowed == nil {
|
} else if config.ServerSSHAllowed == nil {
|
||||||
|
if runtime.GOOS == "android" {
|
||||||
|
// default to disabled SSH on Android for security
|
||||||
|
log.Infof("setting SSH server to false by default on Android")
|
||||||
|
config.ServerSSHAllowed = util.False()
|
||||||
|
} else {
|
||||||
// enables SSH for configs from old versions to preserve backwards compatibility
|
// enables SSH for configs from old versions to preserve backwards compatibility
|
||||||
log.Infof("falling back to enabled SSH server for pre-existing configuration")
|
log.Infof("falling back to enabled SSH server for pre-existing configuration")
|
||||||
config.ServerSSHAllowed = util.True()
|
config.ServerSSHAllowed = util.True()
|
||||||
|
}
|
||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -479,6 +487,16 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.BlockInbound != nil && *input.BlockInbound != config.BlockInbound {
|
||||||
|
if *input.BlockInbound {
|
||||||
|
log.Infof("blocking inbound connections")
|
||||||
|
} else {
|
||||||
|
log.Infof("allowing inbound connections")
|
||||||
|
}
|
||||||
|
config.BlockInbound = *input.BlockInbound
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications {
|
if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications {
|
||||||
if *input.DisableNotifications {
|
if *input.DisableNotifications {
|
||||||
log.Infof("disabling notifications")
|
log.Infof("disabling notifications")
|
||||||
@@ -524,6 +542,12 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.LazyConnectionEnabled != nil && *input.LazyConnectionEnabled != config.LazyConnectionEnabled {
|
||||||
|
log.Infof("switching lazy connection to %t", *input.LazyConnectionEnabled)
|
||||||
|
config.LazyConnectionEnabled = *input.LazyConnectionEnabled
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
return updated, nil
|
return updated, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
312
client/internal/conn_mgr.go
Normal file
312
client/internal/conn_mgr.go
Normal file
@@ -0,0 +1,312 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/lazyconn/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConnMgr coordinates both lazy connections (established on-demand) and permanent peer connections.
|
||||||
|
//
|
||||||
|
// The connection manager is responsible for:
|
||||||
|
// - Managing lazy connections via the lazyConnManager
|
||||||
|
// - Maintaining a list of excluded peers that should always have permanent connections
|
||||||
|
// - Handling connection establishment based on peer signaling
|
||||||
|
//
|
||||||
|
// The implementation is not thread-safe; it is protected by engine.syncMsgMux.
|
||||||
|
type ConnMgr struct {
|
||||||
|
peerStore *peerstore.Store
|
||||||
|
statusRecorder *peer.Status
|
||||||
|
iface lazyconn.WGIface
|
||||||
|
dispatcher *dispatcher.ConnectionDispatcher
|
||||||
|
enabledLocally bool
|
||||||
|
|
||||||
|
lazyConnMgr *manager.Manager
|
||||||
|
|
||||||
|
wg sync.WaitGroup
|
||||||
|
lazyCtx context.Context
|
||||||
|
lazyCtxCancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface, dispatcher *dispatcher.ConnectionDispatcher) *ConnMgr {
|
||||||
|
e := &ConnMgr{
|
||||||
|
peerStore: peerStore,
|
||||||
|
statusRecorder: statusRecorder,
|
||||||
|
iface: iface,
|
||||||
|
dispatcher: dispatcher,
|
||||||
|
}
|
||||||
|
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
|
||||||
|
e.enabledLocally = true
|
||||||
|
}
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start initializes the connection manager and starts the lazy connection manager if enabled by env var or cmd line option.
|
||||||
|
func (e *ConnMgr) Start(ctx context.Context) {
|
||||||
|
if e.lazyConnMgr != nil {
|
||||||
|
log.Errorf("lazy connection manager is already started")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !e.enabledLocally {
|
||||||
|
log.Infof("lazy connection manager is disabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e.initLazyManager(ctx)
|
||||||
|
e.statusRecorder.UpdateLazyConnection(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedRemoteFeatureFlag is called when the remote feature flag is updated.
|
||||||
|
// If enabled, it initializes the lazy connection manager and start it. Do not need to call Start() again.
|
||||||
|
// If disabled, then it closes the lazy connection manager and open the connections to all peers.
|
||||||
|
func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) error {
|
||||||
|
// do not disable lazy connection manager if it was enabled by env var
|
||||||
|
if e.enabledLocally {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if enabled {
|
||||||
|
// if the lazy connection manager is already started, do not start it again
|
||||||
|
if e.lazyConnMgr != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("lazy connection manager is enabled by management feature flag")
|
||||||
|
e.initLazyManager(ctx)
|
||||||
|
e.statusRecorder.UpdateLazyConnection(true)
|
||||||
|
return e.addPeersToLazyConnManager()
|
||||||
|
} else {
|
||||||
|
if e.lazyConnMgr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
log.Infof("lazy connection manager is disabled by management feature flag")
|
||||||
|
e.closeManager(ctx)
|
||||||
|
e.statusRecorder.UpdateLazyConnection(false)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRouteHAMap updates the route HA mappings in the lazy connection manager
|
||||||
|
func (e *ConnMgr) UpdateRouteHAMap(haMap route.HAMap) {
|
||||||
|
if !e.isStartedWithLazyMgr() {
|
||||||
|
log.Debugf("lazy connection manager is not started, skipping UpdateRouteHAMap")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e.lazyConnMgr.UpdateRouteHAMap(haMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetExcludeList sets the list of peer IDs that should always have permanent connections.
|
||||||
|
func (e *ConnMgr) SetExcludeList(ctx context.Context, peerIDs map[string]bool) {
|
||||||
|
if e.lazyConnMgr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
excludedPeers := make([]lazyconn.PeerConfig, 0, len(peerIDs))
|
||||||
|
|
||||||
|
for peerID := range peerIDs {
|
||||||
|
var peerConn *peer.Conn
|
||||||
|
var exists bool
|
||||||
|
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
|
||||||
|
log.Warnf("failed to find peer conn for peerID: %s", peerID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
lazyPeerCfg := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peerID,
|
||||||
|
AllowedIPs: peerConn.WgConfig().AllowedIps,
|
||||||
|
PeerConnID: peerConn.ConnID(),
|
||||||
|
Log: peerConn.Log,
|
||||||
|
}
|
||||||
|
excludedPeers = append(excludedPeers, lazyPeerCfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
added := e.lazyConnMgr.ExcludePeer(e.lazyCtx, excludedPeers)
|
||||||
|
for _, peerID := range added {
|
||||||
|
var peerConn *peer.Conn
|
||||||
|
var exists bool
|
||||||
|
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
|
||||||
|
// if the peer not exist in the store, it means that the engine will call the AddPeerConn in next step
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
peerConn.Log.Infof("peer has been added to lazy connection exclude list, opening permanent connection")
|
||||||
|
if err := peerConn.Open(ctx); err != nil {
|
||||||
|
peerConn.Log.Errorf("failed to open connection: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Conn) (exists bool) {
|
||||||
|
if success := e.peerStore.AddPeerConn(peerKey, conn); !success {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if !e.isStartedWithLazyMgr() {
|
||||||
|
if err := conn.Open(ctx); err != nil {
|
||||||
|
conn.Log.Errorf("failed to open connection: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !lazyconn.IsSupported(conn.AgentVersionString()) {
|
||||||
|
conn.Log.Warnf("peer does not support lazy connection (%s), open permanent connection", conn.AgentVersionString())
|
||||||
|
if err := conn.Open(ctx); err != nil {
|
||||||
|
conn.Log.Errorf("failed to open connection: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
lazyPeerCfg := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peerKey,
|
||||||
|
AllowedIPs: conn.WgConfig().AllowedIps,
|
||||||
|
PeerConnID: conn.ConnID(),
|
||||||
|
Log: conn.Log,
|
||||||
|
}
|
||||||
|
excluded, err := e.lazyConnMgr.AddPeer(e.lazyCtx, lazyPeerCfg)
|
||||||
|
if err != nil {
|
||||||
|
conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err)
|
||||||
|
if err := conn.Open(ctx); err != nil {
|
||||||
|
conn.Log.Errorf("failed to open connection: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if excluded {
|
||||||
|
conn.Log.Infof("peer is on lazy conn manager exclude list, opening connection")
|
||||||
|
if err := conn.Open(ctx); err != nil {
|
||||||
|
conn.Log.Errorf("failed to open connection: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.Log.Infof("peer added to lazy conn manager")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) RemovePeerConn(peerKey string) {
|
||||||
|
conn, ok := e.peerStore.Remove(peerKey)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
if !e.isStartedWithLazyMgr() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e.lazyConnMgr.RemovePeer(peerKey)
|
||||||
|
conn.Log.Infof("removed peer from lazy conn manager")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) OnSignalMsg(ctx context.Context, peerKey string) (*peer.Conn, bool) {
|
||||||
|
conn, ok := e.peerStore.PeerConn(peerKey)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !e.isStartedWithLazyMgr() {
|
||||||
|
return conn, true
|
||||||
|
}
|
||||||
|
|
||||||
|
if found := e.lazyConnMgr.ActivatePeer(e.lazyCtx, peerKey); found {
|
||||||
|
conn.Log.Infof("activated peer from inactive state")
|
||||||
|
if err := conn.Open(ctx); err != nil {
|
||||||
|
conn.Log.Errorf("failed to open connection: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return conn, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) Close() {
|
||||||
|
if !e.isStartedWithLazyMgr() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e.lazyCtxCancel()
|
||||||
|
e.wg.Wait()
|
||||||
|
e.lazyConnMgr = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) initLazyManager(engineCtx context.Context) {
|
||||||
|
cfg := manager.Config{
|
||||||
|
InactivityThreshold: inactivityThresholdEnv(),
|
||||||
|
}
|
||||||
|
e.lazyConnMgr = manager.NewManager(cfg, engineCtx, e.peerStore, e.iface, e.dispatcher)
|
||||||
|
|
||||||
|
e.lazyCtx, e.lazyCtxCancel = context.WithCancel(engineCtx)
|
||||||
|
|
||||||
|
e.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer e.wg.Done()
|
||||||
|
e.lazyConnMgr.Start(e.lazyCtx)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) addPeersToLazyConnManager() error {
|
||||||
|
peers := e.peerStore.PeersPubKey()
|
||||||
|
lazyPeerCfgs := make([]lazyconn.PeerConfig, 0, len(peers))
|
||||||
|
for _, peerID := range peers {
|
||||||
|
var peerConn *peer.Conn
|
||||||
|
var exists bool
|
||||||
|
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
|
||||||
|
log.Warnf("failed to find peer conn for peerID: %s", peerID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
lazyPeerCfg := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peerID,
|
||||||
|
AllowedIPs: peerConn.WgConfig().AllowedIps,
|
||||||
|
PeerConnID: peerConn.ConnID(),
|
||||||
|
Log: peerConn.Log,
|
||||||
|
}
|
||||||
|
lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.lazyConnMgr.AddActivePeers(e.lazyCtx, lazyPeerCfgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) closeManager(ctx context.Context) {
|
||||||
|
if e.lazyConnMgr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e.lazyCtxCancel()
|
||||||
|
e.wg.Wait()
|
||||||
|
e.lazyConnMgr = nil
|
||||||
|
|
||||||
|
for _, peerID := range e.peerStore.PeersPubKey() {
|
||||||
|
e.peerStore.PeerConnOpen(ctx, peerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) isStartedWithLazyMgr() bool {
|
||||||
|
return e.lazyConnMgr != nil && e.lazyCtxCancel != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func inactivityThresholdEnv() *time.Duration {
|
||||||
|
envValue := os.Getenv(lazyconn.EnvInactivityThreshold)
|
||||||
|
if envValue == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedMinutes, err := strconv.Atoi(envValue)
|
||||||
|
if err != nil || parsedMinutes <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
d := time.Duration(parsedMinutes) * time.Minute
|
||||||
|
return &d
|
||||||
|
}
|
||||||
@@ -17,7 +17,6 @@ import (
|
|||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
@@ -436,11 +435,13 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
|||||||
DNSRouteInterval: config.DNSRouteInterval,
|
DNSRouteInterval: config.DNSRouteInterval,
|
||||||
|
|
||||||
DisableClientRoutes: config.DisableClientRoutes,
|
DisableClientRoutes: config.DisableClientRoutes,
|
||||||
DisableServerRoutes: config.DisableServerRoutes,
|
DisableServerRoutes: config.DisableServerRoutes || config.BlockInbound,
|
||||||
DisableDNS: config.DisableDNS,
|
DisableDNS: config.DisableDNS,
|
||||||
DisableFirewall: config.DisableFirewall,
|
DisableFirewall: config.DisableFirewall,
|
||||||
|
|
||||||
BlockLANAccess: config.BlockLANAccess,
|
BlockLANAccess: config.BlockLANAccess,
|
||||||
|
BlockInbound: config.BlockInbound,
|
||||||
|
|
||||||
|
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.PreSharedKey != "" {
|
if config.PreSharedKey != "" {
|
||||||
@@ -481,7 +482,7 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
|
|||||||
return signalClient, nil
|
return signalClient, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
||||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
|
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
|
||||||
|
|
||||||
serverPublicKey, err := client.GetServerPublicKey()
|
serverPublicKey, err := client.GetServerPublicKey()
|
||||||
@@ -498,6 +499,9 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
|||||||
config.DisableServerRoutes,
|
config.DisableServerRoutes,
|
||||||
config.DisableDNS,
|
config.DisableDNS,
|
||||||
config.DisableFirewall,
|
config.DisableFirewall,
|
||||||
|
config.BlockLANAccess,
|
||||||
|
config.BlockInbound,
|
||||||
|
config.LazyConnectionEnabled,
|
||||||
)
|
)
|
||||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -521,17 +525,13 @@ func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal
|
|||||||
|
|
||||||
// freePort attempts to determine if the provided port is available, if not it will ask the system for a free port.
|
// freePort attempts to determine if the provided port is available, if not it will ask the system for a free port.
|
||||||
func freePort(initPort int) (int, error) {
|
func freePort(initPort int) (int, error) {
|
||||||
addr := net.UDPAddr{}
|
addr := net.UDPAddr{Port: initPort}
|
||||||
if initPort == 0 {
|
|
||||||
initPort = iface.DefaultWgPort
|
|
||||||
}
|
|
||||||
|
|
||||||
addr.Port = initPort
|
|
||||||
|
|
||||||
conn, err := net.ListenUDP("udp", &addr)
|
conn, err := net.ListenUDP("udp", &addr)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
returnPort := conn.LocalAddr().(*net.UDPAddr).Port
|
||||||
closeConnWithLog(conn)
|
closeConnWithLog(conn)
|
||||||
return initPort, nil
|
return returnPort, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// if the port is already in use, ask the system for a free port
|
// if the port is already in use, ask the system for a free port
|
||||||
|
|||||||
@@ -13,10 +13,10 @@ func Test_freePort(t *testing.T) {
|
|||||||
shouldMatch bool
|
shouldMatch bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "not provided, fallback to default",
|
name: "when port is 0 use random port",
|
||||||
port: 0,
|
port: 0,
|
||||||
want: 51820,
|
want: 0,
|
||||||
shouldMatch: true,
|
shouldMatch: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "provided and available",
|
name: "provided and available",
|
||||||
@@ -31,7 +31,7 @@ func Test_freePort(t *testing.T) {
|
|||||||
shouldMatch: false,
|
shouldMatch: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830})
|
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 0})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("freePort error = %v", err)
|
t.Errorf("freePort error = %v", err)
|
||||||
}
|
}
|
||||||
@@ -39,6 +39,14 @@ func Test_freePort(t *testing.T) {
|
|||||||
_ = c1.Close()
|
_ = c1.Close()
|
||||||
}(c1)
|
}(c1)
|
||||||
|
|
||||||
|
if tests[1].port == c1.LocalAddr().(*net.UDPAddr).Port {
|
||||||
|
tests[1].port++
|
||||||
|
tests[1].want++
|
||||||
|
}
|
||||||
|
|
||||||
|
tests[2].port = c1.LocalAddr().(*net.UDPAddr).Port
|
||||||
|
tests[2].want = c1.LocalAddr().(*net.UDPAddr).Port
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"archive/zip"
|
"archive/zip"
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -269,11 +270,21 @@ func (g *BundleGenerator) createArchive() error {
|
|||||||
log.Errorf("Failed to add corrupted state files to debug bundle: %v", err)
|
log.Errorf("Failed to add corrupted state files to debug bundle: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if g.logFile != "console" {
|
if err := g.addWgShow(); err != nil {
|
||||||
|
log.Errorf("Failed to add wg show output: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if g.logFile != "console" && g.logFile != "" {
|
||||||
if err := g.addLogfile(); err != nil {
|
if err := g.addLogfile(); err != nil {
|
||||||
return fmt.Errorf("add log file: %w", err)
|
log.Errorf("Failed to add log file to debug bundle: %v", err)
|
||||||
|
if err := g.trySystemdLogFallback(); err != nil {
|
||||||
|
log.Errorf("Failed to add systemd logs as fallback: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if err := g.trySystemdLogFallback(); err != nil {
|
||||||
|
log.Errorf("Failed to add systemd logs: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -365,17 +376,34 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
|||||||
configContent.WriteString(fmt.Sprintf("RosenpassEnabled: %v\n", g.internalConfig.RosenpassEnabled))
|
configContent.WriteString(fmt.Sprintf("RosenpassEnabled: %v\n", g.internalConfig.RosenpassEnabled))
|
||||||
configContent.WriteString(fmt.Sprintf("RosenpassPermissive: %v\n", g.internalConfig.RosenpassPermissive))
|
configContent.WriteString(fmt.Sprintf("RosenpassPermissive: %v\n", g.internalConfig.RosenpassPermissive))
|
||||||
if g.internalConfig.ServerSSHAllowed != nil {
|
if g.internalConfig.ServerSSHAllowed != nil {
|
||||||
configContent.WriteString(fmt.Sprintf("BundleGeneratorSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed))
|
configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed))
|
||||||
}
|
}
|
||||||
configContent.WriteString(fmt.Sprintf("DisableAutoConnect: %v\n", g.internalConfig.DisableAutoConnect))
|
|
||||||
configContent.WriteString(fmt.Sprintf("DNSRouteInterval: %s\n", g.internalConfig.DNSRouteInterval))
|
|
||||||
|
|
||||||
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
||||||
configContent.WriteString(fmt.Sprintf("DisableBundleGeneratorRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
||||||
configContent.WriteString(fmt.Sprintf("DisableDNS: %v\n", g.internalConfig.DisableDNS))
|
configContent.WriteString(fmt.Sprintf("DisableDNS: %v\n", g.internalConfig.DisableDNS))
|
||||||
configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall))
|
configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall))
|
||||||
|
|
||||||
configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess))
|
configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess))
|
||||||
|
configContent.WriteString(fmt.Sprintf("BlockInbound: %v\n", g.internalConfig.BlockInbound))
|
||||||
|
|
||||||
|
if g.internalConfig.DisableNotifications != nil {
|
||||||
|
configContent.WriteString(fmt.Sprintf("DisableNotifications: %v\n", *g.internalConfig.DisableNotifications))
|
||||||
|
}
|
||||||
|
|
||||||
|
configContent.WriteString(fmt.Sprintf("DNSLabels: %v\n", g.internalConfig.DNSLabels))
|
||||||
|
|
||||||
|
configContent.WriteString(fmt.Sprintf("DisableAutoConnect: %v\n", g.internalConfig.DisableAutoConnect))
|
||||||
|
|
||||||
|
configContent.WriteString(fmt.Sprintf("DNSRouteInterval: %s\n", g.internalConfig.DNSRouteInterval))
|
||||||
|
|
||||||
|
if g.internalConfig.ClientCertPath != "" {
|
||||||
|
configContent.WriteString(fmt.Sprintf("ClientCertPath: %s\n", g.internalConfig.ClientCertPath))
|
||||||
|
}
|
||||||
|
if g.internalConfig.ClientCertKeyPath != "" {
|
||||||
|
configContent.WriteString(fmt.Sprintf("ClientCertKeyPath: %s\n", g.internalConfig.ClientCertKeyPath))
|
||||||
|
}
|
||||||
|
|
||||||
|
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *BundleGenerator) addProf() (err error) {
|
func (g *BundleGenerator) addProf() (err error) {
|
||||||
@@ -533,6 +561,33 @@ func (g *BundleGenerator) addLogfile() error {
|
|||||||
return fmt.Errorf("add client log file to zip: %w", err)
|
return fmt.Errorf("add client log file to zip: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add latest rotated log file
|
||||||
|
pattern := filepath.Join(logDir, "client-*.log.gz")
|
||||||
|
files, err := filepath.Glob(pattern)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to glob rotated logs: %v", err)
|
||||||
|
} else if len(files) > 0 {
|
||||||
|
// pick the file with the latest ModTime
|
||||||
|
sort.Slice(files, func(i, j int) bool {
|
||||||
|
fi, err := os.Stat(files[i])
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to stat rotated log %s: %v", files[i], err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
fj, err := os.Stat(files[j])
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to stat rotated log %s: %v", files[j], err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return fi.ModTime().Before(fj.ModTime())
|
||||||
|
})
|
||||||
|
latest := files[len(files)-1]
|
||||||
|
name := filepath.Base(latest)
|
||||||
|
if err := g.addSingleLogFileGz(latest, name); err != nil {
|
||||||
|
log.Warnf("failed to add rotated log %s: %v", name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
stdErrLogPath := filepath.Join(logDir, errorLogFile)
|
stdErrLogPath := filepath.Join(logDir, errorLogFile)
|
||||||
stdoutLogPath := filepath.Join(logDir, stdoutLogFile)
|
stdoutLogPath := filepath.Join(logDir, stdoutLogFile)
|
||||||
if runtime.GOOS == "darwin" {
|
if runtime.GOOS == "darwin" {
|
||||||
@@ -563,16 +618,13 @@ func (g *BundleGenerator) addSingleLogfile(logPath, targetName string) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var logReader io.Reader
|
var logReader io.Reader = logFile
|
||||||
if g.anonymize {
|
if g.anonymize {
|
||||||
var writer *io.PipeWriter
|
var writer *io.PipeWriter
|
||||||
logReader, writer = io.Pipe()
|
logReader, writer = io.Pipe()
|
||||||
|
|
||||||
go anonymizeLog(logFile, writer, g.anonymizer)
|
go anonymizeLog(logFile, writer, g.anonymizer)
|
||||||
} else {
|
|
||||||
logReader = logFile
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := g.addFileToZip(logReader, targetName); err != nil {
|
if err := g.addFileToZip(logReader, targetName); err != nil {
|
||||||
return fmt.Errorf("add %s to zip: %w", targetName, err)
|
return fmt.Errorf("add %s to zip: %w", targetName, err)
|
||||||
}
|
}
|
||||||
@@ -580,6 +632,44 @@ func (g *BundleGenerator) addSingleLogfile(logPath, targetName string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// addSingleLogFileGz adds a single gzipped log file to the archive
|
||||||
|
func (g *BundleGenerator) addSingleLogFileGz(logPath, targetName string) error {
|
||||||
|
f, err := os.Open(logPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("open gz log file %s: %w", targetName, err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
gzr, err := gzip.NewReader(f)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create gzip reader: %w", err)
|
||||||
|
}
|
||||||
|
defer gzr.Close()
|
||||||
|
|
||||||
|
var logReader io.Reader = gzr
|
||||||
|
if g.anonymize {
|
||||||
|
var pw *io.PipeWriter
|
||||||
|
logReader, pw = io.Pipe()
|
||||||
|
go anonymizeLog(gzr, pw, g.anonymizer)
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
gw := gzip.NewWriter(&buf)
|
||||||
|
if _, err := io.Copy(gw, logReader); err != nil {
|
||||||
|
return fmt.Errorf("re-gzip: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := gw.Close(); err != nil {
|
||||||
|
return fmt.Errorf("close gzip writer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.addFileToZip(&buf, targetName); err != nil {
|
||||||
|
return fmt.Errorf("add anonymized gz: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (g *BundleGenerator) addFileToZip(reader io.Reader, filename string) error {
|
func (g *BundleGenerator) addFileToZip(reader io.Reader, filename string) error {
|
||||||
header := &zip.FileHeader{
|
header := &zip.FileHeader{
|
||||||
Name: filename,
|
Name: filename,
|
||||||
|
|||||||
@@ -4,17 +4,104 @@ package debug
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxLogEntries = 100000
|
||||||
|
maxLogAge = 7 * 24 * time.Hour // Last 7 days
|
||||||
|
)
|
||||||
|
|
||||||
|
// trySystemdLogFallback attempts to get logs from systemd journal as fallback
|
||||||
|
func (g *BundleGenerator) trySystemdLogFallback() error {
|
||||||
|
log.Debug("Attempting to collect systemd journal logs")
|
||||||
|
|
||||||
|
serviceName := getServiceName()
|
||||||
|
journalLogs, err := getSystemdLogs(serviceName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get systemd logs for %s: %w", serviceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(journalLogs, "No recent log entries found") {
|
||||||
|
log.Debug("No recent log entries found in systemd journal")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if g.anonymize {
|
||||||
|
journalLogs = g.anonymizer.AnonymizeString(journalLogs)
|
||||||
|
}
|
||||||
|
|
||||||
|
logReader := strings.NewReader(journalLogs)
|
||||||
|
fileName := fmt.Sprintf("systemd-%s.log", serviceName)
|
||||||
|
if err := g.addFileToZip(logReader, fileName); err != nil {
|
||||||
|
return fmt.Errorf("add systemd logs to bundle: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Added systemd journal logs for %s to debug bundle", serviceName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getServiceName gets the service name from environment or defaults to netbird
|
||||||
|
func getServiceName() string {
|
||||||
|
if unitName := os.Getenv("SYSTEMD_UNIT"); unitName != "" {
|
||||||
|
log.Debugf("Detected SYSTEMD_UNIT environment variable: %s", unitName)
|
||||||
|
return unitName
|
||||||
|
}
|
||||||
|
|
||||||
|
return "netbird"
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSystemdLogs retrieves logs from systemd journal for a specific service using journalctl
|
||||||
|
func getSystemdLogs(serviceName string) (string, error) {
|
||||||
|
args := []string{
|
||||||
|
"-u", fmt.Sprintf("%s.service", serviceName),
|
||||||
|
"--since", fmt.Sprintf("-%s", maxLogAge.String()),
|
||||||
|
"--lines", fmt.Sprintf("%d", maxLogEntries),
|
||||||
|
"--no-pager",
|
||||||
|
"--output", "short-iso",
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "journalctl", args...)
|
||||||
|
var stdout, stderr bytes.Buffer
|
||||||
|
cmd.Stdout = &stdout
|
||||||
|
cmd.Stderr = &stderr
|
||||||
|
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||||
|
return "", fmt.Errorf("journalctl command timed out after 30 seconds")
|
||||||
|
}
|
||||||
|
if strings.Contains(err.Error(), "executable file not found") {
|
||||||
|
return "", fmt.Errorf("journalctl command not found: %w", err)
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("execute journalctl: %w (stderr: %s)", err, stderr.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
logs := stdout.String()
|
||||||
|
if strings.TrimSpace(logs) == "" {
|
||||||
|
return "No recent log entries found in systemd journal", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
header := fmt.Sprintf("=== Systemd Journal Logs for %s.service (last %d entries, max %s) ===\n",
|
||||||
|
serviceName, maxLogEntries, maxLogAge.String())
|
||||||
|
|
||||||
|
return header + logs, nil
|
||||||
|
}
|
||||||
|
|
||||||
// addFirewallRules collects and adds firewall rules to the archive
|
// addFirewallRules collects and adds firewall rules to the archive
|
||||||
func (g *BundleGenerator) addFirewallRules() error {
|
func (g *BundleGenerator) addFirewallRules() error {
|
||||||
log.Info("Collecting firewall rules")
|
log.Info("Collecting firewall rules")
|
||||||
@@ -481,7 +568,7 @@ func formatExpr(exp expr.Any) string {
|
|||||||
case *expr.Fib:
|
case *expr.Fib:
|
||||||
return formatFib(e)
|
return formatFib(e)
|
||||||
case *expr.Target:
|
case *expr.Target:
|
||||||
return fmt.Sprintf("jump %s", e.Name) // Properly format jump targets
|
return fmt.Sprintf("jump %s", e.Name)
|
||||||
case *expr.Immediate:
|
case *expr.Immediate:
|
||||||
if e.Register == 1 {
|
if e.Register == 1 {
|
||||||
return formatImmediateData(e.Data)
|
return formatImmediateData(e.Data)
|
||||||
|
|||||||
@@ -6,3 +6,9 @@ package debug
|
|||||||
func (g *BundleGenerator) addFirewallRules() error {
|
func (g *BundleGenerator) addFirewallRules() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *BundleGenerator) trySystemdLogFallback() error {
|
||||||
|
// Systemd is only available on Linux
|
||||||
|
// TODO: Add BSD support
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
66
client/internal/debug/wgshow.go
Normal file
66
client/internal/debug/wgshow.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package debug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WGIface interface {
|
||||||
|
FullStats() (*configurer.Stats, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *BundleGenerator) addWgShow() error {
|
||||||
|
result, err := g.statusRecorder.PeersStatus()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
output := g.toWGShowFormat(result)
|
||||||
|
reader := bytes.NewReader([]byte(output))
|
||||||
|
|
||||||
|
if err := g.addFileToZip(reader, "wgshow.txt"); err != nil {
|
||||||
|
return fmt.Errorf("add wg show to zip: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *BundleGenerator) toWGShowFormat(s *configurer.Stats) string {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
sb.WriteString(fmt.Sprintf("interface: %s\n", s.DeviceName))
|
||||||
|
sb.WriteString(fmt.Sprintf(" public key: %s\n", s.PublicKey))
|
||||||
|
sb.WriteString(fmt.Sprintf(" listen port: %d\n", s.ListenPort))
|
||||||
|
if s.FWMark != 0 {
|
||||||
|
sb.WriteString(fmt.Sprintf(" fwmark: %#x\n", s.FWMark))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peer := range s.Peers {
|
||||||
|
sb.WriteString(fmt.Sprintf("\npeer: %s\n", peer.PublicKey))
|
||||||
|
if peer.Endpoint.IP != nil {
|
||||||
|
if g.anonymize {
|
||||||
|
anonEndpoint := g.anonymizer.AnonymizeUDPAddr(peer.Endpoint)
|
||||||
|
sb.WriteString(fmt.Sprintf(" endpoint: %s\n", anonEndpoint.String()))
|
||||||
|
} else {
|
||||||
|
sb.WriteString(fmt.Sprintf(" endpoint: %s\n", peer.Endpoint.String()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(peer.AllowedIPs) > 0 {
|
||||||
|
var ipStrings []string
|
||||||
|
for _, ipnet := range peer.AllowedIPs {
|
||||||
|
ipStrings = append(ipStrings, ipnet.String())
|
||||||
|
}
|
||||||
|
sb.WriteString(fmt.Sprintf(" allowed ips: %s\n", strings.Join(ipStrings, ", ")))
|
||||||
|
}
|
||||||
|
sb.WriteString(fmt.Sprintf(" latest handshake: %s\n", peer.LastHandshake.Format(time.RFC1123)))
|
||||||
|
sb.WriteString(fmt.Sprintf(" transfer: %d B received, %d B sent\n", peer.RxBytes, peer.TxBytes))
|
||||||
|
if peer.PresharedKey {
|
||||||
|
sb.WriteString(" preshared key: (hidden)\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
@@ -2,7 +2,7 @@ package internal
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -12,13 +12,14 @@ import (
|
|||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) {
|
func createPTRRecord(aRecord nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) {
|
||||||
ip := net.ParseIP(aRecord.RData)
|
ip, err := netip.ParseAddr(aRecord.RData)
|
||||||
if ip == nil || ip.To4() == nil {
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse IP address %s: %v", aRecord.RData, err)
|
||||||
return nbdns.SimpleRecord{}, false
|
return nbdns.SimpleRecord{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !ipNet.Contains(ip) {
|
if !prefix.Contains(ip) {
|
||||||
return nbdns.SimpleRecord{}, false
|
return nbdns.SimpleRecord{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -36,16 +37,19 @@ func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.Simple
|
|||||||
}
|
}
|
||||||
|
|
||||||
// generateReverseZoneName creates the reverse DNS zone name for a given network
|
// generateReverseZoneName creates the reverse DNS zone name for a given network
|
||||||
func generateReverseZoneName(ipNet *net.IPNet) (string, error) {
|
func generateReverseZoneName(network netip.Prefix) (string, error) {
|
||||||
networkIP := ipNet.IP.Mask(ipNet.Mask)
|
networkIP := network.Masked().Addr()
|
||||||
maskOnes, _ := ipNet.Mask.Size()
|
|
||||||
|
if !networkIP.Is4() {
|
||||||
|
return "", fmt.Errorf("reverse DNS is only supported for IPv4 networks, got: %s", networkIP)
|
||||||
|
}
|
||||||
|
|
||||||
// round up to nearest byte
|
// round up to nearest byte
|
||||||
octetsToUse := (maskOnes + 7) / 8
|
octetsToUse := (network.Bits() + 7) / 8
|
||||||
|
|
||||||
octets := strings.Split(networkIP.String(), ".")
|
octets := strings.Split(networkIP.String(), ".")
|
||||||
if octetsToUse > len(octets) {
|
if octetsToUse > len(octets) {
|
||||||
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes)
|
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", network.Bits())
|
||||||
}
|
}
|
||||||
|
|
||||||
reverseOctets := make([]string, octetsToUse)
|
reverseOctets := make([]string, octetsToUse)
|
||||||
@@ -68,7 +72,7 @@ func zoneExists(config *nbdns.Config, zoneName string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// collectPTRRecords gathers all PTR records for the given network from A records
|
// collectPTRRecords gathers all PTR records for the given network from A records
|
||||||
func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord {
|
func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.SimpleRecord {
|
||||||
var records []nbdns.SimpleRecord
|
var records []nbdns.SimpleRecord
|
||||||
|
|
||||||
for _, zone := range config.CustomZones {
|
for _, zone := range config.CustomZones {
|
||||||
@@ -77,7 +81,7 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if ptrRecord, ok := createPTRRecord(record, ipNet); ok {
|
if ptrRecord, ok := createPTRRecord(record, prefix); ok {
|
||||||
records = append(records, ptrRecord)
|
records = append(records, ptrRecord)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -87,8 +91,8 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addReverseZone adds a reverse DNS zone to the configuration for the given network
|
// addReverseZone adds a reverse DNS zone to the configuration for the given network
|
||||||
func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
|
func addReverseZone(config *nbdns.Config, network netip.Prefix) {
|
||||||
zoneName, err := generateReverseZoneName(ipNet)
|
zoneName, err := generateReverseZoneName(network)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn(err)
|
log.Warn(err)
|
||||||
return
|
return
|
||||||
@@ -99,7 +103,7 @@ func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
records := collectPTRRecords(config, ipNet)
|
records := collectPTRRecords(config, network)
|
||||||
|
|
||||||
reverseZone := nbdns.CustomZone{
|
reverseZone := nbdns.CustomZone{
|
||||||
Domain: zoneName,
|
Domain: zoneName,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -10,8 +11,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
PriorityDNSRoute = 100
|
PriorityLocal = 100
|
||||||
PriorityMatchDomain = 50
|
PriorityDNSRoute = 75
|
||||||
|
PriorityUpstream = 50
|
||||||
PriorityDefault = 1
|
PriorityDefault = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -148,47 +150,27 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
qname := strings.ToLower(r.Question[0].Name)
|
qname := strings.ToLower(r.Question[0].Name)
|
||||||
log.Tracef("handling DNS request for domain=%s", qname)
|
|
||||||
|
|
||||||
c.mu.RLock()
|
c.mu.RLock()
|
||||||
handlers := slices.Clone(c.handlers)
|
handlers := slices.Clone(c.handlers)
|
||||||
c.mu.RUnlock()
|
c.mu.RUnlock()
|
||||||
|
|
||||||
if log.IsLevelEnabled(log.TraceLevel) {
|
if log.IsLevelEnabled(log.TraceLevel) {
|
||||||
log.Tracef("current handlers (%d):", len(handlers))
|
var b strings.Builder
|
||||||
|
b.WriteString(fmt.Sprintf("DNS request domain=%s, handlers (%d):\n", qname, len(handlers)))
|
||||||
for _, h := range handlers {
|
for _, h := range handlers {
|
||||||
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d",
|
b.WriteString(fmt.Sprintf(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d\n",
|
||||||
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority)
|
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority))
|
||||||
}
|
}
|
||||||
|
log.Trace(strings.TrimSuffix(b.String(), "\n"))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try handlers in priority order
|
// Try handlers in priority order
|
||||||
for _, entry := range handlers {
|
for _, entry := range handlers {
|
||||||
var matched bool
|
matched := c.isHandlerMatch(qname, entry)
|
||||||
switch {
|
|
||||||
case entry.Pattern == ".":
|
|
||||||
matched = true
|
|
||||||
case entry.IsWildcard:
|
|
||||||
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
|
|
||||||
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
|
|
||||||
default:
|
|
||||||
// For non-wildcard patterns:
|
|
||||||
// If handler wants subdomain matching, allow suffix match
|
|
||||||
// Otherwise require exact match
|
|
||||||
if entry.MatchSubdomains {
|
|
||||||
matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
|
|
||||||
} else {
|
|
||||||
matched = strings.EqualFold(qname, entry.Pattern)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !matched {
|
if matched {
|
||||||
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d matched=false",
|
log.Tracef("handler matched: domain=%s -> pattern=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||||
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard, entry.Priority)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d",
|
|
||||||
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
|
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
|
||||||
|
|
||||||
chainWriter := &ResponseWriterChain{
|
chainWriter := &ResponseWriterChain{
|
||||||
@@ -199,11 +181,12 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
|
|
||||||
// If handler wants to continue, try next handler
|
// If handler wants to continue, try next handler
|
||||||
if chainWriter.shouldContinue {
|
if chainWriter.shouldContinue {
|
||||||
log.Tracef("handler requested continue to next handler")
|
log.Tracef("handler requested continue to next handler for domain=%s", qname)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// No handler matched or all handlers passed
|
// No handler matched or all handlers passed
|
||||||
log.Tracef("no handler found for domain=%s", qname)
|
log.Tracef("no handler found for domain=%s", qname)
|
||||||
@@ -213,3 +196,22 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
log.Errorf("failed to write DNS response: %v", err)
|
log.Errorf("failed to write DNS response: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||||
|
switch {
|
||||||
|
case entry.Pattern == ".":
|
||||||
|
return true
|
||||||
|
case entry.IsWildcard:
|
||||||
|
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
|
||||||
|
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
|
||||||
|
default:
|
||||||
|
// For non-wildcard patterns:
|
||||||
|
// If handler wants subdomain matching, allow suffix match
|
||||||
|
// Otherwise require exact match
|
||||||
|
if entry.MatchSubdomains {
|
||||||
|
return strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
|
||||||
|
} else {
|
||||||
|
return strings.EqualFold(qname, entry.Pattern)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
|||||||
|
|
||||||
// Setup handlers with different priorities
|
// Setup handlers with different priorities
|
||||||
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
|
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
|
||||||
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain)
|
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityUpstream)
|
||||||
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)
|
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)
|
||||||
|
|
||||||
// Create test request
|
// Create test request
|
||||||
@@ -200,7 +200,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||||
{pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain},
|
{pattern: "*.example.com.", priority: nbdns.PriorityUpstream},
|
||||||
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
|
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
|
||||||
},
|
},
|
||||||
queryDomain: "test.example.com.",
|
queryDomain: "test.example.com.",
|
||||||
@@ -214,7 +214,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||||
{pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain},
|
{pattern: "test.example.com.", priority: nbdns.PriorityUpstream},
|
||||||
{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
|
{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
|
||||||
},
|
},
|
||||||
queryDomain: "sub.test.example.com.",
|
queryDomain: "sub.test.example.com.",
|
||||||
@@ -281,7 +281,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
|||||||
|
|
||||||
// Add handlers in priority order
|
// Add handlers in priority order
|
||||||
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
|
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
|
||||||
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain)
|
chain.AddHandler("example.com.", handler2, nbdns.PriorityUpstream)
|
||||||
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)
|
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)
|
||||||
|
|
||||||
// Create test request
|
// Create test request
|
||||||
@@ -344,13 +344,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
{"add", "example.com.", nbdns.PriorityMatchDomain},
|
{"add", "example.com.", nbdns.PriorityUpstream},
|
||||||
{"remove", "example.com.", nbdns.PriorityDNSRoute},
|
{"remove", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
},
|
},
|
||||||
query: "example.com.",
|
query: "example.com.",
|
||||||
expectedCalls: map[int]bool{
|
expectedCalls: map[int]bool{
|
||||||
nbdns.PriorityDNSRoute: false,
|
nbdns.PriorityDNSRoute: false,
|
||||||
nbdns.PriorityMatchDomain: true,
|
nbdns.PriorityUpstream: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -361,13 +361,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
{"add", "example.com.", nbdns.PriorityMatchDomain},
|
{"add", "example.com.", nbdns.PriorityUpstream},
|
||||||
{"remove", "example.com.", nbdns.PriorityMatchDomain},
|
{"remove", "example.com.", nbdns.PriorityUpstream},
|
||||||
},
|
},
|
||||||
query: "example.com.",
|
query: "example.com.",
|
||||||
expectedCalls: map[int]bool{
|
expectedCalls: map[int]bool{
|
||||||
nbdns.PriorityDNSRoute: true,
|
nbdns.PriorityDNSRoute: true,
|
||||||
nbdns.PriorityMatchDomain: false,
|
nbdns.PriorityUpstream: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -378,15 +378,15 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
{"add", "example.com.", nbdns.PriorityMatchDomain},
|
{"add", "example.com.", nbdns.PriorityUpstream},
|
||||||
{"add", "example.com.", nbdns.PriorityDefault},
|
{"add", "example.com.", nbdns.PriorityDefault},
|
||||||
{"remove", "example.com.", nbdns.PriorityDNSRoute},
|
{"remove", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
{"remove", "example.com.", nbdns.PriorityMatchDomain},
|
{"remove", "example.com.", nbdns.PriorityUpstream},
|
||||||
},
|
},
|
||||||
query: "example.com.",
|
query: "example.com.",
|
||||||
expectedCalls: map[int]bool{
|
expectedCalls: map[int]bool{
|
||||||
nbdns.PriorityDNSRoute: false,
|
nbdns.PriorityDNSRoute: false,
|
||||||
nbdns.PriorityMatchDomain: false,
|
nbdns.PriorityUpstream: false,
|
||||||
nbdns.PriorityDefault: true,
|
nbdns.PriorityDefault: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -454,7 +454,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
|||||||
// Add handlers in mixed order
|
// Add handlers in mixed order
|
||||||
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
|
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
|
||||||
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
|
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
|
||||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityUpstream)
|
||||||
|
|
||||||
// Test 1: Initial state
|
// Test 1: Initial state
|
||||||
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
@@ -490,7 +490,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
|||||||
defaultHandler.Calls = nil
|
defaultHandler.Calls = nil
|
||||||
|
|
||||||
// Test 3: Remove middle priority handler
|
// Test 3: Remove middle priority handler
|
||||||
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
chain.RemoveHandler(testDomain, nbdns.PriorityUpstream)
|
||||||
|
|
||||||
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
// Now lowest priority handler (defaultHandler) should be called
|
// Now lowest priority handler (defaultHandler) should be called
|
||||||
@@ -607,7 +607,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
shouldMatch bool
|
shouldMatch bool
|
||||||
}{
|
}{
|
||||||
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
|
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
|
||||||
{"example.com.", nbdns.PriorityMatchDomain, false, false},
|
{"example.com.", nbdns.PriorityUpstream, false, false},
|
||||||
{"Example.Com.", nbdns.PriorityDNSRoute, false, true},
|
{"Example.Com.", nbdns.PriorityDNSRoute, false, true},
|
||||||
},
|
},
|
||||||
query: "example.com.",
|
query: "example.com.",
|
||||||
@@ -702,8 +702,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
priority int
|
priority int
|
||||||
subdomain bool
|
subdomain bool
|
||||||
}{
|
}{
|
||||||
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
{"add", "example.com.", nbdns.PriorityUpstream, true},
|
||||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
|
{"add", "sub.example.com.", nbdns.PriorityUpstream, false},
|
||||||
},
|
},
|
||||||
query: "sub.example.com.",
|
query: "sub.example.com.",
|
||||||
expectedMatch: "sub.example.com.",
|
expectedMatch: "sub.example.com.",
|
||||||
@@ -717,8 +717,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
priority int
|
priority int
|
||||||
subdomain bool
|
subdomain bool
|
||||||
}{
|
}{
|
||||||
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
{"add", "example.com.", nbdns.PriorityUpstream, true},
|
||||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
|
{"add", "sub.example.com.", nbdns.PriorityUpstream, true},
|
||||||
},
|
},
|
||||||
query: "sub.example.com.",
|
query: "sub.example.com.",
|
||||||
expectedMatch: "sub.example.com.",
|
expectedMatch: "sub.example.com.",
|
||||||
@@ -732,10 +732,10 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
priority int
|
priority int
|
||||||
subdomain bool
|
subdomain bool
|
||||||
}{
|
}{
|
||||||
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
{"add", "example.com.", nbdns.PriorityUpstream, true},
|
||||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
|
{"add", "sub.example.com.", nbdns.PriorityUpstream, true},
|
||||||
{"add", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
|
{"add", "test.sub.example.com.", nbdns.PriorityUpstream, false},
|
||||||
{"remove", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
|
{"remove", "test.sub.example.com.", nbdns.PriorityUpstream, false},
|
||||||
},
|
},
|
||||||
query: "test.sub.example.com.",
|
query: "test.sub.example.com.",
|
||||||
expectedMatch: "sub.example.com.",
|
expectedMatch: "sub.example.com.",
|
||||||
@@ -749,7 +749,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
priority int
|
priority int
|
||||||
subdomain bool
|
subdomain bool
|
||||||
}{
|
}{
|
||||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
|
{"add", "sub.example.com.", nbdns.PriorityUpstream, false},
|
||||||
{"add", "example.com.", nbdns.PriorityDNSRoute, true},
|
{"add", "example.com.", nbdns.PriorityDNSRoute, true},
|
||||||
},
|
},
|
||||||
query: "sub.example.com.",
|
query: "sub.example.com.",
|
||||||
@@ -764,9 +764,9 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
priority int
|
priority int
|
||||||
subdomain bool
|
subdomain bool
|
||||||
}{
|
}{
|
||||||
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
{"add", "example.com.", nbdns.PriorityUpstream, true},
|
||||||
{"add", "other.example.com.", nbdns.PriorityMatchDomain, true},
|
{"add", "other.example.com.", nbdns.PriorityUpstream, true},
|
||||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
|
{"add", "sub.example.com.", nbdns.PriorityUpstream, false},
|
||||||
},
|
},
|
||||||
query: "sub.example.com.",
|
query: "sub.example.com.",
|
||||||
expectedMatch: "sub.example.com.",
|
expectedMatch: "sub.example.com.",
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"os/exec"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -41,6 +44,20 @@ const (
|
|||||||
interfaceConfigNameServerKey = "NameServer"
|
interfaceConfigNameServerKey = "NameServer"
|
||||||
interfaceConfigSearchListKey = "SearchList"
|
interfaceConfigSearchListKey = "SearchList"
|
||||||
|
|
||||||
|
// Network interface DNS registration settings
|
||||||
|
disableDynamicUpdateKey = "DisableDynamicUpdate"
|
||||||
|
registrationEnabledKey = "RegistrationEnabled"
|
||||||
|
maxNumberOfAddressesToRegisterKey = "MaxNumberOfAddressesToRegister"
|
||||||
|
|
||||||
|
// NetBIOS/WINS settings
|
||||||
|
netbtInterfacePath = `SYSTEM\CurrentControlSet\Services\NetBT\Parameters\Interfaces`
|
||||||
|
netbiosOptionsKey = "NetbiosOptions"
|
||||||
|
|
||||||
|
// NetBIOS option values: 0 = from DHCP, 1 = enabled, 2 = disabled
|
||||||
|
netbiosFromDHCP = 0
|
||||||
|
netbiosEnabled = 1
|
||||||
|
netbiosDisabled = 2
|
||||||
|
|
||||||
// RP_FORCE: Reapply all policies even if no policy change was detected
|
// RP_FORCE: Reapply all policies even if no policy change was detected
|
||||||
rpForce = 0x1
|
rpForce = 0x1
|
||||||
)
|
)
|
||||||
@@ -67,16 +84,85 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
|||||||
log.Infof("detected GPO DNS policy configuration, using policy store")
|
log.Infof("detected GPO DNS policy configuration, using policy store")
|
||||||
}
|
}
|
||||||
|
|
||||||
return ®istryConfigurator{
|
configurator := ®istryConfigurator{
|
||||||
guid: guid,
|
guid: guid,
|
||||||
gpo: useGPO,
|
gpo: useGPO,
|
||||||
}, nil
|
}
|
||||||
|
|
||||||
|
if err := configurator.configureInterface(); err != nil {
|
||||||
|
log.Errorf("failed to configure interface settings: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return configurator, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) supportCustomPort() bool {
|
func (r *registryConfigurator) supportCustomPort() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *registryConfigurator) configureInterface() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := r.disableDNSRegistrationForInterface(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("disable DNS registration: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.disableWINSForInterface(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("disable WINS: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *registryConfigurator) disableDNSRegistrationForInterface() error {
|
||||||
|
regKey, err := r.getInterfaceRegistryKey()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get interface registry key: %w", err)
|
||||||
|
}
|
||||||
|
defer closer(regKey)
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := regKey.SetDWordValue(disableDynamicUpdateKey, 1); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("set %s: %w", disableDynamicUpdateKey, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := regKey.SetDWordValue(registrationEnabledKey, 0); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("set %s: %w", registrationEnabledKey, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := regKey.SetDWordValue(maxNumberOfAddressesToRegisterKey, 0); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("set %s: %w", maxNumberOfAddressesToRegisterKey, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if merr == nil || len(merr.Errors) == 0 {
|
||||||
|
log.Infof("disabled DNS registration for interface %s", r.guid)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *registryConfigurator) disableWINSForInterface() error {
|
||||||
|
netbtKeyPath := fmt.Sprintf(`%s\Tcpip_%s`, netbtInterfacePath, r.guid)
|
||||||
|
|
||||||
|
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, netbtKeyPath, registry.SET_VALUE)
|
||||||
|
if err != nil {
|
||||||
|
regKey, _, err = registry.CreateKey(registry.LOCAL_MACHINE, netbtKeyPath, registry.SET_VALUE)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create NetBT interface key %s: %w", netbtKeyPath, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defer closer(regKey)
|
||||||
|
|
||||||
|
// NetbiosOptions: 2 = disabled
|
||||||
|
if err := regKey.SetDWordValue(netbiosOptionsKey, netbiosDisabled); err != nil {
|
||||||
|
return fmt.Errorf("set %s: %w", netbiosOptionsKey, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("disabled WINS/NetBIOS for interface %s", r.guid)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||||
if config.RouteAll {
|
if config.RouteAll {
|
||||||
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
|
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
|
||||||
@@ -119,9 +205,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
return fmt.Errorf("update search domains: %w", err)
|
return fmt.Errorf("update search domains: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.flushDNSCache(); err != nil {
|
go r.flushDNSCache()
|
||||||
log.Errorf("failed to flush DNS cache: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -191,7 +275,25 @@ func (r *registryConfigurator) string() string {
|
|||||||
return "registry"
|
return "registry"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) flushDNSCache() error {
|
func (r *registryConfigurator) registerDNS() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// nolint:misspell
|
||||||
|
cmd := exec.CommandContext(ctx, "ipconfig", "/registerdns")
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to register DNS: %v, output: %s", err, out)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("registered DNS names")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *registryConfigurator) flushDNSCache() {
|
||||||
|
r.registerDNS()
|
||||||
|
|
||||||
// dnsFlushResolverCacheFn.Call() may panic if the func is not found
|
// dnsFlushResolverCacheFn.Call() may panic if the func is not found
|
||||||
defer func() {
|
defer func() {
|
||||||
if rec := recover(); rec != nil {
|
if rec := recover(); rec != nil {
|
||||||
@@ -202,13 +304,14 @@ func (r *registryConfigurator) flushDNSCache() error {
|
|||||||
ret, _, err := dnsFlushResolverCacheFn.Call()
|
ret, _, err := dnsFlushResolverCacheFn.Call()
|
||||||
if ret == 0 {
|
if ret == 0 {
|
||||||
if err != nil && !errors.Is(err, syscall.Errno(0)) {
|
if err != nil && !errors.Is(err, syscall.Errno(0)) {
|
||||||
return fmt.Errorf("DnsFlushResolverCache failed: %w", err)
|
log.Errorf("DnsFlushResolverCache failed: %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
return fmt.Errorf("DnsFlushResolverCache failed")
|
log.Errorf("DnsFlushResolverCache failed")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("flushed DNS cache")
|
log.Info("flushed DNS cache")
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
||||||
@@ -263,9 +366,7 @@ func (r *registryConfigurator) restoreHostDNS() error {
|
|||||||
return fmt.Errorf("remove interface registry key: %w", err)
|
return fmt.Errorf("remove interface registry key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.flushDNSCache(); err != nil {
|
go r.flushDNSCache()
|
||||||
log.Errorf("failed to flush DNS cache: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,16 +12,19 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Resolver struct {
|
type Resolver struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
records map[dns.Question][]dns.RR
|
records map[dns.Question][]dns.RR
|
||||||
|
domains map[domain.Domain]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewResolver() *Resolver {
|
func NewResolver() *Resolver {
|
||||||
return &Resolver{
|
return &Resolver{
|
||||||
records: make(map[dns.Question][]dns.RR),
|
records: make(map[dns.Question][]dns.RR),
|
||||||
|
domains: make(map[domain.Domain]struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,8 +67,12 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
replyMessage.Rcode = dns.RcodeSuccess
|
replyMessage.Rcode = dns.RcodeSuccess
|
||||||
replyMessage.Answer = append(replyMessage.Answer, records...)
|
replyMessage.Answer = append(replyMessage.Answer, records...)
|
||||||
} else {
|
} else {
|
||||||
// TODO: return success if we have a different record type for the same name, relevant for search domains
|
// Check if we have any records for this domain name with different types
|
||||||
replyMessage.Rcode = dns.RcodeNameError
|
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
|
||||||
|
replyMessage.Rcode = dns.RcodeSuccess // NOERROR with 0 records
|
||||||
|
} else {
|
||||||
|
replyMessage.Rcode = dns.RcodeNameError // NXDOMAIN
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := w.WriteMsg(replyMessage); err != nil {
|
if err := w.WriteMsg(replyMessage); err != nil {
|
||||||
@@ -73,6 +80,15 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// hasRecordsForDomain checks if any records exist for the given domain name regardless of type
|
||||||
|
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
|
||||||
|
d.mu.RLock()
|
||||||
|
defer d.mu.RUnlock()
|
||||||
|
|
||||||
|
_, exists := d.domains[domainName]
|
||||||
|
return exists
|
||||||
|
}
|
||||||
|
|
||||||
// lookupRecords fetches *all* DNS records matching the first question in r.
|
// lookupRecords fetches *all* DNS records matching the first question in r.
|
||||||
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
|
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
|
||||||
d.mu.RLock()
|
d.mu.RLock()
|
||||||
@@ -111,6 +127,7 @@ func (d *Resolver) Update(update []nbdns.SimpleRecord) {
|
|||||||
defer d.mu.Unlock()
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
maps.Clear(d.records)
|
maps.Clear(d.records)
|
||||||
|
maps.Clear(d.domains)
|
||||||
|
|
||||||
for _, rec := range update {
|
for _, rec := range update {
|
||||||
if err := d.registerRecord(rec); err != nil {
|
if err := d.registerRecord(rec); err != nil {
|
||||||
@@ -144,6 +161,7 @@ func (d *Resolver) registerRecord(record nbdns.SimpleRecord) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
d.records[q] = append(d.records[q], rr)
|
d.records[q] = append(d.records[q], rr)
|
||||||
|
d.domains[domain.Domain(q.Name)] = struct{}{}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -470,3 +470,115 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestLocalResolver_NoErrorWithDifferentRecordType verifies that querying for a record type
|
||||||
|
// that doesn't exist but where other record types exist for the same domain returns NOERROR
|
||||||
|
// with 0 records instead of NXDOMAIN
|
||||||
|
func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
|
||||||
|
recordA := nbdns.SimpleRecord{
|
||||||
|
Name: "example.netbird.cloud.",
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "192.168.1.100",
|
||||||
|
}
|
||||||
|
|
||||||
|
recordCNAME := nbdns.SimpleRecord{
|
||||||
|
Name: "alias.netbird.cloud.",
|
||||||
|
Type: int(dns.TypeCNAME),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "target.example.com.",
|
||||||
|
}
|
||||||
|
|
||||||
|
resolver.Update([]nbdns.SimpleRecord{recordA, recordCNAME})
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
queryName string
|
||||||
|
queryType uint16
|
||||||
|
expectedRcode int
|
||||||
|
shouldHaveData bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Query A record that exists",
|
||||||
|
queryName: "example.netbird.cloud.",
|
||||||
|
queryType: dns.TypeA,
|
||||||
|
expectedRcode: dns.RcodeSuccess,
|
||||||
|
shouldHaveData: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query AAAA for domain with only A record",
|
||||||
|
queryName: "example.netbird.cloud.",
|
||||||
|
queryType: dns.TypeAAAA,
|
||||||
|
expectedRcode: dns.RcodeSuccess,
|
||||||
|
shouldHaveData: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query other record with different case and non-fqdn",
|
||||||
|
queryName: "EXAMPLE.netbird.cloud",
|
||||||
|
queryType: dns.TypeAAAA,
|
||||||
|
expectedRcode: dns.RcodeSuccess,
|
||||||
|
shouldHaveData: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query TXT for domain with only A record",
|
||||||
|
queryName: "example.netbird.cloud.",
|
||||||
|
queryType: dns.TypeTXT,
|
||||||
|
expectedRcode: dns.RcodeSuccess,
|
||||||
|
shouldHaveData: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query A for domain with only CNAME record",
|
||||||
|
queryName: "alias.netbird.cloud.",
|
||||||
|
queryType: dns.TypeA,
|
||||||
|
expectedRcode: dns.RcodeSuccess,
|
||||||
|
shouldHaveData: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query AAAA for domain with only CNAME record",
|
||||||
|
queryName: "alias.netbird.cloud.",
|
||||||
|
queryType: dns.TypeAAAA,
|
||||||
|
expectedRcode: dns.RcodeSuccess,
|
||||||
|
shouldHaveData: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query for completely non-existent domain",
|
||||||
|
queryName: "nonexistent.netbird.cloud.",
|
||||||
|
queryType: dns.TypeA,
|
||||||
|
expectedRcode: dns.RcodeNameError,
|
||||||
|
shouldHaveData: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
var responseMSG *dns.Msg
|
||||||
|
|
||||||
|
msg := new(dns.Msg).SetQuestion(tc.queryName, tc.queryType)
|
||||||
|
|
||||||
|
responseWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
responseMSG = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resolver.ServeDNS(responseWriter, msg)
|
||||||
|
|
||||||
|
require.NotNil(t, responseMSG, "Should have received a response message")
|
||||||
|
|
||||||
|
assert.Equal(t, tc.expectedRcode, responseMSG.Rcode,
|
||||||
|
"Response code should be %d (%s)",
|
||||||
|
tc.expectedRcode, dns.RcodeToString[tc.expectedRcode])
|
||||||
|
|
||||||
|
if tc.shouldHaveData {
|
||||||
|
assert.Greater(t, len(responseMSG.Answer), 0, "Response should contain answers")
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, 0, len(responseMSG.Answer), "Response should contain no answers")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -489,7 +489,7 @@ func (s *DefaultServer) applyHostConfig() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("extra match domains: %v", s.extraDomains)
|
log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains))
|
||||||
|
|
||||||
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
|
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
|
||||||
log.Errorf("failed to apply DNS host manager update: %v", err)
|
log.Errorf("failed to apply DNS host manager update: %v", err)
|
||||||
@@ -527,7 +527,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
|||||||
muxUpdates = append(muxUpdates, handlerWrapper{
|
muxUpdates = append(muxUpdates, handlerWrapper{
|
||||||
domain: customZone.Domain,
|
domain: customZone.Domain,
|
||||||
handler: s.localResolver,
|
handler: s.localResolver,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityLocal,
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, record := range customZone.Records {
|
for _, record := range customZone.Records {
|
||||||
@@ -566,7 +566,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
groupedNS := groupNSGroupsByDomain(nameServerGroups)
|
groupedNS := groupNSGroupsByDomain(nameServerGroups)
|
||||||
|
|
||||||
for _, domainGroup := range groupedNS {
|
for _, domainGroup := range groupedNS {
|
||||||
basePriority := PriorityMatchDomain
|
basePriority := PriorityUpstream
|
||||||
if domainGroup.domain == nbdns.RootZone {
|
if domainGroup.domain == nbdns.RootZone {
|
||||||
basePriority = PriorityDefault
|
basePriority = PriorityDefault
|
||||||
}
|
}
|
||||||
@@ -588,10 +588,14 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
|
|||||||
// Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts
|
// Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts
|
||||||
priority := basePriority - i
|
priority := basePriority - i
|
||||||
|
|
||||||
// Check if we're about to overlap with the next priority tier
|
// Check if we're about to overlap with the next priority tier.
|
||||||
if basePriority == PriorityMatchDomain && priority <= PriorityDefault {
|
// This boundary check ensures that the priority of upstream handlers does not conflict
|
||||||
|
// with the default priority tier. By decrementing the priority for each handler, we avoid
|
||||||
|
// overlaps, but if the calculated priority falls into the default tier, we skip the remaining
|
||||||
|
// handlers to maintain the integrity of the priority system.
|
||||||
|
if basePriority == PriorityUpstream && priority <= PriorityDefault {
|
||||||
log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers",
|
log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers",
|
||||||
domainGroup.domain, PriorityMatchDomain-PriorityDefault)
|
domainGroup.domain, PriorityUpstream-PriorityDefault)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -46,10 +46,9 @@ func (w *mocWGIface) Name() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *mocWGIface) Address() wgaddr.Address {
|
func (w *mocWGIface) Address() wgaddr.Address {
|
||||||
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
|
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: ip,
|
IP: netip.MustParseAddr("100.66.100.1"),
|
||||||
Network: network,
|
Network: netip.MustParsePrefix("100.66.100.0/24"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,12 +164,12 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||||
domain: "netbird.io",
|
domain: "netbird.io",
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
dummyHandler.ID(): handlerWrapper{
|
dummyHandler.ID(): handlerWrapper{
|
||||||
domain: "netbird.cloud",
|
domain: "netbird.cloud",
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityLocal,
|
||||||
},
|
},
|
||||||
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||||
domain: nbdns.RootZone,
|
domain: nbdns.RootZone,
|
||||||
@@ -187,7 +186,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||||
domain: "netbird.cloud",
|
domain: "netbird.cloud",
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
@@ -211,12 +210,12 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||||
domain: "netbird.io",
|
domain: "netbird.io",
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
"local-resolver": handlerWrapper{
|
"local-resolver": handlerWrapper{
|
||||||
domain: "netbird.cloud",
|
domain: "netbird.cloud",
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityLocal,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
||||||
@@ -306,7 +305,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||||
domain: zoneRecords[0].Name,
|
domain: zoneRecords[0].Name,
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
@@ -322,7 +321,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||||
domain: zoneRecords[0].Name,
|
domain: zoneRecords[0].Name,
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
@@ -464,17 +463,10 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
defer ctrl.Finish()
|
defer ctrl.Finish()
|
||||||
|
|
||||||
_, ipNet, err := net.ParseCIDR("100.66.100.1/32")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("parse CIDR: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||||
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
|
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
|
||||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||||
packetfilter.EXPECT().SetNetwork(ipNet)
|
|
||||||
|
|
||||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||||
t.Errorf("set packet filter: %v", err)
|
t.Errorf("set packet filter: %v", err)
|
||||||
@@ -503,7 +495,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
"id1": handlerWrapper{
|
"id1": handlerWrapper{
|
||||||
domain: zoneRecords[0].Name,
|
domain: zoneRecords[0].Name,
|
||||||
handler: &local.Resolver{},
|
handler: &local.Resolver{},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
|
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
|
||||||
@@ -986,7 +978,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute)
|
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute)
|
||||||
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain)
|
chain.AddHandler("example.com.", upstreamHandler, PriorityUpstream)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -1067,14 +1059,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group1",
|
Id: "upstream-group1",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
"upstream-group2": {
|
"upstream-group2": {
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group2",
|
Id: "upstream-group2",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain - 1,
|
priority: PriorityUpstream - 1,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1101,21 +1093,21 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group1",
|
Id: "upstream-group1",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
"upstream-group2": {
|
"upstream-group2": {
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group2",
|
Id: "upstream-group2",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain - 1,
|
priority: PriorityUpstream - 1,
|
||||||
},
|
},
|
||||||
"upstream-other": {
|
"upstream-other": {
|
||||||
domain: "other.com",
|
domain: "other.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-other",
|
Id: "upstream-other",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1136,7 +1128,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group2",
|
Id: "upstream-group2",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain - 1,
|
priority: PriorityUpstream - 1,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedHandlers: map[string]string{
|
expectedHandlers: map[string]string{
|
||||||
@@ -1154,7 +1146,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group1",
|
Id: "upstream-group1",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedHandlers: map[string]string{
|
expectedHandlers: map[string]string{
|
||||||
@@ -1172,7 +1164,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group3",
|
Id: "upstream-group3",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain + 1,
|
priority: PriorityUpstream + 1,
|
||||||
},
|
},
|
||||||
// Keep existing groups with their original priorities
|
// Keep existing groups with their original priorities
|
||||||
{
|
{
|
||||||
@@ -1180,14 +1172,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group1",
|
Id: "upstream-group1",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group2",
|
Id: "upstream-group2",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain - 1,
|
priority: PriorityUpstream - 1,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedHandlers: map[string]string{
|
expectedHandlers: map[string]string{
|
||||||
@@ -1207,14 +1199,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group1",
|
Id: "upstream-group1",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group2",
|
Id: "upstream-group2",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain - 1,
|
priority: PriorityUpstream - 1,
|
||||||
},
|
},
|
||||||
// Add group3 with lowest priority
|
// Add group3 with lowest priority
|
||||||
{
|
{
|
||||||
@@ -1222,7 +1214,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group3",
|
Id: "upstream-group3",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain - 2,
|
priority: PriorityUpstream - 2,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedHandlers: map[string]string{
|
expectedHandlers: map[string]string{
|
||||||
@@ -1343,14 +1335,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group1",
|
Id: "upstream-group1",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
domain: "other.com",
|
domain: "other.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-other",
|
Id: "upstream-other",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedHandlers: map[string]string{
|
expectedHandlers: map[string]string{
|
||||||
@@ -1368,28 +1360,28 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group1",
|
Id: "upstream-group1",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-group2",
|
Id: "upstream-group2",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain - 1,
|
priority: PriorityUpstream - 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
domain: "other.com",
|
domain: "other.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-other",
|
Id: "upstream-other",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
domain: "new.com",
|
domain: "new.com",
|
||||||
handler: &mockHandler{
|
handler: &mockHandler{
|
||||||
Id: "upstream-new",
|
Id: "upstream-new",
|
||||||
},
|
},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityUpstream,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedHandlers: map[string]string{
|
expectedHandlers: map[string]string{
|
||||||
@@ -1799,14 +1791,14 @@ func TestExtraDomainsRefCounting(t *testing.T) {
|
|||||||
|
|
||||||
// Register domains from different handlers with same domain
|
// Register domains from different handlers with same domain
|
||||||
server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute)
|
server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute)
|
||||||
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityMatchDomain)
|
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityUpstream)
|
||||||
|
|
||||||
// Verify refcount is 2
|
// Verify refcount is 2
|
||||||
zoneKey := toZone("shared.example.com")
|
zoneKey := toZone("shared.example.com")
|
||||||
assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice")
|
assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice")
|
||||||
|
|
||||||
// Deregister one handler
|
// Deregister one handler
|
||||||
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityMatchDomain)
|
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityUpstream)
|
||||||
|
|
||||||
// Verify refcount is 1
|
// Verify refcount is 1
|
||||||
assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler")
|
assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler")
|
||||||
@@ -1933,7 +1925,7 @@ func TestDomainCaseHandling(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault)
|
server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault)
|
||||||
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityMatchDomain)
|
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityUpstream)
|
||||||
|
|
||||||
assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized")
|
assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized")
|
||||||
|
|
||||||
@@ -1953,3 +1945,111 @@ func TestDomainCaseHandling(t *testing.T) {
|
|||||||
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
|
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
|
||||||
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
|
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLocalResolverPriorityInServer(t *testing.T) {
|
||||||
|
server := &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
wgInterface: &mocWGIface{},
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
localResolver: local.NewResolver(),
|
||||||
|
service: &mockService{},
|
||||||
|
extraDomains: make(map[domain.Domain]int),
|
||||||
|
}
|
||||||
|
|
||||||
|
config := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "local.example.com",
|
||||||
|
Records: []nbdns.SimpleRecord{
|
||||||
|
{
|
||||||
|
Name: "test.local.example.com",
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "192.168.1.100",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
Domains: []string{"local.example.com"}, // Same domain as local records
|
||||||
|
NameServers: []nbdns.NameServer{
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.8.8"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
upstreamMuxUpdates, err := server.buildUpstreamHandlerUpdate(config.NameServerGroups)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify that local handler has higher priority than upstream for same domain
|
||||||
|
var localPriority, upstreamPriority int
|
||||||
|
localFound, upstreamFound := false, false
|
||||||
|
|
||||||
|
for _, update := range localMuxUpdates {
|
||||||
|
if update.domain == "local.example.com" {
|
||||||
|
localPriority = update.priority
|
||||||
|
localFound = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, update := range upstreamMuxUpdates {
|
||||||
|
if update.domain == "local.example.com" {
|
||||||
|
upstreamPriority = update.priority
|
||||||
|
upstreamFound = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, localFound, "Local handler should be found")
|
||||||
|
assert.True(t, upstreamFound, "Upstream handler should be found")
|
||||||
|
assert.Greater(t, localPriority, upstreamPriority,
|
||||||
|
"Local handler priority (%d) should be higher than upstream priority (%d)",
|
||||||
|
localPriority, upstreamPriority)
|
||||||
|
assert.Equal(t, PriorityLocal, localPriority, "Local handler should use PriorityLocal")
|
||||||
|
assert.Equal(t, PriorityUpstream, upstreamPriority, "Upstream handler should use PriorityUpstream")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalResolverPriorityConstants(t *testing.T) {
|
||||||
|
// Test that priority constants are ordered correctly
|
||||||
|
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route")
|
||||||
|
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
|
||||||
|
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")
|
||||||
|
|
||||||
|
// Test that local resolver uses the correct priority
|
||||||
|
server := &DefaultServer{
|
||||||
|
localResolver: local.NewResolver(),
|
||||||
|
}
|
||||||
|
|
||||||
|
config := nbdns.Config{
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "local.example.com",
|
||||||
|
Records: []nbdns.SimpleRecord{
|
||||||
|
{
|
||||||
|
Name: "test.local.example.com",
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "192.168.1.100",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, localMuxUpdates, 1)
|
||||||
|
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")
|
||||||
|
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
|
||||||
|
}
|
||||||
|
|||||||
@@ -24,11 +24,15 @@ type ServiceViaMemory struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
|
func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
|
||||||
|
lastIP, err := nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("get last ip from network: %v", err)
|
||||||
|
}
|
||||||
s := &ServiceViaMemory{
|
s := &ServiceViaMemory{
|
||||||
wgInterface: wgIface,
|
wgInterface: wgIface,
|
||||||
dnsMux: dns.NewServeMux(),
|
dnsMux: dns.NewServeMux(),
|
||||||
|
|
||||||
runtimeIP: nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1).String(),
|
runtimeIP: lastIP.String(),
|
||||||
runtimePort: defaultPort,
|
runtimePort: defaultPort,
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
@@ -91,7 +95,7 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
firstLayerDecoder := layers.LayerTypeIPv4
|
firstLayerDecoder := layers.LayerTypeIPv4
|
||||||
if s.wgInterface.Address().Network.IP.To4() == nil {
|
if s.wgInterface.Address().IP.Is6() {
|
||||||
firstLayerDecoder = layers.LayerTypeIPv6
|
firstLayerDecoder = layers.LayerTypeIPv6
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,33 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGetLastIPFromNetwork(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
addr string
|
|
||||||
ip string
|
|
||||||
}{
|
|
||||||
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
|
|
||||||
{"192.168.0.0/30", "192.168.0.2"},
|
|
||||||
{"192.168.0.0/16", "192.168.255.254"},
|
|
||||||
{"192.168.0.0/24", "192.168.0.254"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
_, ipnet, err := net.ParseCIDR(tt.addr)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Error parsing CIDR: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
lastIP := nbnet.GetLastIPFromNetwork(ipnet, 1).String()
|
|
||||||
if lastIP != tt.ip {
|
|
||||||
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -30,9 +30,12 @@ const (
|
|||||||
systemdDbusSetDNSMethodSuffix = systemdDbusLinkInterface + ".SetDNS"
|
systemdDbusSetDNSMethodSuffix = systemdDbusLinkInterface + ".SetDNS"
|
||||||
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
|
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
|
||||||
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
|
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
|
||||||
|
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
|
||||||
systemdDbusResolvConfModeForeign = "foreign"
|
systemdDbusResolvConfModeForeign = "foreign"
|
||||||
|
|
||||||
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
|
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
|
||||||
|
|
||||||
|
dnsSecDisabled = "no"
|
||||||
)
|
)
|
||||||
|
|
||||||
type systemdDbusConfigurator struct {
|
type systemdDbusConfigurator struct {
|
||||||
@@ -95,9 +98,13 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
|||||||
Family: unix.AF_INET,
|
Family: unix.AF_INET,
|
||||||
Address: ipAs4[:],
|
Address: ipAs4[:],
|
||||||
}
|
}
|
||||||
err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput})
|
if err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("set interface DNS server %s:%d: %w", config.ServerIP, config.ServerPort, err)
|
||||||
return fmt.Errorf("setting the interface DNS server %s:%d failed with error: %w", config.ServerIP, config.ServerPort, err)
|
}
|
||||||
|
|
||||||
|
// We don't support dnssec. On some machines this is default on so we explicitly set it to off
|
||||||
|
if err = s.callLinkMethod(systemdDbusSetDNSSECMethodSuffix, dnsSecDisabled); err != nil {
|
||||||
|
log.Warnf("failed to set DNSSEC to 'no': %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
@@ -103,19 +104,21 @@ func (u *upstreamResolverBase) Stop() {
|
|||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
requestID := GenerateRequestID()
|
||||||
|
logger := log.WithField("request_id", requestID)
|
||||||
var err error
|
var err error
|
||||||
defer func() {
|
defer func() {
|
||||||
u.checkUpstreamFails(err)
|
u.checkUpstreamFails(err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
if r.Extra == nil {
|
if r.Extra == nil {
|
||||||
r.MsgHdr.AuthenticatedData = true
|
r.MsgHdr.AuthenticatedData = true
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-u.ctx.Done():
|
case <-u.ctx.Done():
|
||||||
log.Tracef("%s has been stopped", u)
|
logger.Tracef("%s has been stopped", u)
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
@@ -132,35 +135,35 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
|
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
|
||||||
log.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
|
logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
|
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if rm == nil || !rm.Response {
|
if rm == nil || !rm.Response {
|
||||||
log.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
|
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
u.successCount.Add(1)
|
u.successCount.Add(1)
|
||||||
log.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
|
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
|
||||||
|
|
||||||
if err = w.WriteMsg(rm); err != nil {
|
if err = w.WriteMsg(rm); err != nil {
|
||||||
log.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
|
logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
|
||||||
}
|
}
|
||||||
// count the fails only if they happen sequentially
|
// count the fails only if they happen sequentially
|
||||||
u.failsCount.Store(0)
|
u.failsCount.Store(0)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
u.failsCount.Add(1)
|
u.failsCount.Add(1)
|
||||||
log.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
|
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
|
||||||
|
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetRcode(r, dns.RcodeServerFailure)
|
m.SetRcode(r, dns.RcodeServerFailure)
|
||||||
if err := w.WriteMsg(m); err != nil {
|
if err := w.WriteMsg(m); err != nil {
|
||||||
log.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
|
logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -385,3 +388,13 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
|||||||
|
|
||||||
return rm, t, nil
|
return rm, t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GenerateRequestID() string {
|
||||||
|
bytes := make([]byte, 4)
|
||||||
|
_, err := rand.Read(bytes)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to generate request ID: %v", err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(bytes)
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -23,8 +24,8 @@ type upstreamResolver struct {
|
|||||||
func newUpstreamResolver(
|
func newUpstreamResolver(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
_ string,
|
_ string,
|
||||||
_ net.IP,
|
_ netip.Addr,
|
||||||
_ *net.IPNet,
|
_ netip.Prefix,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
hostsDNSHolder *hostsDNSHolder,
|
hostsDNSHolder *hostsDNSHolder,
|
||||||
domain string,
|
domain string,
|
||||||
@@ -83,3 +84,10 @@ func (u *upstreamResolver) isLocalResolver(upstream string) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||||
|
return &dns.Client{
|
||||||
|
Timeout: dialTimeout,
|
||||||
|
Net: "udp",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@@ -19,8 +19,8 @@ type upstreamResolver struct {
|
|||||||
func newUpstreamResolver(
|
func newUpstreamResolver(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
_ string,
|
_ string,
|
||||||
_ net.IP,
|
_ netip.Addr,
|
||||||
_ *net.IPNet,
|
_ netip.Prefix,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
_ *hostsDNSHolder,
|
_ *hostsDNSHolder,
|
||||||
domain string,
|
domain string,
|
||||||
@@ -36,3 +36,10 @@ func newUpstreamResolver(
|
|||||||
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||||
return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream)
|
return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||||
|
return &dns.Client{
|
||||||
|
Timeout: dialTimeout,
|
||||||
|
Net: "udp",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -18,16 +19,16 @@ import (
|
|||||||
|
|
||||||
type upstreamResolverIOS struct {
|
type upstreamResolverIOS struct {
|
||||||
*upstreamResolverBase
|
*upstreamResolverBase
|
||||||
lIP net.IP
|
lIP netip.Addr
|
||||||
lNet *net.IPNet
|
lNet netip.Prefix
|
||||||
interfaceName string
|
interfaceName string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUpstreamResolver(
|
func newUpstreamResolver(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
interfaceName string,
|
interfaceName string,
|
||||||
ip net.IP,
|
ip netip.Addr,
|
||||||
net *net.IPNet,
|
net netip.Prefix,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
_ *hostsDNSHolder,
|
_ *hostsDNSHolder,
|
||||||
domain string,
|
domain string,
|
||||||
@@ -58,8 +59,11 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
|||||||
}
|
}
|
||||||
client.DialTimeout = timeout
|
client.DialTimeout = timeout
|
||||||
|
|
||||||
upstreamIP := net.ParseIP(upstreamHost)
|
upstreamIP, err := netip.ParseAddr(upstreamHost)
|
||||||
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err)
|
||||||
|
}
|
||||||
|
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
|
||||||
log.Debugf("using private client to query upstream: %s", upstream)
|
log.Debugf("using private client to query upstream: %s", upstream)
|
||||||
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
|
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -73,7 +77,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
|||||||
|
|
||||||
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
||||||
// This method is needed for iOS
|
// This method is needed for iOS
|
||||||
func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||||
index, err := getInterfaceIndex(interfaceName)
|
index, err := getInterfaceIndex(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
|
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
|
||||||
@@ -82,7 +86,7 @@ func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration
|
|||||||
|
|
||||||
dialer := &net.Dialer{
|
dialer := &net.Dialer{
|
||||||
LocalAddr: &net.UDPAddr{
|
LocalAddr: &net.UDPAddr{
|
||||||
IP: ip,
|
IP: ip.AsSlice(),
|
||||||
Port: 0, // Let the OS pick a free port
|
Port: 0, // Let the OS pick a free port
|
||||||
},
|
},
|
||||||
Timeout: dialTimeout,
|
Timeout: dialTimeout,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -58,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil, ".")
|
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
|
||||||
resolver.upstreamServers = testCase.InputServers
|
resolver.upstreamServers = testCase.InputServers
|
||||||
resolver.upstreamTimeout = testCase.timeout
|
resolver.upstreamTimeout = testCase.timeout
|
||||||
if testCase.cancelCTX {
|
if testCase.cancelCTX {
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
@@ -18,5 +17,4 @@ type WGIface interface {
|
|||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
GetFilter() device.PacketFilter
|
GetFilter() device.PacketFilter
|
||||||
GetDevice() *device.FilteredDevice
|
GetDevice() *device.FilteredDevice
|
||||||
GetStats(peerKey string) (configurer.WGStats, error)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
@@ -13,6 +12,5 @@ type WGIface interface {
|
|||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
GetFilter() device.PacketFilter
|
GetFilter() device.PacketFilter
|
||||||
GetDevice() *device.FilteredDevice
|
GetDevice() *device.FilteredDevice
|
||||||
GetStats(peerKey string) (configurer.WGStats, error)
|
|
||||||
GetInterfaceGUIDString() (string, error)
|
GetInterfaceGUIDString() (string, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,14 +18,20 @@ import (
|
|||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
||||||
const upstreamTimeout = 15 * time.Second
|
const upstreamTimeout = 15 * time.Second
|
||||||
|
|
||||||
|
type resolver interface {
|
||||||
|
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type firewaller interface {
|
||||||
|
UpdateSet(set firewall.Set, prefixes []netip.Prefix) error
|
||||||
|
}
|
||||||
|
|
||||||
type DNSForwarder struct {
|
type DNSForwarder struct {
|
||||||
listenAddress string
|
listenAddress string
|
||||||
ttl uint32
|
ttl uint32
|
||||||
@@ -33,75 +39,94 @@ type DNSForwarder struct {
|
|||||||
|
|
||||||
dnsServer *dns.Server
|
dnsServer *dns.Server
|
||||||
mux *dns.ServeMux
|
mux *dns.ServeMux
|
||||||
|
tcpServer *dns.Server
|
||||||
|
tcpMux *dns.ServeMux
|
||||||
|
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
fwdEntries []*ForwarderEntry
|
fwdEntries []*ForwarderEntry
|
||||||
firewall firewall.Manager
|
firewall firewaller
|
||||||
|
resolver resolver
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager, statusRecorder *peer.Status) *DNSForwarder {
|
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
|
||||||
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
|
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
|
||||||
return &DNSForwarder{
|
return &DNSForwarder{
|
||||||
listenAddress: listenAddress,
|
listenAddress: listenAddress,
|
||||||
ttl: ttl,
|
ttl: ttl,
|
||||||
firewall: firewall,
|
firewall: firewall,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
|
resolver: net.DefaultResolver,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
|
func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
|
||||||
log.Infof("listen DNS forwarder on address=%s", f.listenAddress)
|
log.Infof("starting DNS forwarder on address=%s", f.listenAddress)
|
||||||
mux := dns.NewServeMux()
|
|
||||||
|
|
||||||
dnsServer := &dns.Server{
|
// UDP server
|
||||||
|
mux := dns.NewServeMux()
|
||||||
|
f.mux = mux
|
||||||
|
mux.HandleFunc(".", f.handleDNSQueryUDP)
|
||||||
|
f.dnsServer = &dns.Server{
|
||||||
Addr: f.listenAddress,
|
Addr: f.listenAddress,
|
||||||
Net: "udp",
|
Net: "udp",
|
||||||
Handler: mux,
|
Handler: mux,
|
||||||
}
|
}
|
||||||
f.dnsServer = dnsServer
|
|
||||||
f.mux = mux
|
// TCP server
|
||||||
|
tcpMux := dns.NewServeMux()
|
||||||
|
f.tcpMux = tcpMux
|
||||||
|
tcpMux.HandleFunc(".", f.handleDNSQueryTCP)
|
||||||
|
f.tcpServer = &dns.Server{
|
||||||
|
Addr: f.listenAddress,
|
||||||
|
Net: "tcp",
|
||||||
|
Handler: tcpMux,
|
||||||
|
}
|
||||||
|
|
||||||
f.UpdateDomains(entries)
|
f.UpdateDomains(entries)
|
||||||
|
|
||||||
return dnsServer.ListenAndServe()
|
errCh := make(chan error, 2)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
log.Infof("DNS UDP listener running on %s", f.listenAddress)
|
||||||
|
errCh <- f.dnsServer.ListenAndServe()
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
log.Infof("DNS TCP listener running on %s", f.listenAddress)
|
||||||
|
errCh <- f.tcpServer.ListenAndServe()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// return the first error we get (e.g. bind failure or shutdown)
|
||||||
|
return <-errCh
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
||||||
f.mutex.Lock()
|
f.mutex.Lock()
|
||||||
defer f.mutex.Unlock()
|
defer f.mutex.Unlock()
|
||||||
|
|
||||||
if f.mux == nil {
|
|
||||||
log.Debug("DNS mux is nil, skipping domain update")
|
|
||||||
f.fwdEntries = entries
|
f.fwdEntries = entries
|
||||||
return
|
log.Debugf("Updated DNS forwarder with %d domains", len(entries))
|
||||||
}
|
|
||||||
|
|
||||||
oldDomains := filterDomains(f.fwdEntries)
|
|
||||||
|
|
||||||
for _, d := range oldDomains {
|
|
||||||
f.mux.HandleRemove(d.PunycodeString())
|
|
||||||
}
|
|
||||||
|
|
||||||
newDomains := filterDomains(entries)
|
|
||||||
for _, d := range newDomains {
|
|
||||||
f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQuery)
|
|
||||||
}
|
|
||||||
|
|
||||||
f.fwdEntries = entries
|
|
||||||
|
|
||||||
log.Debugf("Updated domains from %v to %v", oldDomains, newDomains)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) Close(ctx context.Context) error {
|
func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||||
if f.dnsServer == nil {
|
var result *multierror.Error
|
||||||
return nil
|
|
||||||
|
if f.dnsServer != nil {
|
||||||
|
if err := f.dnsServer.ShutdownContext(ctx); err != nil {
|
||||||
|
result = multierror.Append(result, fmt.Errorf("UDP shutdown: %w", err))
|
||||||
}
|
}
|
||||||
return f.dnsServer.ShutdownContext(ctx)
|
}
|
||||||
|
if f.tcpServer != nil {
|
||||||
|
if err := f.tcpServer.ShutdownContext(ctx); err != nil {
|
||||||
|
result = multierror.Append(result, fmt.Errorf("TCP shutdown: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
|
||||||
if len(query.Question) == 0 {
|
if len(query.Question) == 0 {
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
question := query.Question[0]
|
question := query.Question[0]
|
||||||
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
|
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
|
||||||
@@ -123,28 +148,69 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
if err := w.WriteMsg(resp); err != nil {
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
log.Errorf("failed to write DNS response: %v", err)
|
log.Errorf("failed to write DNS response: %v", err)
|
||||||
}
|
}
|
||||||
return
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
|
||||||
|
// query doesn't match any configured domain
|
||||||
|
if mostSpecificResId == "" {
|
||||||
|
resp.Rcode = dns.RcodeRefused
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed to write DNS response: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain)
|
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.handleDNSError(w, resp, domain, err)
|
f.handleDNSError(w, query, resp, domain, err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
||||||
|
f.addIPsToResponse(resp, domain, ips)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
|
resp := f.handleDNSQuery(w, query)
|
||||||
|
if resp == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.updateInternalState(domain, ips)
|
opt := query.IsEdns0()
|
||||||
f.addIPsToResponse(resp, domain, ips)
|
maxSize := dns.MinMsgSize
|
||||||
|
if opt != nil {
|
||||||
|
// client advertised a larger EDNS0 buffer
|
||||||
|
maxSize = int(opt.UDPSize())
|
||||||
|
}
|
||||||
|
|
||||||
|
// if our response is too big, truncate and set the TC bit
|
||||||
|
if resp.Len() > maxSize {
|
||||||
|
resp.Truncate(maxSize)
|
||||||
|
}
|
||||||
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
log.Errorf("failed to write DNS response: %v", err)
|
log.Errorf("failed to write DNS response: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) updateInternalState(domain string, ips []netip.Addr) {
|
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
|
resp := f.handleDNSQuery(w, query)
|
||||||
|
if resp == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed to write DNS response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
||||||
var prefixes []netip.Prefix
|
var prefixes []netip.Prefix
|
||||||
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
|
|
||||||
if mostSpecificResId != "" {
|
if mostSpecificResId != "" {
|
||||||
for _, ip := range ips {
|
for _, ip := range ips {
|
||||||
var prefix netip.Prefix
|
var prefix netip.Prefix
|
||||||
@@ -179,7 +245,7 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response
|
// handleDNSError processes DNS lookup errors and sends an appropriate error response
|
||||||
func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domain string, err error) {
|
func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, query, resp *dns.Msg, domain string, err error) {
|
||||||
var dnsErr *net.DNSError
|
var dnsErr *net.DNSError
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
@@ -191,7 +257,7 @@ func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domai
|
|||||||
}
|
}
|
||||||
|
|
||||||
if dnsErr.Server != "" {
|
if dnsErr.Server != "" {
|
||||||
log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err)
|
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[query.Question[0].Qtype], domain, dnsErr.Server, err)
|
||||||
} else {
|
} else {
|
||||||
log.Warnf(errResolveFailed, domain, err)
|
log.Warnf(errResolveFailed, domain, err)
|
||||||
}
|
}
|
||||||
@@ -275,16 +341,3 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar
|
|||||||
|
|
||||||
return selectedResId, matches
|
return selectedResId, matches
|
||||||
}
|
}
|
||||||
|
|
||||||
// filterDomains returns a list of normalized domains
|
|
||||||
func filterDomains(entries []*ForwarderEntry) domain.List {
|
|
||||||
newDomains := make(domain.List, 0, len(entries))
|
|
||||||
for _, d := range entries {
|
|
||||||
if d.Domain == "" {
|
|
||||||
log.Warn("empty domain in DNS forwarder")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
newDomains = append(newDomains, domain.Domain(nbdns.NormalizeZone(d.Domain.PunycodeString())))
|
|
||||||
}
|
|
||||||
return newDomains
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,11 +1,21 @@
|
|||||||
package dnsfwd
|
package dnsfwd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
@@ -13,7 +23,7 @@ import (
|
|||||||
func Test_getMatchingEntries(t *testing.T) {
|
func Test_getMatchingEntries(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
storedMappings map[string]route.ResID // key: domain pattern, value: resId
|
storedMappings map[string]route.ResID
|
||||||
queryDomain string
|
queryDomain string
|
||||||
expectedResId route.ResID
|
expectedResId route.ResID
|
||||||
}{
|
}{
|
||||||
@@ -44,7 +54,7 @@ func Test_getMatchingEntries(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Wildcard pattern does not match different domain",
|
name: "Wildcard pattern does not match different domain",
|
||||||
storedMappings: map[string]route.ResID{"*.example.com": "res4"},
|
storedMappings: map[string]route.ResID{"*.example.com": "res4"},
|
||||||
queryDomain: "foo.notexample.com",
|
queryDomain: "foo.example.org",
|
||||||
expectedResId: "",
|
expectedResId: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -101,3 +111,619 @@ func Test_getMatchingEntries(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MockFirewall struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockFirewall) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
|
args := m.Called(set, prefixes)
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockResolver struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
|
||||||
|
args := m.Called(ctx, network, host)
|
||||||
|
return args.Get(0).([]netip.Addr), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
configuredDomain string
|
||||||
|
queryDomain string
|
||||||
|
shouldMatch bool
|
||||||
|
expectedResID route.ResID
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exact domain match should be allowed",
|
||||||
|
configuredDomain: "example.com",
|
||||||
|
queryDomain: "example.com",
|
||||||
|
shouldMatch: true,
|
||||||
|
expectedResID: "test-res-id",
|
||||||
|
description: "Direct match to configured domain should work",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain access should be restricted",
|
||||||
|
configuredDomain: "example.com",
|
||||||
|
queryDomain: "mail.example.com",
|
||||||
|
shouldMatch: false,
|
||||||
|
expectedResID: "",
|
||||||
|
description: "Subdomain should not be accessible unless explicitly configured",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard should allow subdomains",
|
||||||
|
configuredDomain: "*.example.com",
|
||||||
|
queryDomain: "mail.example.com",
|
||||||
|
shouldMatch: true,
|
||||||
|
expectedResID: "test-res-id",
|
||||||
|
description: "Wildcard domains should allow subdomain access",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard should allow base domain",
|
||||||
|
configuredDomain: "*.example.com",
|
||||||
|
queryDomain: "example.com",
|
||||||
|
shouldMatch: true,
|
||||||
|
expectedResID: "test-res-id",
|
||||||
|
description: "Wildcard should also match the base domain",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deep subdomain should be restricted",
|
||||||
|
configuredDomain: "example.com",
|
||||||
|
queryDomain: "deep.mail.example.com",
|
||||||
|
shouldMatch: false,
|
||||||
|
expectedResID: "",
|
||||||
|
description: "Deep subdomains should not be accessible",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard allows deep subdomains",
|
||||||
|
configuredDomain: "*.example.com",
|
||||||
|
queryDomain: "deep.mail.example.com",
|
||||||
|
shouldMatch: true,
|
||||||
|
expectedResID: "test-res-id",
|
||||||
|
description: "Wildcard should allow deep subdomains",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
forwarder := &DNSForwarder{}
|
||||||
|
|
||||||
|
d, err := domain.FromString(tt.configuredDomain)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
entries := []*ForwarderEntry{
|
||||||
|
{
|
||||||
|
Domain: d,
|
||||||
|
ResID: "test-res-id",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
resID, matchingEntries := forwarder.getMatchingEntries(tt.queryDomain)
|
||||||
|
|
||||||
|
if tt.shouldMatch {
|
||||||
|
assert.Equal(t, tt.expectedResID, resID, "Expected matching ResID")
|
||||||
|
assert.NotEmpty(t, matchingEntries, "Expected matching entries")
|
||||||
|
t.Logf("✓ Domain %s correctly matches pattern %s", tt.queryDomain, tt.configuredDomain)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, tt.expectedResID, resID, "Expected no ResID match")
|
||||||
|
assert.Empty(t, matchingEntries, "Expected no matching entries")
|
||||||
|
t.Logf("✓ Domain %s correctly does NOT match pattern %s", tt.queryDomain, tt.configuredDomain)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
configuredDomain string
|
||||||
|
queryDomain string
|
||||||
|
shouldResolve bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "configured exact domain resolves",
|
||||||
|
configuredDomain: "example.com",
|
||||||
|
queryDomain: "example.com",
|
||||||
|
shouldResolve: true,
|
||||||
|
description: "Exact match should resolve",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized subdomain blocked",
|
||||||
|
configuredDomain: "example.com",
|
||||||
|
queryDomain: "mail.example.com",
|
||||||
|
shouldResolve: false,
|
||||||
|
description: "Subdomain should be blocked without wildcard",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard allows subdomain",
|
||||||
|
configuredDomain: "*.example.com",
|
||||||
|
queryDomain: "mail.example.com",
|
||||||
|
shouldResolve: true,
|
||||||
|
description: "Wildcard should allow subdomain",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard allows base domain",
|
||||||
|
configuredDomain: "*.example.com",
|
||||||
|
queryDomain: "example.com",
|
||||||
|
shouldResolve: true,
|
||||||
|
description: "Wildcard should allow base domain",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unrelated domain blocked",
|
||||||
|
configuredDomain: "example.com",
|
||||||
|
queryDomain: "example.org",
|
||||||
|
shouldResolve: false,
|
||||||
|
description: "Unrelated domain should be blocked",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deep subdomain blocked",
|
||||||
|
configuredDomain: "example.com",
|
||||||
|
queryDomain: "deep.mail.example.com",
|
||||||
|
shouldResolve: false,
|
||||||
|
description: "Deep subdomain should be blocked",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard allows deep subdomain",
|
||||||
|
configuredDomain: "*.example.com",
|
||||||
|
queryDomain: "deep.mail.example.com",
|
||||||
|
shouldResolve: true,
|
||||||
|
description: "Wildcard should allow deep subdomain",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mockFirewall := &MockFirewall{}
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
|
||||||
|
if tt.shouldResolve {
|
||||||
|
mockFirewall.On("UpdateSet", mock.AnythingOfType("manager.Set"), mock.AnythingOfType("[]netip.Prefix")).Return(nil)
|
||||||
|
|
||||||
|
// Mock successful DNS resolution
|
||||||
|
fakeIP := netip.MustParseAddr("1.2.3.4")
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||||
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
|
d, err := domain.FromString(tt.configuredDomain)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
entries := []*ForwarderEntry{
|
||||||
|
{
|
||||||
|
Domain: d,
|
||||||
|
ResID: "test-res-id",
|
||||||
|
Set: firewall.NewDomainSet([]domain.Domain{d}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
query := &dns.Msg{}
|
||||||
|
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
||||||
|
|
||||||
|
mockWriter := &test.MockResponseWriter{}
|
||||||
|
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||||
|
|
||||||
|
if tt.shouldResolve {
|
||||||
|
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||||
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
|
||||||
|
assert.NotEmpty(t, resp.Answer, "Expected DNS answer records")
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
mockFirewall.AssertExpectations(t)
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
} else {
|
||||||
|
if resp != nil {
|
||||||
|
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
||||||
|
"Unauthorized domain should not return successful answers")
|
||||||
|
}
|
||||||
|
mockFirewall.AssertNotCalled(t, "UpdateSet")
|
||||||
|
mockResolver.AssertNotCalled(t, "LookupNetIP")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
configuredDomains []string
|
||||||
|
query string
|
||||||
|
mockIP string
|
||||||
|
shouldResolve bool
|
||||||
|
expectedSetCount int // How many sets should be updated
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exact domain gets firewall update",
|
||||||
|
configuredDomains: []string{"example.com"},
|
||||||
|
query: "example.com",
|
||||||
|
mockIP: "1.1.1.1",
|
||||||
|
shouldResolve: true,
|
||||||
|
expectedSetCount: 1,
|
||||||
|
description: "Single exact match updates one set",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard domain gets firewall update",
|
||||||
|
configuredDomains: []string{"*.example.com"},
|
||||||
|
query: "mail.example.com",
|
||||||
|
mockIP: "1.1.1.2",
|
||||||
|
shouldResolve: true,
|
||||||
|
expectedSetCount: 1,
|
||||||
|
description: "Wildcard match updates one set",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "overlapping exact and wildcard both get updates",
|
||||||
|
configuredDomains: []string{"*.example.com", "mail.example.com"},
|
||||||
|
query: "mail.example.com",
|
||||||
|
mockIP: "1.1.1.3",
|
||||||
|
shouldResolve: true,
|
||||||
|
expectedSetCount: 2,
|
||||||
|
description: "Both exact and wildcard sets should be updated",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized domain gets no firewall update",
|
||||||
|
configuredDomains: []string{"example.com"},
|
||||||
|
query: "mail.example.com",
|
||||||
|
mockIP: "1.1.1.4",
|
||||||
|
shouldResolve: false,
|
||||||
|
expectedSetCount: 0,
|
||||||
|
description: "No firewall update for unauthorized domains",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple wildcards matching get all updated",
|
||||||
|
configuredDomains: []string{"*.example.com", "*.sub.example.com"},
|
||||||
|
query: "test.sub.example.com",
|
||||||
|
mockIP: "1.1.1.5",
|
||||||
|
shouldResolve: true,
|
||||||
|
expectedSetCount: 2,
|
||||||
|
description: "All matching wildcard sets should be updated",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mockFirewall := &MockFirewall{}
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
|
||||||
|
// Set up forwarder
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||||
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
|
// Create entries and track sets
|
||||||
|
var entries []*ForwarderEntry
|
||||||
|
sets := make([]firewall.Set, 0)
|
||||||
|
|
||||||
|
for i, configDomain := range tt.configuredDomains {
|
||||||
|
d, err := domain.FromString(configDomain)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
set := firewall.NewDomainSet([]domain.Domain{d})
|
||||||
|
sets = append(sets, set)
|
||||||
|
|
||||||
|
entries = append(entries, &ForwarderEntry{
|
||||||
|
Domain: d,
|
||||||
|
ResID: route.ResID(fmt.Sprintf("res-%d", i)),
|
||||||
|
Set: set,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
// Set up mocks
|
||||||
|
if tt.shouldResolve {
|
||||||
|
fakeIP := netip.MustParseAddr(tt.mockIP)
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.query)).
|
||||||
|
Return([]netip.Addr{fakeIP}, nil).Once()
|
||||||
|
|
||||||
|
expectedPrefixes := []netip.Prefix{netip.PrefixFrom(fakeIP, 32)}
|
||||||
|
|
||||||
|
// Count how many sets should actually match
|
||||||
|
updateCount := 0
|
||||||
|
for i, entry := range entries {
|
||||||
|
domain := strings.ToLower(tt.query)
|
||||||
|
pattern := entry.Domain.PunycodeString()
|
||||||
|
|
||||||
|
matches := false
|
||||||
|
if strings.HasPrefix(pattern, "*.") {
|
||||||
|
baseDomain := strings.TrimPrefix(pattern, "*.")
|
||||||
|
if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) {
|
||||||
|
matches = true
|
||||||
|
}
|
||||||
|
} else if domain == pattern {
|
||||||
|
matches = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if matches {
|
||||||
|
mockFirewall.On("UpdateSet", sets[i], expectedPrefixes).Return(nil).Once()
|
||||||
|
updateCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedSetCount, updateCount,
|
||||||
|
"Expected %d sets to be updated, but mock expects %d",
|
||||||
|
tt.expectedSetCount, updateCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute query
|
||||||
|
dnsQuery := &dns.Msg{}
|
||||||
|
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
||||||
|
|
||||||
|
mockWriter := &test.MockResponseWriter{}
|
||||||
|
resp := forwarder.handleDNSQuery(mockWriter, dnsQuery)
|
||||||
|
|
||||||
|
// Verify response
|
||||||
|
if tt.shouldResolve {
|
||||||
|
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||||
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
require.NotEmpty(t, resp.Answer)
|
||||||
|
} else if resp != nil {
|
||||||
|
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
|
||||||
|
"Unauthorized domain should be refused or have no answers")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all mock expectations were met
|
||||||
|
mockFirewall.AssertExpectations(t)
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test to verify that multiple IPs for one domain result in all prefixes being sent together
|
||||||
|
func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
||||||
|
mockFirewall := &MockFirewall{}
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||||
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
|
// Configure a single domain
|
||||||
|
d, err := domain.FromString("example.com")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
set := firewall.NewDomainSet([]domain.Domain{d})
|
||||||
|
entries := []*ForwarderEntry{{
|
||||||
|
Domain: d,
|
||||||
|
ResID: "test-res",
|
||||||
|
Set: set,
|
||||||
|
}}
|
||||||
|
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
// Mock resolver returns multiple IPs
|
||||||
|
ips := []netip.Addr{
|
||||||
|
netip.MustParseAddr("1.1.1.1"),
|
||||||
|
netip.MustParseAddr("1.1.1.2"),
|
||||||
|
netip.MustParseAddr("1.1.1.3"),
|
||||||
|
}
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
|
||||||
|
Return(ips, nil).Once()
|
||||||
|
|
||||||
|
// Expect ONE UpdateSet call with ALL prefixes
|
||||||
|
expectedPrefixes := []netip.Prefix{
|
||||||
|
netip.PrefixFrom(ips[0], 32),
|
||||||
|
netip.PrefixFrom(ips[1], 32),
|
||||||
|
netip.PrefixFrom(ips[2], 32),
|
||||||
|
}
|
||||||
|
mockFirewall.On("UpdateSet", set, expectedPrefixes).Return(nil).Once()
|
||||||
|
|
||||||
|
// Execute query
|
||||||
|
query := &dns.Msg{}
|
||||||
|
query.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
mockWriter := &test.MockResponseWriter{}
|
||||||
|
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||||
|
|
||||||
|
// Verify response contains all IPs
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
|
||||||
|
|
||||||
|
// Verify mocks
|
||||||
|
mockFirewall.AssertExpectations(t)
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
queryType uint16
|
||||||
|
queryDomain string
|
||||||
|
configured string
|
||||||
|
expectedCode int
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "unauthorized domain returns REFUSED",
|
||||||
|
queryType: dns.TypeA,
|
||||||
|
queryDomain: "evil.com",
|
||||||
|
configured: "example.com",
|
||||||
|
expectedCode: dns.RcodeRefused,
|
||||||
|
description: "RFC compliant REFUSED for unauthorized queries",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unsupported query type returns NOTIMP",
|
||||||
|
queryType: dns.TypeMX,
|
||||||
|
queryDomain: "example.com",
|
||||||
|
configured: "example.com",
|
||||||
|
expectedCode: dns.RcodeNotImplemented,
|
||||||
|
description: "RFC compliant NOTIMP for unsupported types",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "CNAME query returns NOTIMP",
|
||||||
|
queryType: dns.TypeCNAME,
|
||||||
|
queryDomain: "example.com",
|
||||||
|
configured: "example.com",
|
||||||
|
expectedCode: dns.RcodeNotImplemented,
|
||||||
|
description: "CNAME queries not supported",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TXT query returns NOTIMP",
|
||||||
|
queryType: dns.TypeTXT,
|
||||||
|
queryDomain: "example.com",
|
||||||
|
configured: "example.com",
|
||||||
|
expectedCode: dns.RcodeNotImplemented,
|
||||||
|
description: "TXT queries not supported",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||||
|
|
||||||
|
d, err := domain.FromString(tt.configured)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
query := &dns.Msg{}
|
||||||
|
query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType)
|
||||||
|
|
||||||
|
// Capture the written response
|
||||||
|
var writtenResp *dns.Msg
|
||||||
|
mockWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
writtenResp = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = forwarder.handleDNSQuery(mockWriter, query)
|
||||||
|
|
||||||
|
// Check the response written to the writer
|
||||||
|
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||||
|
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||||
|
// Test that large UDP responses are truncated with TC bit set
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||||
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
|
d, _ := domain.FromString("example.com")
|
||||||
|
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
// Mock many IPs to create a large response
|
||||||
|
var manyIPs []netip.Addr
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
manyIPs = append(manyIPs, netip.MustParseAddr(fmt.Sprintf("1.1.1.%d", i%256)))
|
||||||
|
}
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").Return(manyIPs, nil)
|
||||||
|
|
||||||
|
// Query without EDNS0
|
||||||
|
query := &dns.Msg{}
|
||||||
|
query.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
var writtenResp *dns.Msg
|
||||||
|
mockWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
writtenResp = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
forwarder.handleDNSQueryUDP(mockWriter, query)
|
||||||
|
|
||||||
|
require.NotNil(t, writtenResp)
|
||||||
|
assert.True(t, writtenResp.Truncated, "Large response should be truncated")
|
||||||
|
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
||||||
|
// Test complex overlapping pattern scenarios
|
||||||
|
mockFirewall := &MockFirewall{}
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||||
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
|
// Set up complex overlapping patterns
|
||||||
|
patterns := []string{
|
||||||
|
"*.example.com", // Matches all subdomains
|
||||||
|
"*.mail.example.com", // More specific wildcard
|
||||||
|
"smtp.mail.example.com", // Exact match
|
||||||
|
"example.com", // Base domain
|
||||||
|
}
|
||||||
|
|
||||||
|
var entries []*ForwarderEntry
|
||||||
|
sets := make(map[string]firewall.Set)
|
||||||
|
|
||||||
|
for _, pattern := range patterns {
|
||||||
|
d, _ := domain.FromString(pattern)
|
||||||
|
set := firewall.NewDomainSet([]domain.Domain{d})
|
||||||
|
sets[pattern] = set
|
||||||
|
entries = append(entries, &ForwarderEntry{
|
||||||
|
Domain: d,
|
||||||
|
ResID: route.ResID("res-" + pattern),
|
||||||
|
Set: set,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
// Test smtp.mail.example.com - should match 3 patterns
|
||||||
|
fakeIP := netip.MustParseAddr("1.2.3.4")
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "smtp.mail.example.com.").Return([]netip.Addr{fakeIP}, nil)
|
||||||
|
|
||||||
|
expectedPrefix := netip.PrefixFrom(fakeIP, 32)
|
||||||
|
// All three matching patterns should get firewall updates
|
||||||
|
mockFirewall.On("UpdateSet", sets["smtp.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
|
||||||
|
mockFirewall.On("UpdateSet", sets["*.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
|
||||||
|
mockFirewall.On("UpdateSet", sets["*.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
|
||||||
|
|
||||||
|
query := &dns.Msg{}
|
||||||
|
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
mockWriter := &test.MockResponseWriter{}
|
||||||
|
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||||
|
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
|
||||||
|
// Verify all three sets were updated
|
||||||
|
mockFirewall.AssertExpectations(t)
|
||||||
|
|
||||||
|
// Verify the most specific ResID was selected
|
||||||
|
// (exact match should win over wildcards)
|
||||||
|
resID, matches := forwarder.getMatchingEntries("smtp.mail.example.com")
|
||||||
|
assert.Equal(t, route.ResID("res-smtp.mail.example.com"), resID)
|
||||||
|
assert.Len(t, matches, 3, "Should match 3 patterns")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
||||||
|
// Test handling of malformed query with no questions
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||||
|
|
||||||
|
query := &dns.Msg{}
|
||||||
|
// Don't set any question
|
||||||
|
|
||||||
|
writeCalled := false
|
||||||
|
mockWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
writeCalled = true
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||||
|
|
||||||
|
assert.Nil(t, resp, "Should return nil for empty query")
|
||||||
|
assert.False(t, writeCalled, "Should not write response for empty query")
|
||||||
|
}
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ type Manager struct {
|
|||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
|
|
||||||
fwRules []firewall.Rule
|
fwRules []firewall.Rule
|
||||||
|
tcpRules []firewall.Rule
|
||||||
dnsForwarder *DNSForwarder
|
dnsForwarder *DNSForwarder
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -107,6 +108,13 @@ func (m *Manager) allowDNSFirewall() error {
|
|||||||
}
|
}
|
||||||
m.fwRules = dnsRules
|
m.fwRules = dnsRules
|
||||||
|
|
||||||
|
tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "")
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.tcpRules = tcpRules
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -117,7 +125,13 @@ func (m *Manager) dropDNSFirewall() error {
|
|||||||
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for _, rule := range m.tcpRules {
|
||||||
|
if err := m.firewall.DeletePeerRule(rule); err != nil {
|
||||||
|
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
m.fwRules = nil
|
m.fwRules = nil
|
||||||
|
m.tcpRules = nil
|
||||||
return nberrors.FormatErrorOrNil(mErr)
|
return nberrors.FormatErrorOrNil(mErr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ import (
|
|||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
@@ -120,8 +121,10 @@ type EngineConfig struct {
|
|||||||
DisableServerRoutes bool
|
DisableServerRoutes bool
|
||||||
DisableDNS bool
|
DisableDNS bool
|
||||||
DisableFirewall bool
|
DisableFirewall bool
|
||||||
|
|
||||||
BlockLANAccess bool
|
BlockLANAccess bool
|
||||||
|
BlockInbound bool
|
||||||
|
|
||||||
|
LazyConnectionEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||||
@@ -134,6 +137,8 @@ type Engine struct {
|
|||||||
// peerConns is a map that holds all the peers that are known to this peer
|
// peerConns is a map that holds all the peers that are known to this peer
|
||||||
peerStore *peerstore.Store
|
peerStore *peerstore.Store
|
||||||
|
|
||||||
|
connMgr *ConnMgr
|
||||||
|
|
||||||
beforePeerHook nbnet.AddHookFunc
|
beforePeerHook nbnet.AddHookFunc
|
||||||
afterPeerHook nbnet.RemoveHookFunc
|
afterPeerHook nbnet.RemoveHookFunc
|
||||||
|
|
||||||
@@ -171,6 +176,7 @@ type Engine struct {
|
|||||||
sshServer nbssh.Server
|
sshServer nbssh.Server
|
||||||
|
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
|
peerConnDispatcher *dispatcher.ConnectionDispatcher
|
||||||
|
|
||||||
firewall firewallManager.Manager
|
firewall firewallManager.Manager
|
||||||
routeManager routemanager.Manager
|
routeManager routemanager.Manager
|
||||||
@@ -235,6 +241,8 @@ func NewEngine(
|
|||||||
checks: checks,
|
checks: checks,
|
||||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
path := statemanager.GetDefaultStatePath()
|
||||||
if runtime.GOOS == "ios" {
|
if runtime.GOOS == "ios" {
|
||||||
if !fileExists(mobileDep.StateFilePath) {
|
if !fileExists(mobileDep.StateFilePath) {
|
||||||
err := createFile(mobileDep.StateFilePath)
|
err := createFile(mobileDep.StateFilePath)
|
||||||
@@ -244,11 +252,9 @@ func NewEngine(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
engine.stateManager = statemanager.New(mobileDep.StateFilePath)
|
path = mobileDep.StateFilePath
|
||||||
}
|
}
|
||||||
if path := statemanager.GetDefaultStatePath(); path != "" {
|
|
||||||
engine.stateManager = statemanager.New(path)
|
engine.stateManager = statemanager.New(path)
|
||||||
}
|
|
||||||
|
|
||||||
return engine
|
return engine
|
||||||
}
|
}
|
||||||
@@ -262,6 +268,10 @@ func (e *Engine) Stop() error {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
if e.connMgr != nil {
|
||||||
|
e.connMgr.Close()
|
||||||
|
}
|
||||||
|
|
||||||
// stopping network monitor first to avoid starting the engine again
|
// stopping network monitor first to avoid starting the engine again
|
||||||
if e.networkMonitor != nil {
|
if e.networkMonitor != nil {
|
||||||
e.networkMonitor.Stop()
|
e.networkMonitor.Stop()
|
||||||
@@ -297,8 +307,7 @@ func (e *Engine) Stop() error {
|
|||||||
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
||||||
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
|
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
|
||||||
|
|
||||||
err := e.removeAllPeers()
|
if err := e.removeAllPeers(); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to remove all peers: %s", err)
|
return fmt.Errorf("failed to remove all peers: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -350,6 +359,7 @@ func (e *Engine) Start() error {
|
|||||||
return fmt.Errorf("new wg interface: %w", err)
|
return fmt.Errorf("new wg interface: %w", err)
|
||||||
}
|
}
|
||||||
e.wgInterface = wgIface
|
e.wgInterface = wgIface
|
||||||
|
e.statusRecorder.SetWgIface(wgIface)
|
||||||
|
|
||||||
// start flow manager right after interface creation
|
// start flow manager right after interface creation
|
||||||
publicKey := e.config.WgPrivateKey.PublicKey()
|
publicKey := e.config.WgPrivateKey.PublicKey()
|
||||||
@@ -371,7 +381,6 @@ func (e *Engine) Start() error {
|
|||||||
return fmt.Errorf("run rosenpass manager: %w", err)
|
return fmt.Errorf("run rosenpass manager: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
e.stateManager.Start()
|
e.stateManager.Start()
|
||||||
|
|
||||||
initialRoutes, dnsServer, err := e.newDnsServer()
|
initialRoutes, dnsServer, err := e.newDnsServer()
|
||||||
@@ -405,8 +414,7 @@ func (e *Engine) Start() error {
|
|||||||
|
|
||||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||||
|
|
||||||
err = e.wgInterfaceCreate()
|
if err = e.wgInterfaceCreate(); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
||||||
e.close()
|
e.close()
|
||||||
return fmt.Errorf("create wg interface: %w", err)
|
return fmt.Errorf("create wg interface: %w", err)
|
||||||
@@ -423,7 +431,8 @@ func (e *Engine) Start() error {
|
|||||||
return fmt.Errorf("up wg interface: %w", err)
|
return fmt.Errorf("up wg interface: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.firewall != nil {
|
// if inbound conns are blocked there is no need to create the ACL manager
|
||||||
|
if e.firewall != nil && !e.config.BlockInbound {
|
||||||
e.acl = acl.NewDefaultManager(e.firewall)
|
e.acl = acl.NewDefaultManager(e.firewall)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -442,6 +451,11 @@ func (e *Engine) Start() error {
|
|||||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
e.peerConnDispatcher = dispatcher.NewConnectionDispatcher()
|
||||||
|
|
||||||
|
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface, e.peerConnDispatcher)
|
||||||
|
e.connMgr.Start(e.ctx)
|
||||||
|
|
||||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||||
e.srWatcher.Start()
|
e.srWatcher.Start()
|
||||||
|
|
||||||
@@ -450,7 +464,6 @@ func (e *Engine) Start() error {
|
|||||||
|
|
||||||
// starting network monitor at the very last to avoid disruptions
|
// starting network monitor at the very last to avoid disruptions
|
||||||
e.startNetworkMonitor()
|
e.startNetworkMonitor()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -475,12 +488,10 @@ func (e *Engine) createFirewall() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) initFirewall() error {
|
func (e *Engine) initFirewall() error {
|
||||||
if e.firewall.IsServerRouteSupported() {
|
|
||||||
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
|
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
|
||||||
e.close()
|
e.close()
|
||||||
return fmt.Errorf("enable server router: %w", err)
|
return fmt.Errorf("enable server router: %w", err)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if e.config.BlockLANAccess {
|
if e.config.BlockLANAccess {
|
||||||
e.blockLanAccess()
|
e.blockLanAccess()
|
||||||
@@ -513,6 +524,11 @@ func (e *Engine) initFirewall() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) blockLanAccess() {
|
func (e *Engine) blockLanAccess() {
|
||||||
|
if e.config.BlockInbound {
|
||||||
|
// no need to set up extra deny rules if inbound is already blocked in general
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
// TODO: keep this updated
|
// TODO: keep this updated
|
||||||
@@ -550,6 +566,16 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
|||||||
var modified []*mgmProto.RemotePeerConfig
|
var modified []*mgmProto.RemotePeerConfig
|
||||||
for _, p := range peersUpdate {
|
for _, p := range peersUpdate {
|
||||||
peerPubKey := p.GetWgPubKey()
|
peerPubKey := p.GetWgPubKey()
|
||||||
|
currentPeer, ok := e.peerStore.PeerConn(peerPubKey)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentPeer.AgentVersionString() != p.AgentVersion {
|
||||||
|
modified = append(modified, p)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey)
|
allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey)
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
@@ -559,8 +585,7 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn())
|
if err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn()); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err)
|
log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -621,17 +646,12 @@ func (e *Engine) removePeer(peerKey string) error {
|
|||||||
e.sshServer.RemoveAuthorizedKey(peerKey)
|
e.sshServer.RemoveAuthorizedKey(peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
e.connMgr.RemovePeerConn(peerKey)
|
||||||
|
|
||||||
err := e.statusRecorder.RemovePeer(peerKey)
|
err := e.statusRecorder.RemovePeer(peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("received error when removing peer %s from status recorder: %v", peerKey, err)
|
log.Warnf("received error when removing peer %s from status recorder: %v", peerKey, err)
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
|
|
||||||
conn, exists := e.peerStore.Remove(peerKey)
|
|
||||||
if exists {
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -766,6 +786,9 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
|||||||
e.config.DisableServerRoutes,
|
e.config.DisableServerRoutes,
|
||||||
e.config.DisableDNS,
|
e.config.DisableDNS,
|
||||||
e.config.DisableFirewall,
|
e.config.DisableFirewall,
|
||||||
|
e.config.BlockLANAccess,
|
||||||
|
e.config.BlockInbound,
|
||||||
|
e.config.LazyConnectionEnabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||||
@@ -780,11 +803,15 @@ func isNil(server nbssh.Server) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||||
|
if e.config.BlockInbound {
|
||||||
|
log.Infof("SSH server is disabled because inbound connections are blocked")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if !e.config.ServerSSHAllowed {
|
if !e.config.ServerSSHAllowed {
|
||||||
log.Warnf("running SSH server is not permitted")
|
log.Info("SSH server is not enabled")
|
||||||
return nil
|
return nil
|
||||||
} else {
|
}
|
||||||
|
|
||||||
if sshConf.GetSshEnabled() {
|
if sshConf.GetSshEnabled() {
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
@@ -828,8 +855,6 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
|||||||
e.sshServer = nil
|
e.sshServer = nil
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||||
@@ -883,6 +908,9 @@ func (e *Engine) receiveManagementEvents() {
|
|||||||
e.config.DisableServerRoutes,
|
e.config.DisableServerRoutes,
|
||||||
e.config.DisableDNS,
|
e.config.DisableDNS,
|
||||||
e.config.DisableFirewall,
|
e.config.DisableFirewall,
|
||||||
|
e.config.BlockLANAccess,
|
||||||
|
e.config.BlockInbound,
|
||||||
|
e.config.LazyConnectionEnabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
// err = e.mgmClient.Sync(info, e.handleSync)
|
// err = e.mgmClient.Sync(info, e.handleSync)
|
||||||
@@ -952,20 +980,49 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := e.connMgr.UpdatedRemoteFeatureFlag(e.ctx, networkMap.GetPeerConfig().GetLazyConnectionEnabled()); err != nil {
|
||||||
|
log.Errorf("failed to update lazy connection feature flag: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if e.firewall != nil {
|
if e.firewall != nil {
|
||||||
if localipfw, ok := e.firewall.(localIpUpdater); ok {
|
if localipfw, ok := e.firewall.(localIpUpdater); ok {
|
||||||
if err := localipfw.UpdateLocalIPs(); err != nil {
|
if err := localipfw.UpdateLocalIPs(); err != nil {
|
||||||
log.Errorf("failed to update local IPs: %v", err)
|
log.Errorf("failed to update local IPs: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
|
||||||
|
// then the mgmt server is older than the client, and we need to allow all traffic for routes.
|
||||||
|
// This needs to be toggled before applying routes.
|
||||||
|
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
|
||||||
|
if err := e.firewall.SetLegacyManagement(isLegacy); err != nil {
|
||||||
|
log.Errorf("failed to set legacy management flag: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protoDNSConfig := networkMap.GetDNSConfig()
|
||||||
|
if protoDNSConfig == nil {
|
||||||
|
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
|
||||||
|
log.Errorf("failed to update dns server, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
||||||
|
|
||||||
// apply routes first, route related actions might depend on routing being enabled
|
// apply routes first, route related actions might depend on routing being enabled
|
||||||
routes := toRoutes(networkMap.GetRoutes())
|
routes := toRoutes(networkMap.GetRoutes())
|
||||||
if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil {
|
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
|
||||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
|
||||||
|
// lazy mgr needs to be aware of which routes are available before they are applied
|
||||||
|
if e.connMgr != nil {
|
||||||
|
e.connMgr.UpdateRouteHAMap(clientRoutes)
|
||||||
|
log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil {
|
||||||
|
log.Errorf("failed to update routes: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.acl != nil {
|
if e.acl != nil {
|
||||||
@@ -976,7 +1033,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
|
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
|
||||||
|
|
||||||
// Ingress forward rules
|
// Ingress forward rules
|
||||||
if err := e.updateForwardRules(networkMap.GetForwardingRules()); err != nil {
|
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
|
||||||
|
if err != nil {
|
||||||
log.Errorf("failed to update forward rules, err: %v", err)
|
log.Errorf("failed to update forward rules, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1022,14 +1080,9 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protoDNSConfig := networkMap.GetDNSConfig()
|
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||||
if protoDNSConfig == nil {
|
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, networkMap.GetRemotePeers())
|
||||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
|
||||||
}
|
|
||||||
|
|
||||||
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
|
|
||||||
log.Errorf("failed to update dns server, err: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
e.networkSerial = serial
|
e.networkSerial = serial
|
||||||
|
|
||||||
@@ -1065,7 +1118,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
|||||||
|
|
||||||
convertedRoute := &route.Route{
|
convertedRoute := &route.Route{
|
||||||
ID: route.ID(protoRoute.ID),
|
ID: route.ID(protoRoute.ID),
|
||||||
Network: prefix,
|
Network: prefix.Masked(),
|
||||||
Domains: domain.FromPunycodeList(protoRoute.Domains),
|
Domains: domain.FromPunycodeList(protoRoute.Domains),
|
||||||
NetID: route.NetID(protoRoute.NetID),
|
NetID: route.NetID(protoRoute.NetID),
|
||||||
NetworkType: route.NetworkType(protoRoute.NetworkType),
|
NetworkType: route.NetworkType(protoRoute.NetworkType),
|
||||||
@@ -1099,7 +1152,7 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
|
|||||||
return entries
|
return entries
|
||||||
}
|
}
|
||||||
|
|
||||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config {
|
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
|
||||||
dnsUpdate := nbdns.Config{
|
dnsUpdate := nbdns.Config{
|
||||||
ServiceEnable: protoDNSConfig.GetServiceEnable(),
|
ServiceEnable: protoDNSConfig.GetServiceEnable(),
|
||||||
CustomZones: make([]nbdns.CustomZone, 0),
|
CustomZones: make([]nbdns.CustomZone, 0),
|
||||||
@@ -1155,7 +1208,7 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
|
|||||||
IP: strings.Join(offlinePeer.GetAllowedIps(), ","),
|
IP: strings.Join(offlinePeer.GetAllowedIps(), ","),
|
||||||
PubKey: offlinePeer.GetWgPubKey(),
|
PubKey: offlinePeer.GetWgPubKey(),
|
||||||
FQDN: offlinePeer.GetFqdn(),
|
FQDN: offlinePeer.GetFqdn(),
|
||||||
ConnStatus: peer.StatusDisconnected,
|
ConnStatus: peer.StatusIdle,
|
||||||
ConnStatusUpdate: time.Now(),
|
ConnStatusUpdate: time.Now(),
|
||||||
Mux: new(sync.RWMutex),
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
@@ -1191,12 +1244,17 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
|||||||
peerIPs = append(peerIPs, allowedNetIP)
|
peerIPs = append(peerIPs, allowedNetIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := e.createPeerConn(peerKey, peerIPs)
|
conn, err := e.createPeerConn(peerKey, peerIPs, peerConfig.AgentVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create peer connection: %w", err)
|
return fmt.Errorf("create peer connection: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok {
|
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn, peerIPs[0].Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if exists := e.connMgr.AddPeerConn(e.ctx, peerKey, conn); exists {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return fmt.Errorf("peer already exists: %s", peerKey)
|
return fmt.Errorf("peer already exists: %s", peerKey)
|
||||||
}
|
}
|
||||||
@@ -1205,17 +1263,10 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
|||||||
conn.AddBeforeAddPeerHook(e.beforePeerHook)
|
conn.AddBeforeAddPeerHook(e.beforePeerHook)
|
||||||
conn.AddAfterRemovePeerHook(e.afterPeerHook)
|
conn.AddAfterRemovePeerHook(e.afterPeerHook)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
conn.Open()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer.Conn, error) {
|
func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentVersion string) (*peer.Conn, error) {
|
||||||
log.Debugf("creating peer connection %s", pubKey)
|
log.Debugf("creating peer connection %s", pubKey)
|
||||||
|
|
||||||
wgConfig := peer.WgConfig{
|
wgConfig := peer.WgConfig{
|
||||||
@@ -1231,6 +1282,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer
|
|||||||
config := peer.ConnConfig{
|
config := peer.ConnConfig{
|
||||||
Key: pubKey,
|
Key: pubKey,
|
||||||
LocalKey: e.config.WgPrivateKey.PublicKey().String(),
|
LocalKey: e.config.WgPrivateKey.PublicKey().String(),
|
||||||
|
AgentVersion: agentVersion,
|
||||||
Timeout: timeout,
|
Timeout: timeout,
|
||||||
WgConfig: wgConfig,
|
WgConfig: wgConfig,
|
||||||
LocalWgPort: e.config.WgPort,
|
LocalWgPort: e.config.WgPort,
|
||||||
@@ -1249,7 +1301,16 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher, e.connSemaphore)
|
serviceDependencies := peer.ServiceDependencies{
|
||||||
|
StatusRecorder: e.statusRecorder,
|
||||||
|
Signaler: e.signaler,
|
||||||
|
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||||
|
RelayManager: e.relayManager,
|
||||||
|
SrWatcher: e.srWatcher,
|
||||||
|
Semaphore: e.connSemaphore,
|
||||||
|
PeerConnDispatcher: e.peerConnDispatcher,
|
||||||
|
}
|
||||||
|
peerConn, err := peer.NewConn(config, serviceDependencies)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1270,7 +1331,7 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
conn, ok := e.peerStore.PeerConn(msg.Key)
|
conn, ok := e.connMgr.OnSignalMsg(e.ctx, msg.Key)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
||||||
}
|
}
|
||||||
@@ -1406,6 +1467,7 @@ func (e *Engine) close() {
|
|||||||
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
|
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
|
||||||
}
|
}
|
||||||
e.wgInterface = nil
|
e.wgInterface = nil
|
||||||
|
e.statusRecorder.SetWgIface(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isNil(e.sshServer) {
|
if !isNil(e.sshServer) {
|
||||||
@@ -1437,6 +1499,9 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
|||||||
e.config.DisableServerRoutes,
|
e.config.DisableServerRoutes,
|
||||||
e.config.DisableDNS,
|
e.config.DisableDNS,
|
||||||
e.config.DisableFirewall,
|
e.config.DisableFirewall,
|
||||||
|
e.config.BlockLANAccess,
|
||||||
|
e.config.BlockInbound,
|
||||||
|
e.config.LazyConnectionEnabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
netMap, err := e.mgmClient.GetNetworkMap(info)
|
||||||
@@ -1462,6 +1527,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
|
|||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
TransportNet: transportNet,
|
TransportNet: transportNet,
|
||||||
FilterFn: e.addrViaRoutes,
|
FilterFn: e.addrViaRoutes,
|
||||||
|
DisableDNS: e.config.DisableDNS,
|
||||||
}
|
}
|
||||||
|
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
@@ -1578,13 +1644,39 @@ func (e *Engine) getRosenpassAddr() string {
|
|||||||
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
|
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
|
||||||
// and updates the status recorder with the latest states.
|
// and updates the status recorder with the latest states.
|
||||||
func (e *Engine) RunHealthProbes() bool {
|
func (e *Engine) RunHealthProbes() bool {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
|
||||||
signalHealthy := e.signal.IsHealthy()
|
signalHealthy := e.signal.IsHealthy()
|
||||||
log.Debugf("signal health check: healthy=%t", signalHealthy)
|
log.Debugf("signal health check: healthy=%t", signalHealthy)
|
||||||
|
|
||||||
managementHealthy := e.mgmClient.IsHealthy()
|
managementHealthy := e.mgmClient.IsHealthy()
|
||||||
log.Debugf("management health check: healthy=%t", managementHealthy)
|
log.Debugf("management health check: healthy=%t", managementHealthy)
|
||||||
|
|
||||||
results := append(e.probeSTUNs(), e.probeTURNs()...)
|
stuns := slices.Clone(e.STUNs)
|
||||||
|
turns := slices.Clone(e.TURNs)
|
||||||
|
|
||||||
|
if e.wgInterface != nil {
|
||||||
|
stats, err := e.wgInterface.GetStats()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to get wireguard stats: %v", err)
|
||||||
|
e.syncMsgMux.Unlock()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, key := range e.peerStore.PeersPubKey() {
|
||||||
|
// wgStats could be zero value, in which case we just reset the stats
|
||||||
|
wgStats, ok := stats[key]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
|
||||||
|
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
results := e.probeICE(stuns, turns)
|
||||||
e.statusRecorder.UpdateRelayStates(results)
|
e.statusRecorder.UpdateRelayStates(results)
|
||||||
|
|
||||||
relayHealthy := true
|
relayHealthy := true
|
||||||
@@ -1596,37 +1688,16 @@ func (e *Engine) RunHealthProbes() bool {
|
|||||||
}
|
}
|
||||||
log.Debugf("relay health check: healthy=%t", relayHealthy)
|
log.Debugf("relay health check: healthy=%t", relayHealthy)
|
||||||
|
|
||||||
for _, key := range e.peerStore.PeersPubKey() {
|
|
||||||
wgStats, err := e.wgInterface.GetStats(key)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to get wg stats for peer %s: %s", key, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// wgStats could be zero value, in which case we just reset the stats
|
|
||||||
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
|
|
||||||
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
allHealthy := signalHealthy && managementHealthy && relayHealthy
|
allHealthy := signalHealthy && managementHealthy && relayHealthy
|
||||||
log.Debugf("all health checks completed: healthy=%t", allHealthy)
|
log.Debugf("all health checks completed: healthy=%t", allHealthy)
|
||||||
return allHealthy
|
return allHealthy
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) probeSTUNs() []relay.ProbeResult {
|
func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
|
||||||
e.syncMsgMux.Lock()
|
return append(
|
||||||
stuns := slices.Clone(e.STUNs)
|
relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns),
|
||||||
e.syncMsgMux.Unlock()
|
relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)...,
|
||||||
|
)
|
||||||
return relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *Engine) probeTURNs() []relay.ProbeResult {
|
|
||||||
e.syncMsgMux.Lock()
|
|
||||||
turns := slices.Clone(e.TURNs)
|
|
||||||
e.syncMsgMux.Unlock()
|
|
||||||
|
|
||||||
return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// restartEngine restarts the engine by cancelling the client context
|
// restartEngine restarts the engine by cancelling the client context
|
||||||
@@ -1738,9 +1809,9 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetWgAddr returns the wireguard address
|
// GetWgAddr returns the wireguard address
|
||||||
func (e *Engine) GetWgAddr() net.IP {
|
func (e *Engine) GetWgAddr() netip.Addr {
|
||||||
if e.wgInterface == nil {
|
if e.wgInterface == nil {
|
||||||
return nil
|
return netip.Addr{}
|
||||||
}
|
}
|
||||||
return e.wgInterface.Address().IP
|
return e.wgInterface.Address().IP
|
||||||
}
|
}
|
||||||
@@ -1750,6 +1821,10 @@ func (e *Engine) updateDNSForwarder(
|
|||||||
enabled bool,
|
enabled bool,
|
||||||
fwdEntries []*dnsfwd.ForwarderEntry,
|
fwdEntries []*dnsfwd.ForwarderEntry,
|
||||||
) {
|
) {
|
||||||
|
if e.config.DisableServerRoutes {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if !enabled {
|
if !enabled {
|
||||||
if e.dnsForwardMgr == nil {
|
if e.dnsForwardMgr == nil {
|
||||||
return
|
return
|
||||||
@@ -1805,29 +1880,24 @@ func (e *Engine) Address() (netip.Addr, error) {
|
|||||||
return netip.Addr{}, errors.New("wireguard interface not initialized")
|
return netip.Addr{}, errors.New("wireguard interface not initialized")
|
||||||
}
|
}
|
||||||
|
|
||||||
addr := e.wgInterface.Address()
|
return e.wgInterface.Address().IP, nil
|
||||||
ip, ok := netip.AddrFromSlice(addr.IP)
|
|
||||||
if !ok {
|
|
||||||
return netip.Addr{}, errors.New("failed to convert address to netip.Addr")
|
|
||||||
}
|
|
||||||
return ip.Unmap(), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error {
|
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {
|
||||||
if e.firewall == nil {
|
if e.firewall == nil {
|
||||||
log.Warn("firewall is disabled, not updating forwarding rules")
|
log.Warn("firewall is disabled, not updating forwarding rules")
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(rules) == 0 {
|
if len(rules) == 0 {
|
||||||
if e.ingressGatewayMgr == nil {
|
if e.ingressGatewayMgr == nil {
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := e.ingressGatewayMgr.Close()
|
err := e.ingressGatewayMgr.Close()
|
||||||
e.ingressGatewayMgr = nil
|
e.ingressGatewayMgr = nil
|
||||||
e.statusRecorder.SetIngressGwMgr(nil)
|
e.statusRecorder.SetIngressGwMgr(nil)
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.ingressGatewayMgr == nil {
|
if e.ingressGatewayMgr == nil {
|
||||||
@@ -1878,7 +1948,25 @@ func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error {
|
|||||||
log.Errorf("failed to update forwarding rules: %v", err)
|
log.Errorf("failed to update forwarding rules: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return forwardingRules, nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) toExcludedLazyPeers(rules []firewallManager.ForwardRule, peers []*mgmProto.RemotePeerConfig) map[string]bool {
|
||||||
|
excludedPeers := make(map[string]bool)
|
||||||
|
for _, r := range rules {
|
||||||
|
ip := r.TranslatedAddress
|
||||||
|
for _, p := range peers {
|
||||||
|
for _, allowedIP := range p.GetAllowedIps() {
|
||||||
|
if allowedIP != ip.String() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Infof("exclude forwarder peer from lazy connection: %s", p.GetWgPubKey())
|
||||||
|
excludedPeers[p.GetWgPubKey()] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return excludedPeers
|
||||||
}
|
}
|
||||||
|
|
||||||
// isChecksEqual checks if two slices of checks are equal.
|
// isChecksEqual checks if two slices of checks are equal.
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user