mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-12 03:39:55 +00:00
Compare commits
1 Commits
v0.70.3
...
vnc-server
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b754df1171 |
158
.github/workflows/release.yml
vendored
158
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.1.4"
|
SIGN_PIPE_VER: "v0.1.2"
|
||||||
GORELEASER_VER: "v2.14.3"
|
GORELEASER_VER: "v2.14.3"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "NetBird GmbH"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
@@ -115,12 +115,6 @@ jobs:
|
|||||||
|
|
||||||
release:
|
release:
|
||||||
runs-on: ubuntu-latest-m
|
runs-on: ubuntu-latest-m
|
||||||
outputs:
|
|
||||||
release_artifact_url: ${{ steps.upload_release.outputs.artifact-url }}
|
|
||||||
linux_packages_artifact_url: ${{ steps.upload_linux_packages.outputs.artifact-url }}
|
|
||||||
windows_packages_artifact_url: ${{ steps.upload_windows_packages.outputs.artifact-url }}
|
|
||||||
macos_packages_artifact_url: ${{ steps.upload_macos_packages.outputs.artifact-url }}
|
|
||||||
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
|
|
||||||
env:
|
env:
|
||||||
flags: ""
|
flags: ""
|
||||||
steps:
|
steps:
|
||||||
@@ -219,13 +213,10 @@ jobs:
|
|||||||
if: always()
|
if: always()
|
||||||
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||||
- name: Tag and push images (amd64 only)
|
- name: Tag and push images (amd64 only)
|
||||||
id: tag_and_push_images
|
|
||||||
if: |
|
if: |
|
||||||
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) ||
|
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) ||
|
||||||
(github.event_name == 'push' && github.ref == 'refs/heads/main')
|
(github.event_name == 'push' && github.ref == 'refs/heads/main')
|
||||||
run: |
|
run: |
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
resolve_tags() {
|
resolve_tags() {
|
||||||
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
|
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
|
||||||
echo "pr-${{ github.event.pull_request.number }}"
|
echo "pr-${{ github.event.pull_request.number }}"
|
||||||
@@ -234,17 +225,6 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
ghcr_package_url() {
|
|
||||||
local image="$1" package encoded_package
|
|
||||||
package="${image#ghcr.io/}"
|
|
||||||
package="${package#*/}"
|
|
||||||
package="${package%%:*}"
|
|
||||||
encoded_package="${package//\//%2F}"
|
|
||||||
echo "https://github.com/orgs/netbirdio/packages/container/package/${encoded_package}"
|
|
||||||
}
|
|
||||||
|
|
||||||
image_refs=()
|
|
||||||
|
|
||||||
tag_and_push() {
|
tag_and_push() {
|
||||||
local src="$1" img_name tag dst
|
local src="$1" img_name tag dst
|
||||||
img_name="${src%%:*}"
|
img_name="${src%%:*}"
|
||||||
@@ -253,56 +233,35 @@ jobs:
|
|||||||
echo "Tagging ${src} -> ${dst}"
|
echo "Tagging ${src} -> ${dst}"
|
||||||
docker tag "$src" "$dst"
|
docker tag "$src" "$dst"
|
||||||
docker push "$dst"
|
docker push "$dst"
|
||||||
image_refs+=("$dst")
|
|
||||||
done
|
done
|
||||||
}
|
}
|
||||||
|
|
||||||
cat > /tmp/goreleaser-artifacts.json <<'JSON'
|
export -f tag_and_push resolve_tags
|
||||||
${{ steps.goreleaser.outputs.artifacts }}
|
|
||||||
JSON
|
|
||||||
|
|
||||||
mapfile -t src_images < <(
|
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
|
||||||
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name | select(startswith("ghcr.io/"))' /tmp/goreleaser-artifacts.json
|
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
|
||||||
)
|
grep '^ghcr.io/' | while read -r SRC; do
|
||||||
|
tag_and_push "$SRC"
|
||||||
for src in "${src_images[@]}"; do
|
done
|
||||||
tag_and_push "$src"
|
|
||||||
done
|
|
||||||
|
|
||||||
{
|
|
||||||
echo "images_markdown<<EOF"
|
|
||||||
if [[ ${#image_refs[@]} -eq 0 ]]; then
|
|
||||||
echo "_No GHCR images were pushed._"
|
|
||||||
else
|
|
||||||
printf '%s\n' "${image_refs[@]}" | sort -u | while read -r image; do
|
|
||||||
printf -- '- [`%s`](%s)\n' "$image" "$(ghcr_package_url "$image")"
|
|
||||||
done
|
|
||||||
fi
|
|
||||||
echo "EOF"
|
|
||||||
} >> "$GITHUB_OUTPUT"
|
|
||||||
- name: upload non tags for debug purposes
|
- name: upload non tags for debug purposes
|
||||||
id: upload_release
|
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: release
|
name: release
|
||||||
path: dist/
|
path: dist/
|
||||||
retention-days: 7
|
retention-days: 7
|
||||||
- name: upload linux packages
|
- name: upload linux packages
|
||||||
id: upload_linux_packages
|
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: linux-packages
|
name: linux-packages
|
||||||
path: dist/netbird_linux**
|
path: dist/netbird_linux**
|
||||||
retention-days: 7
|
retention-days: 7
|
||||||
- name: upload windows packages
|
- name: upload windows packages
|
||||||
id: upload_windows_packages
|
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: windows-packages
|
name: windows-packages
|
||||||
path: dist/netbird_windows**
|
path: dist/netbird_windows**
|
||||||
retention-days: 7
|
retention-days: 7
|
||||||
- name: upload macos packages
|
- name: upload macos packages
|
||||||
id: upload_macos_packages
|
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: macos-packages
|
name: macos-packages
|
||||||
@@ -311,8 +270,6 @@ jobs:
|
|||||||
|
|
||||||
release_ui:
|
release_ui:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
outputs:
|
|
||||||
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
|
|
||||||
steps:
|
steps:
|
||||||
- name: Parse semver string
|
- name: Parse semver string
|
||||||
id: semver_parser
|
id: semver_parser
|
||||||
@@ -403,7 +360,6 @@ jobs:
|
|||||||
if: always()
|
if: always()
|
||||||
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||||
- name: upload non tags for debug purposes
|
- name: upload non tags for debug purposes
|
||||||
id: upload_release_ui
|
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: release-ui
|
name: release-ui
|
||||||
@@ -412,8 +368,6 @@ jobs:
|
|||||||
|
|
||||||
release_ui_darwin:
|
release_ui_darwin:
|
||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
outputs:
|
|
||||||
release_ui_darwin_artifact_url: ${{ steps.upload_release_ui_darwin.outputs.artifact-url }}
|
|
||||||
steps:
|
steps:
|
||||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||||
@@ -448,110 +402,12 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
- name: upload non tags for debug purposes
|
- name: upload non tags for debug purposes
|
||||||
id: upload_release_ui_darwin
|
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: release-ui-darwin
|
name: release-ui-darwin
|
||||||
path: dist/
|
path: dist/
|
||||||
retention-days: 3
|
retention-days: 3
|
||||||
|
|
||||||
comment_release_artifacts:
|
|
||||||
name: Comment release artifacts
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
needs: [release, release_ui, release_ui_darwin]
|
|
||||||
if: ${{ always() && github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository }}
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
issues: write
|
|
||||||
pull-requests: write
|
|
||||||
steps:
|
|
||||||
- name: Create or update PR comment
|
|
||||||
uses: actions/github-script@v7
|
|
||||||
env:
|
|
||||||
RELEASE_RESULT: ${{ needs.release.result }}
|
|
||||||
RELEASE_UI_RESULT: ${{ needs.release_ui.result }}
|
|
||||||
RELEASE_UI_DARWIN_RESULT: ${{ needs.release_ui_darwin.result }}
|
|
||||||
RELEASE_ARTIFACT_URL: ${{ needs.release.outputs.release_artifact_url }}
|
|
||||||
LINUX_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.linux_packages_artifact_url }}
|
|
||||||
WINDOWS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.windows_packages_artifact_url }}
|
|
||||||
MACOS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.macos_packages_artifact_url }}
|
|
||||||
RELEASE_UI_ARTIFACT_URL: ${{ needs.release_ui.outputs.release_ui_artifact_url }}
|
|
||||||
RELEASE_UI_DARWIN_ARTIFACT_URL: ${{ needs.release_ui_darwin.outputs.release_ui_darwin_artifact_url }}
|
|
||||||
GHCR_IMAGES_MARKDOWN: ${{ needs.release.outputs.ghcr_images }}
|
|
||||||
with:
|
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
script: |
|
|
||||||
const marker = '<!-- netbird-release-artifacts -->';
|
|
||||||
const { owner, repo } = context.repo;
|
|
||||||
const issue_number = context.payload.pull_request.number;
|
|
||||||
const runUrl = `${context.serverUrl}/${owner}/${repo}/actions/runs/${context.runId}`;
|
|
||||||
const shortSha = context.payload.pull_request.head.sha.slice(0, 7);
|
|
||||||
|
|
||||||
const artifactCell = (url, result) => {
|
|
||||||
if (url) return `[Download](${url})`;
|
|
||||||
return result && result !== 'success' ? `_Not available (${result})_` : '_Not available_';
|
|
||||||
};
|
|
||||||
|
|
||||||
const artifacts = [
|
|
||||||
['All release artifacts', process.env.RELEASE_ARTIFACT_URL, process.env.RELEASE_RESULT],
|
|
||||||
['Linux packages', process.env.LINUX_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
|
|
||||||
['Windows packages', process.env.WINDOWS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
|
|
||||||
['macOS packages', process.env.MACOS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
|
|
||||||
['UI artifacts', process.env.RELEASE_UI_ARTIFACT_URL, process.env.RELEASE_UI_RESULT],
|
|
||||||
['UI macOS artifacts', process.env.RELEASE_UI_DARWIN_ARTIFACT_URL, process.env.RELEASE_UI_DARWIN_RESULT],
|
|
||||||
];
|
|
||||||
|
|
||||||
const artifactRows = artifacts
|
|
||||||
.map(([name, url, result]) => `| ${name} | ${artifactCell(url, result)} |`)
|
|
||||||
.join('\n');
|
|
||||||
|
|
||||||
const ghcrImages = (process.env.GHCR_IMAGES_MARKDOWN || '').trim() || '_No GHCR images were pushed._';
|
|
||||||
|
|
||||||
const body = [
|
|
||||||
marker,
|
|
||||||
'## Release artifacts',
|
|
||||||
'',
|
|
||||||
`Built for PR head \`${shortSha}\` in [workflow run #${process.env.GITHUB_RUN_NUMBER}](${runUrl}).`,
|
|
||||||
'',
|
|
||||||
'| Artifact | Link |',
|
|
||||||
'| --- | --- |',
|
|
||||||
artifactRows,
|
|
||||||
'',
|
|
||||||
'### GHCR images (amd64)',
|
|
||||||
ghcrImages,
|
|
||||||
'',
|
|
||||||
'_This comment is updated by the Release workflow. Artifact links expire according to the workflow retention policy._',
|
|
||||||
].join('\n');
|
|
||||||
|
|
||||||
const comments = await github.paginate(github.rest.issues.listComments, {
|
|
||||||
owner,
|
|
||||||
repo,
|
|
||||||
issue_number,
|
|
||||||
per_page: 100,
|
|
||||||
});
|
|
||||||
|
|
||||||
const previous = comments.find(comment =>
|
|
||||||
comment.user?.type === 'Bot' && comment.body?.includes(marker)
|
|
||||||
);
|
|
||||||
|
|
||||||
if (previous) {
|
|
||||||
await github.rest.issues.updateComment({
|
|
||||||
owner,
|
|
||||||
repo,
|
|
||||||
comment_id: previous.id,
|
|
||||||
body,
|
|
||||||
});
|
|
||||||
core.info(`Updated release artifacts comment ${previous.id}`);
|
|
||||||
} else {
|
|
||||||
const { data } = await github.rest.issues.createComment({
|
|
||||||
owner,
|
|
||||||
repo,
|
|
||||||
issue_number,
|
|
||||||
body,
|
|
||||||
});
|
|
||||||
core.info(`Created release artifacts comment ${data.id}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
trigger_signer:
|
trigger_signer:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: [release, release_ui, release_ui_darwin]
|
needs: [release, release_ui, release_ui_darwin]
|
||||||
|
|||||||
28
.github/workflows/sync-tag.yml
vendored
28
.github/workflows/sync-tag.yml
vendored
@@ -9,8 +9,6 @@ concurrency:
|
|||||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
# Receiving workflows (cloud sync-tag, mobile bump-netbird) expect the short
|
|
||||||
# tag form (e.g. v0.30.0), not refs/tags/v0.30.0 — github.ref_name, not github.ref.
|
|
||||||
jobs:
|
jobs:
|
||||||
trigger_sync_tag:
|
trigger_sync_tag:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
@@ -23,29 +21,3 @@ jobs:
|
|||||||
repo: ${{ secrets.UPSTREAM_REPO }}
|
repo: ${{ secrets.UPSTREAM_REPO }}
|
||||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
||||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
||||||
|
|
||||||
trigger_android_bump:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
|
|
||||||
steps:
|
|
||||||
- name: Trigger android-client submodule bump
|
|
||||||
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
|
|
||||||
with:
|
|
||||||
workflow: bump-netbird.yml
|
|
||||||
ref: main
|
|
||||||
repo: netbirdio/android-client
|
|
||||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
|
||||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
|
||||||
|
|
||||||
trigger_ios_bump:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
|
|
||||||
steps:
|
|
||||||
- name: Trigger ios-client submodule bump
|
|
||||||
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
|
|
||||||
with:
|
|
||||||
workflow: bump-netbird.yml
|
|
||||||
ref: main
|
|
||||||
repo: netbirdio/ios-client
|
|
||||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
|
||||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
|
||||||
@@ -151,6 +151,7 @@ func init() {
|
|||||||
rootCmd.AddCommand(logoutCmd)
|
rootCmd.AddCommand(logoutCmd)
|
||||||
rootCmd.AddCommand(versionCmd)
|
rootCmd.AddCommand(versionCmd)
|
||||||
rootCmd.AddCommand(sshCmd)
|
rootCmd.AddCommand(sshCmd)
|
||||||
|
rootCmd.AddCommand(vncCmd)
|
||||||
rootCmd.AddCommand(networksCMD)
|
rootCmd.AddCommand(networksCMD)
|
||||||
rootCmd.AddCommand(forwardingRulesCmd)
|
rootCmd.AddCommand(forwardingRulesCmd)
|
||||||
rootCmd.AddCommand(debugCmd)
|
rootCmd.AddCommand(debugCmd)
|
||||||
|
|||||||
@@ -36,7 +36,10 @@ const (
|
|||||||
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
|
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
|
||||||
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
|
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
|
||||||
disableSSHAuthFlag = "disable-ssh-auth"
|
disableSSHAuthFlag = "disable-ssh-auth"
|
||||||
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
|
jwtCacheTTLFlag = "jwt-cache-ttl"
|
||||||
|
|
||||||
|
// Alias for backward compatibility.
|
||||||
|
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -61,7 +64,7 @@ var (
|
|||||||
enableSSHLocalPortForward bool
|
enableSSHLocalPortForward bool
|
||||||
enableSSHRemotePortForward bool
|
enableSSHRemotePortForward bool
|
||||||
disableSSHAuth bool
|
disableSSHAuth bool
|
||||||
sshJWTCacheTTL int
|
jwtCacheTTL int
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -71,7 +74,9 @@ func init() {
|
|||||||
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
|
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
|
||||||
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
|
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
|
||||||
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
|
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
|
||||||
upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)")
|
upCmd.PersistentFlags().IntVar(&jwtCacheTTL, jwtCacheTTLFlag, 0, "JWT token cache TTL in seconds (0=disabled)")
|
||||||
|
upCmd.PersistentFlags().IntVar(&jwtCacheTTL, sshJWTCacheTTLFlag, 0, "JWT token cache TTL in seconds (alias for --jwt-cache-ttl)")
|
||||||
|
_ = upCmd.PersistentFlags().MarkDeprecated(sshJWTCacheTTLFlag, "use --jwt-cache-ttl instead")
|
||||||
|
|
||||||
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
|
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
|
||||||
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
|
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil)
|
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -356,6 +356,9 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
|||||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||||
req.ServerSSHAllowed = &serverSSHAllowed
|
req.ServerSSHAllowed = &serverSSHAllowed
|
||||||
}
|
}
|
||||||
|
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||||
|
req.ServerVNCAllowed = &serverVNCAllowed
|
||||||
|
}
|
||||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||||
req.EnableSSHRoot = &enableSSHRoot
|
req.EnableSSHRoot = &enableSSHRoot
|
||||||
}
|
}
|
||||||
@@ -371,9 +374,12 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
|||||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||||
req.DisableSSHAuth = &disableSSHAuth
|
req.DisableSSHAuth = &disableSSHAuth
|
||||||
}
|
}
|
||||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
if cmd.Flag(disableVNCAuthFlag).Changed {
|
||||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
req.DisableVNCAuth = &disableVNCAuth
|
||||||
req.SshJWTCacheTTL = &sshJWTCacheTTL32
|
}
|
||||||
|
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||||
|
jwtCacheTTL32 := int32(jwtCacheTTL)
|
||||||
|
req.SshJWTCacheTTL = &jwtCacheTTL32
|
||||||
}
|
}
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
if err := parseInterfaceName(interfaceName); err != nil {
|
if err := parseInterfaceName(interfaceName); err != nil {
|
||||||
@@ -458,6 +464,9 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
|||||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||||
}
|
}
|
||||||
|
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||||
|
ic.ServerVNCAllowed = &serverVNCAllowed
|
||||||
|
}
|
||||||
|
|
||||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||||
ic.EnableSSHRoot = &enableSSHRoot
|
ic.EnableSSHRoot = &enableSSHRoot
|
||||||
@@ -479,8 +488,12 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
|||||||
ic.DisableSSHAuth = &disableSSHAuth
|
ic.DisableSSHAuth = &disableSSHAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
if cmd.Flag(disableVNCAuthFlag).Changed {
|
||||||
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
|
ic.DisableVNCAuth = &disableVNCAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||||
|
ic.SSHJWTCacheTTL = &jwtCacheTTL
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
@@ -582,6 +595,9 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
|||||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||||
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
||||||
}
|
}
|
||||||
|
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||||
|
loginRequest.ServerVNCAllowed = &serverVNCAllowed
|
||||||
|
}
|
||||||
|
|
||||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||||
loginRequest.EnableSSHRoot = &enableSSHRoot
|
loginRequest.EnableSSHRoot = &enableSSHRoot
|
||||||
@@ -603,9 +619,13 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
|||||||
loginRequest.DisableSSHAuth = &disableSSHAuth
|
loginRequest.DisableSSHAuth = &disableSSHAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
if cmd.Flag(disableVNCAuthFlag).Changed {
|
||||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
loginRequest.DisableVNCAuth = &disableVNCAuth
|
||||||
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||||
|
jwtCacheTTL32 := int32(jwtCacheTTL)
|
||||||
|
loginRequest.SshJWTCacheTTL = &jwtCacheTTL32
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||||
|
|||||||
271
client/cmd/vnc.go
Normal file
271
client/cmd/vnc.go
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"os/user"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
vncUsername string
|
||||||
|
vncHost string
|
||||||
|
vncMode string
|
||||||
|
vncListen string
|
||||||
|
vncNoBrowser bool
|
||||||
|
vncNoCache bool
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
vncCmd.PersistentFlags().StringVar(&vncUsername, "user", "", "OS username for session mode")
|
||||||
|
vncCmd.PersistentFlags().StringVar(&vncMode, "mode", "attach", "Connection mode: attach (view current display) or session (virtual desktop)")
|
||||||
|
vncCmd.PersistentFlags().StringVar(&vncListen, "listen", "", "Start local VNC proxy on this address (e.g., :5900) for external VNC viewers")
|
||||||
|
vncCmd.PersistentFlags().BoolVar(&vncNoBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||||
|
vncCmd.PersistentFlags().BoolVar(&vncNoCache, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||||
|
}
|
||||||
|
|
||||||
|
var vncCmd = &cobra.Command{
|
||||||
|
Use: "vnc [flags] [user@]host",
|
||||||
|
Short: "Connect to a NetBird peer via VNC",
|
||||||
|
Long: `Connect to a NetBird peer using VNC with JWT-based authentication.
|
||||||
|
The target peer must have the VNC server enabled.
|
||||||
|
|
||||||
|
Two modes are available:
|
||||||
|
- attach: view the current physical display (remote support)
|
||||||
|
- session: start a virtual desktop as the specified user (passwordless login)
|
||||||
|
|
||||||
|
Use --listen to start a local proxy for external VNC viewers:
|
||||||
|
netbird vnc --listen :5900 peer-hostname
|
||||||
|
vncviewer localhost:5900
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
netbird vnc peer-hostname
|
||||||
|
netbird vnc --mode session --user alice peer-hostname
|
||||||
|
netbird vnc --listen :5900 peer-hostname`,
|
||||||
|
Args: cobra.MinimumNArgs(1),
|
||||||
|
RunE: vncFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
func vncFn(cmd *cobra.Command, args []string) error {
|
||||||
|
SetFlagsFromEnvVars(rootCmd)
|
||||||
|
SetFlagsFromEnvVars(cmd)
|
||||||
|
cmd.SetOut(cmd.OutOrStdout())
|
||||||
|
|
||||||
|
logOutput := "console"
|
||||||
|
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
||||||
|
logOutput = firstLogFile
|
||||||
|
}
|
||||||
|
if err := util.InitLog(logLevel, logOutput); err != nil {
|
||||||
|
return fmt.Errorf("init log: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := parseVNCHostArg(args[0]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := internal.CtxInitState(cmd.Context())
|
||||||
|
sig := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||||
|
vncCtx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
if err := runVNC(vncCtx, cmd); err != nil {
|
||||||
|
errCh <- err
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-sig:
|
||||||
|
cancel()
|
||||||
|
<-vncCtx.Done()
|
||||||
|
return nil
|
||||||
|
case err := <-errCh:
|
||||||
|
return err
|
||||||
|
case <-vncCtx.Done():
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseVNCHostArg(arg string) error {
|
||||||
|
if strings.Contains(arg, "@") {
|
||||||
|
parts := strings.SplitN(arg, "@", 2)
|
||||||
|
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||||
|
return fmt.Errorf("invalid user@host format")
|
||||||
|
}
|
||||||
|
if vncUsername == "" {
|
||||||
|
vncUsername = parts[0]
|
||||||
|
}
|
||||||
|
vncHost = parts[1]
|
||||||
|
if vncMode == "attach" {
|
||||||
|
vncMode = "session"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
vncHost = arg
|
||||||
|
}
|
||||||
|
|
||||||
|
if vncMode == "session" && vncUsername == "" {
|
||||||
|
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||||
|
vncUsername = sudoUser
|
||||||
|
} else if currentUser, err := user.Current(); err == nil {
|
||||||
|
vncUsername = currentUser.Username
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runVNC(ctx context.Context, cmd *cobra.Command) error {
|
||||||
|
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
|
||||||
|
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("connect to daemon: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = grpcConn.Close() }()
|
||||||
|
|
||||||
|
daemonClient := proto.NewDaemonServiceClient(grpcConn)
|
||||||
|
|
||||||
|
if vncMode == "session" {
|
||||||
|
cmd.Printf("Connecting to %s@%s [session mode]...\n", vncUsername, vncHost)
|
||||||
|
} else {
|
||||||
|
cmd.Printf("Connecting to %s [attach mode]...\n", vncHost)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Obtain JWT token. If the daemon has no SSO configured, proceed without one
|
||||||
|
// (the server will accept unauthenticated connections if --disable-vnc-auth is set).
|
||||||
|
var jwtToken string
|
||||||
|
hint := profilemanager.GetLoginHint()
|
||||||
|
var browserOpener func(string) error
|
||||||
|
if !vncNoBrowser {
|
||||||
|
browserOpener = util.OpenBrowser
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := nbssh.RequestJWTToken(ctx, daemonClient, nil, cmd.ErrOrStderr(), !vncNoCache, hint, browserOpener)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("JWT authentication unavailable, connecting without token: %v", err)
|
||||||
|
} else {
|
||||||
|
jwtToken = token
|
||||||
|
log.Debug("JWT authentication successful")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect to the VNC server on the standard port (5900). The peer's firewall
|
||||||
|
// DNATs 5900 -> 25900 (internal), so both ports work on the overlay network.
|
||||||
|
vncAddr := net.JoinHostPort(vncHost, "5900")
|
||||||
|
vncConn, err := net.DialTimeout("tcp", vncAddr, vncDialTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("connect to VNC at %s: %w", vncAddr, err)
|
||||||
|
}
|
||||||
|
defer vncConn.Close()
|
||||||
|
|
||||||
|
// Send session header with mode, username, and JWT.
|
||||||
|
if err := sendVNCHeader(vncConn, vncMode, vncUsername, jwtToken); err != nil {
|
||||||
|
return fmt.Errorf("send VNC header: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("VNC connected to %s\n", vncHost)
|
||||||
|
|
||||||
|
if vncListen != "" {
|
||||||
|
return runVNCLocalProxy(ctx, cmd, vncConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// No --listen flag: inform the user they need to use --listen for external viewers.
|
||||||
|
cmd.Printf("VNC tunnel established. Use --listen :5900 to proxy for local VNC viewers.\n")
|
||||||
|
cmd.Printf("Press Ctrl+C to disconnect.\n")
|
||||||
|
<-ctx.Done()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const vncDialTimeout = 15 * time.Second
|
||||||
|
|
||||||
|
// sendVNCHeader writes the NetBird VNC session header.
|
||||||
|
func sendVNCHeader(conn net.Conn, mode, username, jwt string) error {
|
||||||
|
var modeByte byte
|
||||||
|
if mode == "session" {
|
||||||
|
modeByte = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
usernameBytes := []byte(username)
|
||||||
|
jwtBytes := []byte(jwt)
|
||||||
|
hdr := make([]byte, 3+len(usernameBytes)+2+len(jwtBytes))
|
||||||
|
hdr[0] = modeByte
|
||||||
|
binary.BigEndian.PutUint16(hdr[1:3], uint16(len(usernameBytes)))
|
||||||
|
off := 3
|
||||||
|
copy(hdr[off:], usernameBytes)
|
||||||
|
off += len(usernameBytes)
|
||||||
|
binary.BigEndian.PutUint16(hdr[off:off+2], uint16(len(jwtBytes)))
|
||||||
|
off += 2
|
||||||
|
copy(hdr[off:], jwtBytes)
|
||||||
|
|
||||||
|
_, err := conn.Write(hdr)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// runVNCLocalProxy listens on the given address and proxies incoming
|
||||||
|
// connections to the already-established VNC tunnel.
|
||||||
|
func runVNCLocalProxy(ctx context.Context, cmd *cobra.Command, vncConn net.Conn) error {
|
||||||
|
listener, err := net.Listen("tcp", vncListen)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("listen on %s: %w", vncListen, err)
|
||||||
|
}
|
||||||
|
defer listener.Close()
|
||||||
|
|
||||||
|
cmd.Printf("VNC proxy listening on %s - connect with your VNC viewer\n", listener.Addr())
|
||||||
|
cmd.Printf("Press Ctrl+C to stop.\n")
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
listener.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Accept a single viewer connection. VNC is single-session: the RFB
|
||||||
|
// handshake completes on vncConn for the first viewer, so subsequent
|
||||||
|
// viewers would get a mid-stream connection. The loop handles transient
|
||||||
|
// accept errors until a valid connection arrives.
|
||||||
|
for {
|
||||||
|
clientConn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
log.Debugf("accept VNC proxy client: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("VNC viewer connected from %s\n", clientConn.RemoteAddr())
|
||||||
|
|
||||||
|
// Bidirectional copy.
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
io.Copy(vncConn, clientConn)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
io.Copy(clientConn, vncConn)
|
||||||
|
<-done
|
||||||
|
clientConn.Close()
|
||||||
|
|
||||||
|
cmd.Printf("VNC viewer disconnected\n")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
62
client/cmd/vnc_agent.go
Normal file
62
client/cmd/vnc_agent.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
var vncAgentPort string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
vncAgentCmd.Flags().StringVar(&vncAgentPort, "port", "15900", "Port for the VNC agent to listen on")
|
||||||
|
rootCmd.AddCommand(vncAgentCmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// vncAgentCmd runs a VNC server in the current user session, listening on
|
||||||
|
// localhost. It is spawned by the NetBird service (Session 0) via
|
||||||
|
// CreateProcessAsUser into the interactive console session.
|
||||||
|
var vncAgentCmd = &cobra.Command{
|
||||||
|
Use: "vnc-agent",
|
||||||
|
Short: "Run VNC capture agent (internal, spawned by service)",
|
||||||
|
Hidden: true,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
// Agent's stderr is piped to the service which relogs it.
|
||||||
|
// Use JSON format with caller info for structured parsing.
|
||||||
|
log.SetReportCaller(true)
|
||||||
|
log.SetFormatter(&log.JSONFormatter{})
|
||||||
|
log.SetOutput(os.Stderr)
|
||||||
|
|
||||||
|
sessionID := vncserver.GetCurrentSessionID()
|
||||||
|
log.Infof("VNC agent starting on 127.0.0.1:%s (session %d)", vncAgentPort, sessionID)
|
||||||
|
|
||||||
|
capturer := vncserver.NewDesktopCapturer()
|
||||||
|
injector := vncserver.NewWindowsInputInjector()
|
||||||
|
srv := vncserver.New(capturer, injector, "")
|
||||||
|
// Auth is handled by the service. The agent verifies a token on each
|
||||||
|
// connection to ensure only the service process can connect.
|
||||||
|
// The token is passed via environment variable to avoid exposing it
|
||||||
|
// in the process command line (visible via tasklist/wmic).
|
||||||
|
srv.SetDisableAuth(true)
|
||||||
|
srv.SetAgentToken(os.Getenv("NB_VNC_AGENT_TOKEN"))
|
||||||
|
|
||||||
|
port, err := netip.ParseAddrPort("127.0.0.1:" + vncAgentPort)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
loopback := netip.PrefixFrom(netip.AddrFrom4([4]byte{127, 0, 0, 0}), 8)
|
||||||
|
if err := srv.Start(cmd.Context(), port, loopback); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
<-cmd.Context().Done()
|
||||||
|
return srv.Stop()
|
||||||
|
},
|
||||||
|
}
|
||||||
16
client/cmd/vnc_flags.go
Normal file
16
client/cmd/vnc_flags.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
const (
|
||||||
|
serverVNCAllowedFlag = "allow-server-vnc"
|
||||||
|
disableVNCAuthFlag = "disable-vnc-auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
serverVNCAllowed bool
|
||||||
|
disableVNCAuth bool
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
upCmd.PersistentFlags().BoolVar(&serverVNCAllowed, serverVNCAllowedFlag, false, "Allow embedded VNC server on peer")
|
||||||
|
upCmd.PersistentFlags().BoolVar(&disableVNCAuth, disableVNCAuthFlag, false, "Disable JWT authentication for VNC")
|
||||||
|
}
|
||||||
229
client/cmd/vnc_recordings.go
Normal file
229
client/cmd/vnc_recordings.go
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdh"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"text/tabwriter"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
var vncRecDir string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
vncRecPlayCmd.Flags().StringVar(&vncRecDir, "dir", "", "Recording directory (default: auto-detect)")
|
||||||
|
vncRecListCmd.Flags().StringVar(&vncRecDir, "dir", "", "Recording directory (default: auto-detect)")
|
||||||
|
vncRecCmd.AddCommand(vncRecListCmd)
|
||||||
|
vncRecCmd.AddCommand(vncRecPlayCmd)
|
||||||
|
vncRecCmd.AddCommand(vncRecKeygenCmd)
|
||||||
|
vncCmd.AddCommand(vncRecCmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
var vncRecCmd = &cobra.Command{
|
||||||
|
Use: "rec",
|
||||||
|
Short: "Manage VNC session recordings",
|
||||||
|
}
|
||||||
|
|
||||||
|
var vncRecKeygenCmd = &cobra.Command{
|
||||||
|
Use: "keygen",
|
||||||
|
Short: "Generate an X25519 keypair for recording encryption",
|
||||||
|
Long: `Generates an X25519 keypair. Put the public key in management settings
|
||||||
|
(Session Recording > Encryption Key). Keep the private key safe for decrypting recordings.`,
|
||||||
|
RunE: vncRecKeygenFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
var vncRecListCmd = &cobra.Command{
|
||||||
|
Use: "list",
|
||||||
|
Short: "List VNC session recordings",
|
||||||
|
RunE: vncRecListFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
var vncRecPlayCmd = &cobra.Command{
|
||||||
|
Use: "play <file-or-name>",
|
||||||
|
Short: "Open a VNC recording in the browser",
|
||||||
|
Long: `Opens a browser-based player with playback controls:
|
||||||
|
play/pause, seek, speed (0.25x to 8x), keyboard shortcuts.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
netbird vnc rec play last
|
||||||
|
netbird vnc rec play 20260416-104433_vnc.rec`,
|
||||||
|
Args: cobra.ExactArgs(1),
|
||||||
|
RunE: vncRecPlayFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
func vncRecListFn(cmd *cobra.Command, _ []string) error {
|
||||||
|
dir, err := resolveVNCRecDir()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := os.ReadDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read recording dir %s: %w", dir, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 0, 2, ' ', 0)
|
||||||
|
fmt.Fprintln(w, "FILE\tSIZE\tDIMENSIONS\tUSER\tREMOTE\tMODE\tDATE")
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".rec") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filePath := filepath.Join(dir, entry.Name())
|
||||||
|
info, err := entry.Info()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
header, err := vncserver.ReadRecordingHeader(filePath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(w, "%s\t%s\t?\t?\t?\t?\t?\n", entry.Name(), vncFormatSize(info.Size()))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "%s\t%s\t%dx%d\t%s\t%s\t%s\t%s\n",
|
||||||
|
entry.Name(),
|
||||||
|
vncFormatSize(info.Size()),
|
||||||
|
header.Width, header.Height,
|
||||||
|
header.Meta.User,
|
||||||
|
header.Meta.RemoteAddr,
|
||||||
|
header.Meta.Mode,
|
||||||
|
header.StartTime.Format("2006-01-02 15:04:05"),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func vncRecPlayFn(cmd *cobra.Command, args []string) error {
|
||||||
|
filePath, err := resolveVNCRecFile(args[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
header, err := vncserver.ReadRecordingHeader(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read recording: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("Recording: %s (%dx%d)\n", filepath.Base(filePath), header.Width, header.Height)
|
||||||
|
|
||||||
|
url, err := vncserver.ServeWebPlayer(filePath, "localhost:0")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("start web player: %w", err)
|
||||||
|
}
|
||||||
|
cmd.Printf("Player: %s\n", url)
|
||||||
|
if err := util.OpenBrowser(url); err != nil {
|
||||||
|
cmd.Printf("Open %s in your browser\n", url)
|
||||||
|
}
|
||||||
|
cmd.Printf("Press Ctrl+C to stop.\n")
|
||||||
|
sig := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
<-sig
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
func vncRecKeygenFn(cmd *cobra.Command, _ []string) error {
|
||||||
|
priv, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
privB64 := base64.StdEncoding.EncodeToString(priv.Bytes())
|
||||||
|
pubB64 := base64.StdEncoding.EncodeToString(priv.PublicKey().Bytes())
|
||||||
|
|
||||||
|
cmd.Printf("Private key (keep secret, for decrypting recordings):\n %s\n\n", privB64)
|
||||||
|
cmd.Printf("Public key (paste into management Settings > Session Recording > Encryption Key):\n %s\n", pubB64)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func vncFormatSize(size int64) string {
|
||||||
|
switch {
|
||||||
|
case size >= 1<<20:
|
||||||
|
return fmt.Sprintf("%.1fM", float64(size)/float64(1<<20))
|
||||||
|
case size >= 1<<10:
|
||||||
|
return fmt.Sprintf("%.1fK", float64(size)/float64(1<<10))
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("%dB", size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveVNCRecDir() (string, error) {
|
||||||
|
if vncRecDir != "" {
|
||||||
|
return vncRecDir, nil
|
||||||
|
}
|
||||||
|
candidates := []string{
|
||||||
|
"/var/lib/netbird/recordings/vnc",
|
||||||
|
filepath.Join(os.Getenv("HOME"), ".netbird/recordings/vnc"),
|
||||||
|
}
|
||||||
|
for _, dir := range candidates {
|
||||||
|
if fi, err := os.Stat(dir); err == nil && fi.IsDir() {
|
||||||
|
return dir, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("no VNC recording directory found; use --dir to specify")
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveVNCRecFile(arg string) (string, error) {
|
||||||
|
if strings.Contains(arg, "/") || strings.Contains(arg, string(os.PathSeparator)) {
|
||||||
|
return arg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dir, err := resolveVNCRecDir()
|
||||||
|
if err != nil && arg != "last" {
|
||||||
|
return arg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if arg == "last" {
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return findLatestRec(dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
full := filepath.Join(dir, arg)
|
||||||
|
if _, err := os.Stat(full); err == nil {
|
||||||
|
return full, nil
|
||||||
|
}
|
||||||
|
return arg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func findLatestRec(dir string) (string, error) {
|
||||||
|
entries, err := os.ReadDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("read dir: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var latest string
|
||||||
|
var latestTime time.Time
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".rec") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
info, err := entry.Info()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if info.ModTime().After(latestTime) {
|
||||||
|
latestTime = info.ModTime()
|
||||||
|
latest = filepath.Join(dir, entry.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if latest == "" {
|
||||||
|
return "", fmt.Errorf("no recordings found in %s", dir)
|
||||||
|
}
|
||||||
|
return latest, nil
|
||||||
|
}
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
// Package firewalld integrates with the firewalld daemon so NetBird can place
|
|
||||||
// its wg interface into firewalld's "trusted" zone. This is required because
|
|
||||||
// firewalld's nftables chains are created with NFT_CHAIN_OWNER on recent
|
|
||||||
// versions, which returns EPERM to any other process that tries to insert
|
|
||||||
// rules into them. The workaround mirrors what Tailscale does: let firewalld
|
|
||||||
// itself add the accept rules to its own chains by trusting the interface.
|
|
||||||
package firewalld
|
|
||||||
|
|
||||||
// TrustedZone is the firewalld zone name used for interfaces whose traffic
|
|
||||||
// should bypass firewalld filtering.
|
|
||||||
const TrustedZone = "trusted"
|
|
||||||
@@ -1,260 +0,0 @@
|
|||||||
//go:build linux
|
|
||||||
|
|
||||||
package firewalld
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"os/exec"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/godbus/dbus/v5"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
dbusDest = "org.fedoraproject.FirewallD1"
|
|
||||||
dbusPath = "/org/fedoraproject/FirewallD1"
|
|
||||||
dbusRootIface = "org.fedoraproject.FirewallD1"
|
|
||||||
dbusZoneIface = "org.fedoraproject.FirewallD1.zone"
|
|
||||||
|
|
||||||
errZoneAlreadySet = "ZONE_ALREADY_SET"
|
|
||||||
errAlreadyEnabled = "ALREADY_ENABLED"
|
|
||||||
errUnknownIface = "UNKNOWN_INTERFACE"
|
|
||||||
errNotEnabled = "NOT_ENABLED"
|
|
||||||
|
|
||||||
// callTimeout bounds each individual DBus or firewall-cmd invocation.
|
|
||||||
// A fresh context is created for each call so a slow DBus probe can't
|
|
||||||
// exhaust the deadline before the firewall-cmd fallback gets to run.
|
|
||||||
callTimeout = 3 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
errDBusUnavailable = errors.New("firewalld dbus unavailable")
|
|
||||||
|
|
||||||
// trustLogOnce ensures the "added to trusted zone" message is logged at
|
|
||||||
// Info level only for the first successful add per process; repeat adds
|
|
||||||
// from other init paths are quieter.
|
|
||||||
trustLogOnce sync.Once
|
|
||||||
|
|
||||||
parentCtxMu sync.RWMutex
|
|
||||||
parentCtx context.Context = context.Background()
|
|
||||||
)
|
|
||||||
|
|
||||||
// SetParentContext installs a parent context whose cancellation aborts any
|
|
||||||
// in-flight TrustInterface call. It does not affect UntrustInterface, which
|
|
||||||
// always uses a fresh Background-rooted timeout so cleanup can still run
|
|
||||||
// during engine shutdown when the engine context is already cancelled.
|
|
||||||
func SetParentContext(ctx context.Context) {
|
|
||||||
parentCtxMu.Lock()
|
|
||||||
parentCtx = ctx
|
|
||||||
parentCtxMu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func getParentContext() context.Context {
|
|
||||||
parentCtxMu.RLock()
|
|
||||||
defer parentCtxMu.RUnlock()
|
|
||||||
return parentCtx
|
|
||||||
}
|
|
||||||
|
|
||||||
// TrustInterface places iface into firewalld's trusted zone if firewalld is
|
|
||||||
// running. It is idempotent and best-effort: errors are returned so callers
|
|
||||||
// can log, but a non-running firewalld is not an error. Only the first
|
|
||||||
// successful call per process logs at Info. Respects the parent context set
|
|
||||||
// via SetParentContext so startup-time cancellation unblocks it.
|
|
||||||
func TrustInterface(iface string) error {
|
|
||||||
parent := getParentContext()
|
|
||||||
if !isRunning(parent) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err := addTrusted(parent, iface); err != nil {
|
|
||||||
return fmt.Errorf("add %s to firewalld trusted zone: %w", iface, err)
|
|
||||||
}
|
|
||||||
trustLogOnce.Do(func() {
|
|
||||||
log.Infof("added %s to firewalld trusted zone", iface)
|
|
||||||
})
|
|
||||||
log.Debugf("firewalld: ensured %s is in trusted zone", iface)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UntrustInterface removes iface from firewalld's trusted zone if firewalld
|
|
||||||
// is running. Idempotent. Uses a Background-rooted timeout so it still runs
|
|
||||||
// during shutdown after the engine context has been cancelled.
|
|
||||||
func UntrustInterface(iface string) error {
|
|
||||||
if !isRunning(context.Background()) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err := removeTrusted(context.Background(), iface); err != nil {
|
|
||||||
return fmt.Errorf("remove %s from firewalld trusted zone: %w", iface, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newCallContext(parent context.Context) (context.Context, context.CancelFunc) {
|
|
||||||
return context.WithTimeout(parent, callTimeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
func isRunning(parent context.Context) bool {
|
|
||||||
ctx, cancel := newCallContext(parent)
|
|
||||||
ok, err := isRunningDBus(ctx)
|
|
||||||
cancel()
|
|
||||||
if err == nil {
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
if errors.Is(err, errDBusUnavailable) || errors.Is(err, context.DeadlineExceeded) {
|
|
||||||
ctx, cancel = newCallContext(parent)
|
|
||||||
defer cancel()
|
|
||||||
return isRunningCLI(ctx)
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func addTrusted(parent context.Context, iface string) error {
|
|
||||||
ctx, cancel := newCallContext(parent)
|
|
||||||
err := addDBus(ctx, iface)
|
|
||||||
cancel()
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if !errors.Is(err, errDBusUnavailable) {
|
|
||||||
log.Debugf("firewalld: dbus add failed, falling back to firewall-cmd: %v", err)
|
|
||||||
}
|
|
||||||
ctx, cancel = newCallContext(parent)
|
|
||||||
defer cancel()
|
|
||||||
return addCLI(ctx, iface)
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeTrusted(parent context.Context, iface string) error {
|
|
||||||
ctx, cancel := newCallContext(parent)
|
|
||||||
err := removeDBus(ctx, iface)
|
|
||||||
cancel()
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if !errors.Is(err, errDBusUnavailable) {
|
|
||||||
log.Debugf("firewalld: dbus remove failed, falling back to firewall-cmd: %v", err)
|
|
||||||
}
|
|
||||||
ctx, cancel = newCallContext(parent)
|
|
||||||
defer cancel()
|
|
||||||
return removeCLI(ctx, iface)
|
|
||||||
}
|
|
||||||
|
|
||||||
func isRunningDBus(ctx context.Context) (bool, error) {
|
|
||||||
conn, err := dbus.SystemBus()
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
|
||||||
}
|
|
||||||
obj := conn.Object(dbusDest, dbusPath)
|
|
||||||
|
|
||||||
var zone string
|
|
||||||
if err := obj.CallWithContext(ctx, dbusRootIface+".getDefaultZone", 0).Store(&zone); err != nil {
|
|
||||||
return false, fmt.Errorf("firewalld getDefaultZone: %w", err)
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isRunningCLI(ctx context.Context) bool {
|
|
||||||
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return exec.CommandContext(ctx, "firewall-cmd", "--state").Run() == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func addDBus(ctx context.Context, iface string) error {
|
|
||||||
conn, err := dbus.SystemBus()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
|
||||||
}
|
|
||||||
obj := conn.Object(dbusDest, dbusPath)
|
|
||||||
|
|
||||||
call := obj.CallWithContext(ctx, dbusZoneIface+".addInterface", 0, TrustedZone, iface)
|
|
||||||
if call.Err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if dbusErrContains(call.Err, errAlreadyEnabled) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if dbusErrContains(call.Err, errZoneAlreadySet) {
|
|
||||||
move := obj.CallWithContext(ctx, dbusZoneIface+".changeZoneOfInterface", 0, TrustedZone, iface)
|
|
||||||
if move.Err != nil {
|
|
||||||
return fmt.Errorf("firewalld changeZoneOfInterface: %w", move.Err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("firewalld addInterface: %w", call.Err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeDBus(ctx context.Context, iface string) error {
|
|
||||||
conn, err := dbus.SystemBus()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
|
||||||
}
|
|
||||||
obj := conn.Object(dbusDest, dbusPath)
|
|
||||||
|
|
||||||
call := obj.CallWithContext(ctx, dbusZoneIface+".removeInterface", 0, TrustedZone, iface)
|
|
||||||
if call.Err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if dbusErrContains(call.Err, errUnknownIface) || dbusErrContains(call.Err, errNotEnabled) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("firewalld removeInterface: %w", call.Err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func addCLI(ctx context.Context, iface string) error {
|
|
||||||
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
|
||||||
return fmt.Errorf("firewall-cmd not available: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// --change-interface (no --permanent) binds the interface for the
|
|
||||||
// current runtime only; we do not want membership to persist across
|
|
||||||
// reboots because netbird re-asserts it on every startup.
|
|
||||||
out, err := exec.CommandContext(ctx,
|
|
||||||
"firewall-cmd", "--zone="+TrustedZone, "--change-interface="+iface,
|
|
||||||
).CombinedOutput()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("firewall-cmd change-interface: %w: %s", err, strings.TrimSpace(string(out)))
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeCLI(ctx context.Context, iface string) error {
|
|
||||||
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
|
||||||
return fmt.Errorf("firewall-cmd not available: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
out, err := exec.CommandContext(ctx,
|
|
||||||
"firewall-cmd", "--zone="+TrustedZone, "--remove-interface="+iface,
|
|
||||||
).CombinedOutput()
|
|
||||||
if err != nil {
|
|
||||||
msg := strings.TrimSpace(string(out))
|
|
||||||
if strings.Contains(msg, errUnknownIface) || strings.Contains(msg, errNotEnabled) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return fmt.Errorf("firewall-cmd remove-interface: %w: %s", err, msg)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func dbusErrContains(err error, code string) bool {
|
|
||||||
if err == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
var de dbus.Error
|
|
||||||
if errors.As(err, &de) {
|
|
||||||
for _, b := range de.Body {
|
|
||||||
if s, ok := b.(string); ok && strings.Contains(s, code) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return strings.Contains(err.Error(), code)
|
|
||||||
}
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
//go:build linux
|
|
||||||
|
|
||||||
package firewalld
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/godbus/dbus/v5"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDBusErrContains(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
err error
|
|
||||||
code string
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{"nil error", nil, errZoneAlreadySet, false},
|
|
||||||
{"plain error match", errors.New("ZONE_ALREADY_SET: wt0"), errZoneAlreadySet, true},
|
|
||||||
{"plain error miss", errors.New("something else"), errZoneAlreadySet, false},
|
|
||||||
{
|
|
||||||
"dbus.Error body match",
|
|
||||||
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"ZONE_ALREADY_SET: wt0"}},
|
|
||||||
errZoneAlreadySet,
|
|
||||||
true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dbus.Error body miss",
|
|
||||||
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"INVALID_INTERFACE"}},
|
|
||||||
errAlreadyEnabled,
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dbus.Error non-string body falls back to Error()",
|
|
||||||
dbus.Error{Name: "x", Body: []any{123}},
|
|
||||||
"x",
|
|
||||||
true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
got := dbusErrContains(tc.err, tc.code)
|
|
||||||
if got != tc.want {
|
|
||||||
t.Fatalf("dbusErrContains(%v, %q) = %v; want %v", tc.err, tc.code, got, tc.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
//go:build !linux
|
|
||||||
|
|
||||||
package firewalld
|
|
||||||
|
|
||||||
import "context"
|
|
||||||
|
|
||||||
// SetParentContext is a no-op on non-Linux platforms because firewalld only
|
|
||||||
// runs on Linux.
|
|
||||||
func SetParentContext(context.Context) {
|
|
||||||
// intentionally empty: firewalld is a Linux-only daemon
|
|
||||||
}
|
|
||||||
|
|
||||||
// TrustInterface is a no-op on non-Linux platforms because firewalld only
|
|
||||||
// runs on Linux.
|
|
||||||
func TrustInterface(string) error {
|
|
||||||
// intentionally empty: firewalld is a Linux-only daemon
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UntrustInterface is a no-op on non-Linux platforms because firewalld only
|
|
||||||
// runs on Linux.
|
|
||||||
func UntrustInterface(string) error {
|
|
||||||
// intentionally empty: firewalld is a Linux-only daemon
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@@ -87,12 +86,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
|
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trust after all fatal init steps so a later failure doesn't leave the
|
|
||||||
// interface in firewalld's trusted zone without a corresponding Close.
|
|
||||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
|
||||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// persist early to ensure cleanup of chains
|
// persist early to ensure cleanup of chains
|
||||||
go func() {
|
go func() {
|
||||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||||
@@ -198,12 +191,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Appending to merr intentionally blocks DeleteState below so ShutdownState
|
|
||||||
// stays persisted and the crash-recovery path retries firewalld cleanup.
|
|
||||||
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
|
||||||
merr = multierror.Append(merr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// attempt to delete state only if all other operations succeeded
|
// attempt to delete state only if all other operations succeeded
|
||||||
if merr == nil {
|
if merr == nil {
|
||||||
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||||
@@ -230,11 +217,6 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
|
||||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@@ -218,10 +217,6 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
|
||||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
@@ -41,8 +40,6 @@ const (
|
|||||||
chainNameForward = "FORWARD"
|
chainNameForward = "FORWARD"
|
||||||
chainNameMangleForward = "netbird-mangle-forward"
|
chainNameMangleForward = "netbird-mangle-forward"
|
||||||
|
|
||||||
firewalldTableName = "firewalld"
|
|
||||||
|
|
||||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||||
userDataAcceptInputRule = "inputaccept"
|
userDataAcceptInputRule = "inputaccept"
|
||||||
@@ -136,10 +133,6 @@ func (r *router) Reset() error {
|
|||||||
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
|
|
||||||
merr = multierror.Append(merr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.removeNatPreroutingRules(); err != nil {
|
if err := r.removeNatPreroutingRules(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
||||||
}
|
}
|
||||||
@@ -287,10 +280,6 @@ func (r *router) createContainers() error {
|
|||||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
|
|
||||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
log.Errorf("failed to refresh rules: %s", err)
|
log.Errorf("failed to refresh rules: %s", err)
|
||||||
}
|
}
|
||||||
@@ -1330,13 +1319,6 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip firewalld-owned chains. Firewalld creates its chains with the
|
|
||||||
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
|
|
||||||
// We delegate acceptance to firewalld by trusting the interface instead.
|
|
||||||
if chain.Table.Name == firewalldTableName {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip all iptables-managed tables in the ip family
|
// Skip all iptables-managed tables in the ip family
|
||||||
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -3,9 +3,6 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -19,9 +16,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.Close(stateManager)
|
return m.nativeFirewall.Close(stateManager)
|
||||||
}
|
}
|
||||||
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
|
||||||
log.Warnf("failed to untrust interface in firewalld: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -30,8 +24,5 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.AllowNetbird()
|
return m.nativeFirewall.AllowNetbird()
|
||||||
}
|
}
|
||||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
|
||||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
// IFaceMapper defines subset methods of interface required for manager
|
||||||
type IFaceMapper interface {
|
type IFaceMapper interface {
|
||||||
Name() string
|
|
||||||
SetFilter(device.PacketFilter) error
|
SetFilter(device.PacketFilter) error
|
||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
GetWGDevice() *wgdevice.Device
|
GetWGDevice() *wgdevice.Device
|
||||||
|
|||||||
@@ -31,20 +31,12 @@ var logger = log.NewFromLogrus(logrus.StandardLogger())
|
|||||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
type IFaceMock struct {
|
type IFaceMock struct {
|
||||||
NameFunc func() string
|
|
||||||
SetFilterFunc func(device.PacketFilter) error
|
SetFilterFunc func(device.PacketFilter) error
|
||||||
AddressFunc func() wgaddr.Address
|
AddressFunc func() wgaddr.Address
|
||||||
GetWGDeviceFunc func() *wgdevice.Device
|
GetWGDeviceFunc func() *wgdevice.Device
|
||||||
GetDeviceFunc func() *device.FilteredDevice
|
GetDeviceFunc func() *device.FilteredDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *IFaceMock) Name() string {
|
|
||||||
if i.NameFunc == nil {
|
|
||||||
return "wgtest"
|
|
||||||
}
|
|
||||||
return i.NameFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
|
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
|
||||||
if i.GetWGDeviceFunc == nil {
|
if i.GetWGDeviceFunc == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -239,12 +239,8 @@ func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
|
|||||||
ipv6Count++
|
ipv6Count++
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allow some UDP packet loss under load (e.g. FreeBSD/QEMU runners). The
|
assert.Equal(t, packetsPerFamily, ipv4Count)
|
||||||
// routing-correctness checks above are the real assertions; the counts
|
assert.Equal(t, packetsPerFamily, ipv6Count)
|
||||||
// are a sanity bound to catch a totally silent path.
|
|
||||||
minDelivered := packetsPerFamily * 80 / 100
|
|
||||||
assert.GreaterOrEqual(t, ipv4Count, minDelivered, "IPv4 delivery below threshold")
|
|
||||||
assert.GreaterOrEqual(t, ipv6Count, minDelivered, "IPv6 delivery below threshold")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
|
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
|
||||||
|
|||||||
@@ -201,18 +201,7 @@ Pop $0
|
|||||||
|
|
||||||
Function .onInit
|
Function .onInit
|
||||||
StrCpy $INSTDIR "${INSTALL_DIR}"
|
StrCpy $INSTDIR "${INSTALL_DIR}"
|
||||||
; Default autostart to enabled so silent installs (/S) match the interactive default
|
|
||||||
StrCpy $AutostartEnabled "1"
|
|
||||||
|
|
||||||
; Pre-0.70.1 installers ran without SetRegView, so their uninstall keys live
|
|
||||||
; in the 32-bit view. Fall back to it so upgrades still find them.
|
|
||||||
SetRegView 64
|
|
||||||
ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString"
|
ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString"
|
||||||
${If} $R0 == ""
|
|
||||||
SetRegView 32
|
|
||||||
ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString"
|
|
||||||
SetRegView 64
|
|
||||||
${EndIf}
|
|
||||||
${If} $R0 != ""
|
${If} $R0 != ""
|
||||||
# if silent install jump to uninstall step
|
# if silent install jump to uninstall step
|
||||||
IfSilent uninstall
|
IfSilent uninstall
|
||||||
@@ -225,10 +214,6 @@ ${If} $R0 != ""
|
|||||||
|
|
||||||
${EndIf}
|
${EndIf}
|
||||||
FunctionEnd
|
FunctionEnd
|
||||||
|
|
||||||
Function un.onInit
|
|
||||||
SetRegView 64
|
|
||||||
FunctionEnd
|
|
||||||
######################################################################
|
######################################################################
|
||||||
Section -MainProgram
|
Section -MainProgram
|
||||||
${INSTALL_TYPE}
|
${INSTALL_TYPE}
|
||||||
@@ -243,7 +228,6 @@ Section -MainProgram
|
|||||||
!else
|
!else
|
||||||
File /r "..\\dist\\netbird_windows_amd64\\"
|
File /r "..\\dist\\netbird_windows_amd64\\"
|
||||||
!endif
|
!endif
|
||||||
File "..\\client\\ui\\assets\\netbird.png"
|
|
||||||
SectionEnd
|
SectionEnd
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
@@ -263,11 +247,9 @@ WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
|
|||||||
; Create autostart registry entry based on checkbox
|
; Create autostart registry entry based on checkbox
|
||||||
DetailPrint "Autostart enabled: $AutostartEnabled"
|
DetailPrint "Autostart enabled: $AutostartEnabled"
|
||||||
${If} $AutostartEnabled == "1"
|
${If} $AutostartEnabled == "1"
|
||||||
WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"'
|
WriteRegStr HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" "$INSTDIR\${UI_APP_EXE}.exe"
|
||||||
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
|
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
|
||||||
${Else}
|
${Else}
|
||||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
|
||||||
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
|
|
||||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||||
DetailPrint "Autostart not enabled by user"
|
DetailPrint "Autostart not enabled by user"
|
||||||
${EndIf}
|
${EndIf}
|
||||||
@@ -301,8 +283,6 @@ ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
|||||||
|
|
||||||
; Remove autostart registry entry
|
; Remove autostart registry entry
|
||||||
DetailPrint "Removing autostart registry entry if exists..."
|
DetailPrint "Removing autostart registry entry if exists..."
|
||||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
|
||||||
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
|
|
||||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||||
|
|
||||||
; Handle data deletion based on checkbox
|
; Handle data deletion based on checkbox
|
||||||
@@ -341,7 +321,6 @@ DetailPrint "Removing registry keys..."
|
|||||||
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
|
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
|
||||||
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
|
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
|
||||||
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
|
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
|
||||||
DeleteRegKey HKCU "Software\Classes\AppUserModelId\${APP_NAME}"
|
|
||||||
|
|
||||||
DetailPrint "Removing application directory from PATH..."
|
DetailPrint "Removing application directory from PATH..."
|
||||||
EnVar::SetHKLM
|
EnVar::SetHKLM
|
||||||
|
|||||||
@@ -315,6 +315,7 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
|||||||
a.config.RosenpassEnabled,
|
a.config.RosenpassEnabled,
|
||||||
a.config.RosenpassPermissive,
|
a.config.RosenpassPermissive,
|
||||||
a.config.ServerSSHAllowed,
|
a.config.ServerSSHAllowed,
|
||||||
|
a.config.ServerVNCAllowed,
|
||||||
a.config.DisableClientRoutes,
|
a.config.DisableClientRoutes,
|
||||||
a.config.DisableServerRoutes,
|
a.config.DisableServerRoutes,
|
||||||
a.config.DisableDNS,
|
a.config.DisableDNS,
|
||||||
@@ -327,6 +328,7 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
|||||||
a.config.EnableSSHLocalPortForwarding,
|
a.config.EnableSSHLocalPortForwarding,
|
||||||
a.config.EnableSSHRemotePortForwarding,
|
a.config.EnableSSHRemotePortForwarding,
|
||||||
a.config.DisableSSHAuth,
|
a.config.DisableSSHAuth,
|
||||||
|
a.config.DisableVNCAuth,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -333,10 +333,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
c.statusRecorder.MarkSignalConnected()
|
c.statusRecorder.MarkSignalConnected()
|
||||||
|
|
||||||
relayURLs, token := parseRelayInfo(loginResp)
|
relayURLs, token := parseRelayInfo(loginResp)
|
||||||
if override, ok := peer.OverrideRelayURLs(); ok {
|
|
||||||
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
|
|
||||||
relayURLs = override
|
|
||||||
}
|
|
||||||
peerConfig := loginResp.GetPeerConfig()
|
peerConfig := loginResp.GetPeerConfig()
|
||||||
|
|
||||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
|
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
|
||||||
@@ -550,11 +546,13 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
|||||||
RosenpassEnabled: config.RosenpassEnabled,
|
RosenpassEnabled: config.RosenpassEnabled,
|
||||||
RosenpassPermissive: config.RosenpassPermissive,
|
RosenpassPermissive: config.RosenpassPermissive,
|
||||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||||
|
ServerVNCAllowed: config.ServerVNCAllowed != nil && *config.ServerVNCAllowed,
|
||||||
EnableSSHRoot: config.EnableSSHRoot,
|
EnableSSHRoot: config.EnableSSHRoot,
|
||||||
EnableSSHSFTP: config.EnableSSHSFTP,
|
EnableSSHSFTP: config.EnableSSHSFTP,
|
||||||
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
|
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
|
||||||
EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding,
|
EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding,
|
||||||
DisableSSHAuth: config.DisableSSHAuth,
|
DisableSSHAuth: config.DisableSSHAuth,
|
||||||
|
DisableVNCAuth: config.DisableVNCAuth,
|
||||||
DNSRouteInterval: config.DNSRouteInterval,
|
DNSRouteInterval: config.DNSRouteInterval,
|
||||||
|
|
||||||
DisableClientRoutes: config.DisableClientRoutes,
|
DisableClientRoutes: config.DisableClientRoutes,
|
||||||
@@ -631,6 +629,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
|||||||
config.RosenpassEnabled,
|
config.RosenpassEnabled,
|
||||||
config.RosenpassPermissive,
|
config.RosenpassPermissive,
|
||||||
config.ServerSSHAllowed,
|
config.ServerSSHAllowed,
|
||||||
|
config.ServerVNCAllowed,
|
||||||
config.DisableClientRoutes,
|
config.DisableClientRoutes,
|
||||||
config.DisableServerRoutes,
|
config.DisableServerRoutes,
|
||||||
config.DisableDNS,
|
config.DisableDNS,
|
||||||
@@ -643,6 +642,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
|||||||
config.EnableSSHLocalPortForwarding,
|
config.EnableSSHLocalPortForwarding,
|
||||||
config.EnableSSHRemotePortForwarding,
|
config.EnableSSHRemotePortForwarding,
|
||||||
config.DisableSSHAuth,
|
config.DisableSSHAuth,
|
||||||
|
config.DisableVNCAuth,
|
||||||
)
|
)
|
||||||
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
|
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,12 +3,10 @@ package debug
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
@@ -21,10 +19,8 @@ func TestUpload(t *testing.T) {
|
|||||||
t.Skip("Skipping upload test on docker ci")
|
t.Skip("Skipping upload test on docker ci")
|
||||||
}
|
}
|
||||||
testDir := t.TempDir()
|
testDir := t.TempDir()
|
||||||
addr := reserveLoopbackPort(t)
|
testURL := "http://localhost:8080"
|
||||||
testURL := "http://" + addr
|
|
||||||
t.Setenv("SERVER_URL", testURL)
|
t.Setenv("SERVER_URL", testURL)
|
||||||
t.Setenv("SERVER_ADDRESS", addr)
|
|
||||||
t.Setenv("STORE_DIR", testDir)
|
t.Setenv("STORE_DIR", testDir)
|
||||||
srv := server.NewServer()
|
srv := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
@@ -37,7 +33,6 @@ func TestUpload(t *testing.T) {
|
|||||||
t.Errorf("Failed to stop server: %v", err)
|
t.Errorf("Failed to stop server: %v", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
waitForServer(t, addr)
|
|
||||||
|
|
||||||
file := filepath.Join(t.TempDir(), "tmpfile")
|
file := filepath.Join(t.TempDir(), "tmpfile")
|
||||||
fileContent := []byte("test file content")
|
fileContent := []byte("test file content")
|
||||||
@@ -52,30 +47,3 @@ func TestUpload(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, fileContent, createdFileContent)
|
require.Equal(t, fileContent, createdFileContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
// reserveLoopbackPort binds an ephemeral port on loopback to learn a free
|
|
||||||
// address, then releases it so the server under test can rebind. The close/
|
|
||||||
// rebind window is racy in theory; on loopback with a kernel-assigned port
|
|
||||||
// it's essentially never contended in practice.
|
|
||||||
func reserveLoopbackPort(t *testing.T) string {
|
|
||||||
t.Helper()
|
|
||||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
addr := l.Addr().String()
|
|
||||||
require.NoError(t, l.Close())
|
|
||||||
return addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func waitForServer(t *testing.T, addr string) {
|
|
||||||
t.Helper()
|
|
||||||
deadline := time.Now().Add(5 * time.Second)
|
|
||||||
for time.Now().Before(deadline) {
|
|
||||||
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
|
|
||||||
if err == nil {
|
|
||||||
_ = c.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
time.Sleep(20 * time.Millisecond)
|
|
||||||
}
|
|
||||||
t.Fatalf("server did not start listening on %s in time", addr)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
defaultResolvConfPath = "/etc/resolv.conf"
|
defaultResolvConfPath = "/etc/resolv.conf"
|
||||||
nsswitchConfPath = "/etc/nsswitch.conf"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type resolvConf struct {
|
type resolvConf struct {
|
||||||
|
|||||||
@@ -1,10 +1,7 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
|
||||||
"net"
|
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -195,12 +192,6 @@ func (c *HandlerChain) logHandlers() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
c.dispatch(w, r, math.MaxInt)
|
|
||||||
}
|
|
||||||
|
|
||||||
// dispatch routes a DNS request through the chain, skipping handlers with
|
|
||||||
// priority > maxPriority. Shared by ServeDNS and ResolveInternal.
|
|
||||||
func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) {
|
|
||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -225,9 +216,6 @@ func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority in
|
|||||||
|
|
||||||
// Try handlers in priority order
|
// Try handlers in priority order
|
||||||
for _, entry := range handlers {
|
for _, entry := range handlers {
|
||||||
if entry.Priority > maxPriority {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !c.isHandlerMatch(qname, entry) {
|
if !c.isHandlerMatch(qname, entry) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -285,55 +273,6 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
|
|||||||
cw.response.Len(), meta, time.Since(startTime))
|
cw.response.Len(), meta, time.Since(startTime))
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResolveInternal runs an in-process DNS query against the chain, skipping any
|
|
||||||
// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt
|
|
||||||
// cache refresher) that must bypass themselves to avoid loops. Honors ctx
|
|
||||||
// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own
|
|
||||||
// (bounded by the invoked handler's internal timeout).
|
|
||||||
func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) {
|
|
||||||
if len(r.Question) == 0 {
|
|
||||||
return nil, fmt.Errorf("empty question")
|
|
||||||
}
|
|
||||||
|
|
||||||
base := &internalResponseWriter{}
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
c.dispatch(base, r, maxPriority)
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
case <-ctx.Done():
|
|
||||||
// Prefer a completed response if dispatch finished concurrently with cancellation.
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if base.response == nil || base.response.Rcode == dns.RcodeRefused {
|
|
||||||
return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d",
|
|
||||||
strings.ToLower(r.Question[0].Name), maxPriority)
|
|
||||||
}
|
|
||||||
return base.response, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasRootHandlerAtOrBelow reports whether any "." handler is registered at
|
|
||||||
// priority ≤ maxPriority.
|
|
||||||
func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
|
|
||||||
c.mu.RLock()
|
|
||||||
defer c.mu.RUnlock()
|
|
||||||
|
|
||||||
for _, h := range c.handlers {
|
|
||||||
if h.Pattern == "." && h.Priority <= maxPriority {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||||
switch {
|
switch {
|
||||||
case entry.Pattern == ".":
|
case entry.Pattern == ".":
|
||||||
@@ -352,36 +291,3 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// internalResponseWriter captures a dns.Msg for in-process chain queries.
|
|
||||||
type internalResponseWriter struct {
|
|
||||||
response *dns.Msg
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil }
|
|
||||||
func (w *internalResponseWriter) LocalAddr() net.Addr { return nil }
|
|
||||||
func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil }
|
|
||||||
|
|
||||||
// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg
|
|
||||||
// still surface their answer to ResolveInternal.
|
|
||||||
func (w *internalResponseWriter) Write(p []byte) (int, error) {
|
|
||||||
msg := new(dns.Msg)
|
|
||||||
if err := msg.Unpack(p); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
w.response = msg
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *internalResponseWriter) Close() error { return nil }
|
|
||||||
func (w *internalResponseWriter) TsigStatus() error { return nil }
|
|
||||||
|
|
||||||
// TsigTimersOnly is part of dns.ResponseWriter.
|
|
||||||
func (w *internalResponseWriter) TsigTimersOnly(bool) {
|
|
||||||
// no-op: in-process queries carry no TSIG state.
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hijack is part of dns.ResponseWriter.
|
|
||||||
func (w *internalResponseWriter) Hijack() {
|
|
||||||
// no-op: in-process queries have no underlying connection to hand off.
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,15 +1,11 @@
|
|||||||
package dns_test
|
package dns_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
@@ -1046,163 +1042,3 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// answeringHandler writes a fixed A record to ack the query. Used to verify
|
|
||||||
// which handler ResolveInternal dispatches to.
|
|
||||||
type answeringHandler struct {
|
|
||||||
name string
|
|
||||||
ip string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|
||||||
resp := &dns.Msg{}
|
|
||||||
resp.SetReply(r)
|
|
||||||
resp.Answer = []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP(h.ip).To4(),
|
|
||||||
}}
|
|
||||||
_ = w.WriteMsg(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *answeringHandler) String() string { return h.name }
|
|
||||||
|
|
||||||
func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) {
|
|
||||||
chain := nbdns.NewHandlerChain()
|
|
||||||
|
|
||||||
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
|
|
||||||
low := &answeringHandler{name: "low", ip: "10.0.0.2"}
|
|
||||||
|
|
||||||
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
|
|
||||||
chain.AddHandler("example.com.", low, nbdns.PriorityUpstream)
|
|
||||||
|
|
||||||
r := new(dns.Msg)
|
|
||||||
r.SetQuestion("example.com.", dns.TypeA)
|
|
||||||
|
|
||||||
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotNil(t, resp)
|
|
||||||
assert.Equal(t, 1, len(resp.Answer))
|
|
||||||
a, ok := resp.Answer[0].(*dns.A)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) {
|
|
||||||
chain := nbdns.NewHandlerChain()
|
|
||||||
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
|
|
||||||
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
|
|
||||||
|
|
||||||
r := new(dns.Msg)
|
|
||||||
r.SetQuestion("example.com.", dns.TypeA)
|
|
||||||
|
|
||||||
_, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
|
||||||
assert.Error(t, err, "no handler at or below maxPriority should error")
|
|
||||||
}
|
|
||||||
|
|
||||||
// rawWriteHandler packs a response and calls ResponseWriter.Write directly
|
|
||||||
// (instead of WriteMsg), exercising the internalResponseWriter.Write path.
|
|
||||||
type rawWriteHandler struct {
|
|
||||||
ip string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|
||||||
resp := &dns.Msg{}
|
|
||||||
resp.SetReply(r)
|
|
||||||
resp.Answer = []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP(h.ip).To4(),
|
|
||||||
}}
|
|
||||||
packed, err := resp.Pack()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_, _ = w.Write(packed)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) {
|
|
||||||
chain := nbdns.NewHandlerChain()
|
|
||||||
chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream)
|
|
||||||
|
|
||||||
r := new(dns.Msg)
|
|
||||||
r.SetQuestion("example.com.", dns.TypeA)
|
|
||||||
|
|
||||||
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
require.NotNil(t, resp)
|
|
||||||
require.Len(t, resp.Answer, 1)
|
|
||||||
a, ok := resp.Answer[0].(*dns.A)
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) {
|
|
||||||
chain := nbdns.NewHandlerChain()
|
|
||||||
_, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream)
|
|
||||||
assert.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// hangingHandler blocks indefinitely until closed, simulating a wedged upstream.
|
|
||||||
type hangingHandler struct {
|
|
||||||
block chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|
||||||
<-h.block
|
|
||||||
resp := &dns.Msg{}
|
|
||||||
resp.SetReply(r)
|
|
||||||
_ = w.WriteMsg(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *hangingHandler) String() string { return "hangingHandler" }
|
|
||||||
|
|
||||||
func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) {
|
|
||||||
chain := nbdns.NewHandlerChain()
|
|
||||||
h := &hangingHandler{block: make(chan struct{})}
|
|
||||||
defer close(h.block)
|
|
||||||
|
|
||||||
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
|
|
||||||
|
|
||||||
r := new(dns.Msg)
|
|
||||||
r.SetQuestion("example.com.", dns.TypeA)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
_, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream)
|
|
||||||
elapsed := time.Since(start)
|
|
||||||
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
|
||||||
assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) {
|
|
||||||
chain := nbdns.NewHandlerChain()
|
|
||||||
h := &answeringHandler{name: "h", ip: "10.0.0.1"}
|
|
||||||
|
|
||||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain")
|
|
||||||
|
|
||||||
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
|
|
||||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count")
|
|
||||||
|
|
||||||
chain.AddHandler(".", h, nbdns.PriorityMgmtCache)
|
|
||||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded")
|
|
||||||
|
|
||||||
chain.AddHandler(".", h, nbdns.PriorityDefault)
|
|
||||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included")
|
|
||||||
|
|
||||||
chain.RemoveHandler(".", nbdns.PriorityDefault)
|
|
||||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
|
|
||||||
|
|
||||||
// Primary nsgroup case: root handler lands at PriorityUpstream.
|
|
||||||
chain.AddHandler(".", h, nbdns.PriorityUpstream)
|
|
||||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included")
|
|
||||||
chain.RemoveHandler(".", nbdns.PriorityUpstream)
|
|
||||||
|
|
||||||
// Fallback case: original /etc/resolv.conf entries land at PriorityFallback.
|
|
||||||
chain.AddHandler(".", h, nbdns.PriorityFallback)
|
|
||||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included")
|
|
||||||
chain.RemoveHandler(".", nbdns.PriorityFallback)
|
|
||||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -46,12 +46,12 @@ type restoreHostManager interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface string) (hostManager, error) {
|
func newHostManager(wgInterface string) (hostManager, error) {
|
||||||
osManager, reason, err := getOSDNSManagerType()
|
osManager, err := getOSDNSManagerType()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get os dns manager type: %w", err)
|
return nil, fmt.Errorf("get os dns manager type: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("System DNS manager discovered: %s (%s)", osManager, reason)
|
log.Infof("System DNS manager discovered: %s", osManager)
|
||||||
mgr, err := newHostManagerFromType(wgInterface, osManager)
|
mgr, err := newHostManagerFromType(wgInterface, osManager)
|
||||||
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
|
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -74,49 +74,17 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getOSDNSManagerType() (osManagerType, string, error) {
|
func getOSDNSManagerType() (osManagerType, error) {
|
||||||
resolved := isSystemdResolvedRunning()
|
|
||||||
nss := isLibnssResolveUsed()
|
|
||||||
stub := checkStub()
|
|
||||||
|
|
||||||
// Prefer systemd-resolved whenever it owns libc resolution, regardless of
|
|
||||||
// who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups
|
|
||||||
// that go through nss-resolve, and in foreign mode they can loop back
|
|
||||||
// through resolved as an upstream.
|
|
||||||
if resolved && (nss || stub) {
|
|
||||||
return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
mgr, reason, rejected, err := scanResolvConfHeader()
|
|
||||||
if err != nil {
|
|
||||||
return 0, "", err
|
|
||||||
}
|
|
||||||
if reason != "" {
|
|
||||||
return mgr, reason, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub)
|
|
||||||
if len(rejected) > 0 {
|
|
||||||
fallback += "; rejected: " + strings.Join(rejected, ", ")
|
|
||||||
}
|
|
||||||
return fileManager, fallback, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the
|
|
||||||
// matching manager. If reason is empty the caller should pick file mode and
|
|
||||||
// use rejected for diagnostics.
|
|
||||||
func scanResolvConfHeader() (osManagerType, string, []string, error) {
|
|
||||||
file, err := os.Open(defaultResolvConfPath)
|
file, err := os.Open(defaultResolvConfPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if cerr := file.Close(); cerr != nil {
|
if err := file.Close(); err != nil {
|
||||||
log.Errorf("close file %s: %s", defaultResolvConfPath, cerr)
|
log.Errorf("close file %s: %s", defaultResolvConfPath, err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var rejected []string
|
|
||||||
scanner := bufio.NewScanner(file)
|
scanner := bufio.NewScanner(file)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
text := scanner.Text()
|
text := scanner.Text()
|
||||||
@@ -124,48 +92,41 @@ func scanResolvConfHeader() (osManagerType, string, []string, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if text[0] != '#' {
|
if text[0] != '#' {
|
||||||
break
|
return fileManager, nil
|
||||||
}
|
}
|
||||||
if mgr, reason, rej := matchResolvConfHeader(text); reason != "" {
|
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
||||||
return mgr, reason, nil, nil
|
return netbirdManager, nil
|
||||||
} else if rej != "" {
|
}
|
||||||
rejected = append(rejected, rej)
|
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||||
|
return networkManager, nil
|
||||||
|
}
|
||||||
|
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
|
||||||
|
if checkStub() {
|
||||||
|
return systemdManager, nil
|
||||||
|
} else {
|
||||||
|
return fileManager, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.Contains(text, "resolvconf") {
|
||||||
|
if isSystemdResolveConfMode() {
|
||||||
|
return systemdManager, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return resolvConfManager, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := scanner.Err(); err != nil && err != io.EOF {
|
if err := scanner.Err(); err != nil && err != io.EOF {
|
||||||
return 0, "", nil, fmt.Errorf("scan: %w", err)
|
return 0, fmt.Errorf("scan: %w", err)
|
||||||
}
|
}
|
||||||
return 0, "", rejected, nil
|
|
||||||
|
return fileManager, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// matchResolvConfHeader inspects a single comment line. Returns either a
|
// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager.
|
||||||
// definitive (manager, reason) or a non-empty rejected diagnostic.
|
|
||||||
func matchResolvConfHeader(text string) (osManagerType, string, string) {
|
|
||||||
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
|
||||||
return netbirdManager, "netbird-managed resolv.conf header detected", ""
|
|
||||||
}
|
|
||||||
if strings.Contains(text, "NetworkManager") {
|
|
||||||
if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
|
||||||
return networkManager, "NetworkManager header + supported version on dbus", ""
|
|
||||||
}
|
|
||||||
return 0, "", "NetworkManager header (no dbus or unsupported version)"
|
|
||||||
}
|
|
||||||
if strings.Contains(text, "resolvconf") {
|
|
||||||
if isSystemdResolveConfMode() {
|
|
||||||
return systemdManager, "resolvconf header in systemd-resolved compatibility mode", ""
|
|
||||||
}
|
|
||||||
return resolvConfManager, "resolvconf header detected", ""
|
|
||||||
}
|
|
||||||
return 0, "", ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed
|
|
||||||
// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping
|
|
||||||
// into file mode while resolved is active.
|
|
||||||
func checkStub() bool {
|
func checkStub() bool {
|
||||||
rConf, err := parseDefaultResolvConf()
|
rConf, err := parseDefaultResolvConf()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err)
|
log.Warnf("failed to parse resolv conf: %s", err)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -178,36 +139,3 @@ func checkStub() bool {
|
|||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// isLibnssResolveUsed reports whether nss-resolve is listed before dns on
|
|
||||||
// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are
|
|
||||||
// delegated to systemd-resolved regardless of /etc/resolv.conf.
|
|
||||||
func isLibnssResolveUsed() bool {
|
|
||||||
bs, err := os.ReadFile(nsswitchConfPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("read %s: %v", nsswitchConfPath, err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return parseNsswitchResolveAhead(bs)
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseNsswitchResolveAhead(data []byte) bool {
|
|
||||||
for _, line := range strings.Split(string(data), "\n") {
|
|
||||||
if i := strings.IndexByte(line, '#'); i >= 0 {
|
|
||||||
line = line[:i]
|
|
||||||
}
|
|
||||||
fields := strings.Fields(line)
|
|
||||||
if len(fields) < 2 || fields[0] != "hosts:" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, module := range fields[1:] {
|
|
||||||
switch module {
|
|
||||||
case "dns":
|
|
||||||
return false
|
|
||||||
case "resolve":
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,76 +0,0 @@
|
|||||||
//go:build (linux && !android) || freebsd
|
|
||||||
|
|
||||||
package dns
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestParseNsswitchResolveAhead(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
in string
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "resolve before dns with action token",
|
|
||||||
in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n",
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "dns before resolve",
|
|
||||||
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n",
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "debian default with only dns",
|
|
||||||
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n",
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "neither resolve nor dns",
|
|
||||||
in: "hosts: files myhostname\n",
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no hosts line",
|
|
||||||
in: "passwd: files systemd\ngroup: files systemd\n",
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty",
|
|
||||||
in: "",
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "comments and blank lines ignored",
|
|
||||||
in: "# comment\n\n# another\nhosts: resolve dns\n",
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "trailing inline comment",
|
|
||||||
in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n",
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "hosts token must be the first field",
|
|
||||||
in: " hosts: resolve dns\n",
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "other db line mentioning resolve is ignored",
|
|
||||||
in: "networks: resolve\nhosts: dns\n",
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "only resolve, no dns",
|
|
||||||
in: "hosts: files resolve\n",
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want {
|
|
||||||
t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -2,83 +2,40 @@ package mgmt
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
|
||||||
"slices"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sync/singleflight"
|
|
||||||
|
|
||||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const dnsTimeout = 5 * time.Second
|
||||||
dnsTimeout = 5 * time.Second
|
|
||||||
defaultTTL = 300 * time.Second
|
|
||||||
refreshBackoff = 30 * time.Second
|
|
||||||
|
|
||||||
// envMgmtCacheTTL overrides defaultTTL for integration/dev testing.
|
// Resolver caches critical NetBird infrastructure domains
|
||||||
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ChainResolver lets the cache refresh stale entries through the DNS handler
|
|
||||||
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
|
|
||||||
// system resolver.
|
|
||||||
type ChainResolver interface {
|
|
||||||
ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error)
|
|
||||||
HasRootHandlerAtOrBelow(maxPriority int) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// cachedRecord holds DNS records plus timestamps used for TTL refresh.
|
|
||||||
// records and cachedAt are set at construction and treated as immutable;
|
|
||||||
// lastFailedRefresh and consecFailures are mutable and must be accessed under
|
|
||||||
// Resolver.mutex.
|
|
||||||
type cachedRecord struct {
|
|
||||||
records []dns.RR
|
|
||||||
cachedAt time.Time
|
|
||||||
lastFailedRefresh time.Time
|
|
||||||
consecFailures int
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resolver caches critical NetBird infrastructure domains.
|
|
||||||
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
|
|
||||||
type Resolver struct {
|
type Resolver struct {
|
||||||
records map[dns.Question]*cachedRecord
|
records map[dns.Question][]dns.RR
|
||||||
mgmtDomain *domain.Domain
|
mgmtDomain *domain.Domain
|
||||||
serverDomains *dnsconfig.ServerDomains
|
serverDomains *dnsconfig.ServerDomains
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
chain ChainResolver
|
type ipsResponse struct {
|
||||||
chainMaxPriority int
|
ips []netip.Addr
|
||||||
refreshGroup singleflight.Group
|
err error
|
||||||
|
|
||||||
// refreshing tracks questions whose refresh is running via the OS
|
|
||||||
// fallback path. A ServeDNS hit for a question in this map indicates
|
|
||||||
// the OS resolver routed the recursive query back to us (loop). Only
|
|
||||||
// the OS path arms this so chain-path refreshes don't produce false
|
|
||||||
// positives. The atomic bool is CAS-flipped once per refresh to
|
|
||||||
// throttle the warning log.
|
|
||||||
refreshing map[dns.Question]*atomic.Bool
|
|
||||||
|
|
||||||
cacheTTL time.Duration
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewResolver creates a new management domains cache resolver.
|
// NewResolver creates a new management domains cache resolver.
|
||||||
func NewResolver() *Resolver {
|
func NewResolver() *Resolver {
|
||||||
return &Resolver{
|
return &Resolver{
|
||||||
records: make(map[dns.Question]*cachedRecord),
|
records: make(map[dns.Question][]dns.RR),
|
||||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
|
||||||
cacheTTL: resolveCacheTTL(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,19 +44,7 @@ func (m *Resolver) String() string {
|
|||||||
return "MgmtCacheResolver"
|
return "MgmtCacheResolver"
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetChainResolver wires the handler chain used to refresh stale cache entries.
|
// ServeDNS implements dns.Handler interface.
|
||||||
// maxPriority caps which handlers may answer refresh queries (typically
|
|
||||||
// PriorityUpstream, so upstream/default/fallback handlers are consulted and
|
|
||||||
// mgmt/route/local handlers are skipped).
|
|
||||||
func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
|
|
||||||
m.mutex.Lock()
|
|
||||||
m.chain = chain
|
|
||||||
m.chainMaxPriority = maxPriority
|
|
||||||
m.mutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeDNS serves cached A/AAAA records. Stale entries are returned
|
|
||||||
// immediately and refreshed asynchronously (stale-while-revalidate).
|
|
||||||
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
m.continueToNext(w, r)
|
m.continueToNext(w, r)
|
||||||
@@ -115,14 +60,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
cached, found := m.records[question]
|
records, found := m.records[question]
|
||||||
inflight := m.refreshing[question]
|
|
||||||
var shouldRefresh bool
|
|
||||||
if found {
|
|
||||||
stale := time.Since(cached.cachedAt) > m.cacheTTL
|
|
||||||
inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff
|
|
||||||
shouldRefresh = stale && !inBackoff
|
|
||||||
}
|
|
||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
if !found {
|
if !found {
|
||||||
@@ -130,23 +68,12 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if inflight != nil && inflight.CompareAndSwap(false, true) {
|
|
||||||
log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)",
|
|
||||||
question.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip scheduling a refresh goroutine if one is already inflight for
|
|
||||||
// this question; singleflight would dedup anyway but skipping avoids
|
|
||||||
// a parked goroutine per stale hit under bursty load.
|
|
||||||
if shouldRefresh && inflight == nil {
|
|
||||||
m.scheduleRefresh(question, cached)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := &dns.Msg{}
|
resp := &dns.Msg{}
|
||||||
resp.SetReply(r)
|
resp.SetReply(r)
|
||||||
resp.Authoritative = false
|
resp.Authoritative = false
|
||||||
resp.RecursionAvailable = true
|
resp.RecursionAvailable = true
|
||||||
resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt))
|
|
||||||
|
resp.Answer = append(resp.Answer, records...)
|
||||||
|
|
||||||
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
|
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
|
||||||
|
|
||||||
@@ -171,260 +98,101 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
// AddDomain manually adds a domain to cache by resolving it.
|
||||||
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
|
||||||
// entry for that qtype.
|
|
||||||
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName)
|
ips, err := lookupIPWithExtraTimeout(ctx, d)
|
||||||
|
if err != nil {
|
||||||
if errA != nil && errAAAA != nil {
|
return err
|
||||||
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(aRecords) == 0 && len(aaaaRecords) == 0 {
|
var aRecords, aaaaRecords []dns.RR
|
||||||
if err := errors.Join(errA, errAAAA); err != nil {
|
for _, ip := range ips {
|
||||||
return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err)
|
if ip.Is4() {
|
||||||
|
rr := &dns.A{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: dnsName,
|
||||||
|
Rrtype: dns.TypeA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: 300,
|
||||||
|
},
|
||||||
|
A: ip.AsSlice(),
|
||||||
|
}
|
||||||
|
aRecords = append(aRecords, rr)
|
||||||
|
} else if ip.Is6() {
|
||||||
|
rr := &dns.AAAA{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: dnsName,
|
||||||
|
Rrtype: dns.TypeAAAA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: 300,
|
||||||
|
},
|
||||||
|
AAAA: ip.AsSlice(),
|
||||||
|
}
|
||||||
|
aaaaRecords = append(aaaaRecords, rr)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now)
|
if len(aRecords) > 0 {
|
||||||
m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now)
|
aQuestion := dns.Question{
|
||||||
|
Name: dnsName,
|
||||||
|
Qtype: dns.TypeA,
|
||||||
|
Qclass: dns.ClassINET,
|
||||||
|
}
|
||||||
|
m.records[aQuestion] = aRecords
|
||||||
|
}
|
||||||
|
|
||||||
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
if len(aaaaRecords) > 0 {
|
||||||
|
aaaaQuestion := dns.Question{
|
||||||
|
Name: dnsName,
|
||||||
|
Qtype: dns.TypeAAAA,
|
||||||
|
Qclass: dns.ClassINET,
|
||||||
|
}
|
||||||
|
m.records[aaaaQuestion] = aaaaRecords
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mutex.Unlock()
|
||||||
|
|
||||||
|
log.Debugf("added domain=%s with %d A records and %d AAAA records",
|
||||||
d.SafeString(), len(aRecords), len(aaaaRecords))
|
d.SafeString(), len(aRecords), len(aaaaRecords))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyFamilyRecords writes records, evicts on NODATA, leaves the cache
|
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
|
||||||
// untouched on error. Caller holds m.mutex.
|
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
|
||||||
func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) {
|
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
|
||||||
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
resultChan := make(chan *ipsResponse, 1)
|
||||||
switch {
|
|
||||||
case len(records) > 0:
|
|
||||||
m.records[q] = &cachedRecord{records: records, cachedAt: now}
|
|
||||||
case err == nil:
|
|
||||||
delete(m.records, q)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per
|
go func() {
|
||||||
// unique in-flight key; bursty stale hits share its channel. expected is the
|
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
|
||||||
// cachedRecord pointer observed by the caller; the refresh only mutates the
|
resultChan <- &ipsResponse{
|
||||||
// cache if that pointer is still the one stored, so a stale in-flight refresh
|
err: err,
|
||||||
// can't clobber a newer entry written by AddDomain or a competing refresh.
|
ips: ips,
|
||||||
func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) {
|
|
||||||
key := question.Name + "|" + dns.TypeToString[question.Qtype]
|
|
||||||
_ = m.refreshGroup.DoChan(key, func() (any, error) {
|
|
||||||
return nil, m.refreshQuestion(question, expected)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// refreshQuestion replaces the cached records on success, or marks the entry
|
|
||||||
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
|
|
||||||
// a resolver loop by spotting a query for this same question arriving on us.
|
|
||||||
// expected pins the cache entry observed at schedule time; mutations only apply
|
|
||||||
// if m.records[question] still points at it.
|
|
||||||
func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
|
|
||||||
if err != nil {
|
|
||||||
m.markRefreshFailed(question, expected)
|
|
||||||
return fmt.Errorf("parse domain: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
records, err := m.lookupRecords(ctx, d, question)
|
|
||||||
if err != nil {
|
|
||||||
fails := m.markRefreshFailed(question, expected)
|
|
||||||
logf := log.Warnf
|
|
||||||
if fails == 0 || fails > 1 {
|
|
||||||
logf = log.Debugf
|
|
||||||
}
|
}
|
||||||
logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)",
|
}()
|
||||||
d.SafeString(), dns.TypeToString[question.Qtype], err, fails)
|
|
||||||
return err
|
var resp *ipsResponse
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(dnsTimeout + time.Millisecond*500):
|
||||||
|
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
|
||||||
|
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case resp = <-resultChan:
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOERROR/NODATA: family gone upstream, evict so we stop serving stale.
|
if resp.err != nil {
|
||||||
if len(records) == 0 {
|
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
|
||||||
m.mutex.Lock()
|
|
||||||
if m.records[question] == expected {
|
|
||||||
delete(m.records, question)
|
|
||||||
m.mutex.Unlock()
|
|
||||||
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
|
|
||||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
m.mutex.Unlock()
|
|
||||||
log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh",
|
|
||||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
return resp.ips, nil
|
||||||
now := time.Now()
|
|
||||||
m.mutex.Lock()
|
|
||||||
if m.records[question] != expected {
|
|
||||||
m.mutex.Unlock()
|
|
||||||
log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh",
|
|
||||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
m.records[question] = &cachedRecord{records: records, cachedAt: now}
|
|
||||||
m.mutex.Unlock()
|
|
||||||
|
|
||||||
log.Infof("refreshed mgmt cache domain=%s type=%s",
|
|
||||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Resolver) markRefreshing(question dns.Question) {
|
|
||||||
m.mutex.Lock()
|
|
||||||
m.refreshing[question] = &atomic.Bool{}
|
|
||||||
m.mutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Resolver) clearRefreshing(question dns.Question) {
|
|
||||||
m.mutex.Lock()
|
|
||||||
delete(m.refreshing, question)
|
|
||||||
m.mutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// markRefreshFailed arms the backoff and returns the new consecutive-failure
|
|
||||||
// count so callers can downgrade subsequent failure logs to debug.
|
|
||||||
func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
c, ok := m.records[question]
|
|
||||||
if !ok || c != expected {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
c.lastFailedRefresh = time.Now()
|
|
||||||
c.consecFailures++
|
|
||||||
return c.consecFailures
|
|
||||||
}
|
|
||||||
|
|
||||||
// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let
|
|
||||||
// callers tell records, NODATA (nil err, no records), and failure apart.
|
|
||||||
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
|
|
||||||
m.mutex.RLock()
|
|
||||||
chain := m.chain
|
|
||||||
maxPriority := m.chainMaxPriority
|
|
||||||
m.mutex.RUnlock()
|
|
||||||
|
|
||||||
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
|
||||||
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
|
|
||||||
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: drop once every supported OS registers a fallback resolver. Safe
|
|
||||||
// today: no root handler at priority ≤ PriorityUpstream means NetBird is
|
|
||||||
// not the system resolver, so net.DefaultResolver will not loop back.
|
|
||||||
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
|
|
||||||
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// lookupRecords resolves a single record type via chain or OS. The OS branch
|
|
||||||
// arms the loop detector for the duration of its call so that ServeDNS can
|
|
||||||
// spot the OS resolver routing the recursive query back to us.
|
|
||||||
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
|
|
||||||
m.mutex.RLock()
|
|
||||||
chain := m.chain
|
|
||||||
maxPriority := m.chainMaxPriority
|
|
||||||
m.mutex.RUnlock()
|
|
||||||
|
|
||||||
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
|
||||||
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: drop once every supported OS registers a fallback resolver.
|
|
||||||
m.markRefreshing(q)
|
|
||||||
defer m.clearRefreshing(q)
|
|
||||||
|
|
||||||
return m.osLookup(ctx, d, q.Name, q.Qtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
// lookupViaChain resolves via the handler chain and rewrites each RR to use
|
|
||||||
// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache
|
|
||||||
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
|
|
||||||
func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) {
|
|
||||||
msg := &dns.Msg{}
|
|
||||||
msg.SetQuestion(dnsName, qtype)
|
|
||||||
msg.RecursionDesired = true
|
|
||||||
|
|
||||||
resp, err := chain.ResolveInternal(ctx, msg, maxPriority)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("chain resolve: %w", err)
|
|
||||||
}
|
|
||||||
if resp == nil {
|
|
||||||
return nil, fmt.Errorf("chain resolve returned nil response")
|
|
||||||
}
|
|
||||||
if resp.Rcode != dns.RcodeSuccess {
|
|
||||||
return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode])
|
|
||||||
}
|
|
||||||
|
|
||||||
ttl := uint32(m.cacheTTL.Seconds())
|
|
||||||
owners := cnameOwners(dnsName, resp.Answer)
|
|
||||||
var filtered []dns.RR
|
|
||||||
for _, rr := range resp.Answer {
|
|
||||||
h := rr.Header()
|
|
||||||
if h.Class != dns.ClassINET || h.Rrtype != qtype {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !owners[strings.ToLower(dns.Fqdn(h.Name))] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil {
|
|
||||||
filtered = append(filtered, cp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return filtered, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// osLookup resolves a single family via net.DefaultResolver using resutil,
|
|
||||||
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
|
|
||||||
// returns (nil, nil).
|
|
||||||
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
|
|
||||||
network := resutil.NetworkForQtype(qtype)
|
|
||||||
if network == "" {
|
|
||||||
return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype])
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
|
||||||
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
|
||||||
|
|
||||||
result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype)
|
|
||||||
if result.Rcode == dns.RcodeSuccess {
|
|
||||||
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.Err != nil {
|
|
||||||
return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode])
|
|
||||||
}
|
|
||||||
|
|
||||||
// responseTTL returns the remaining cache lifetime in seconds (rounded up),
|
|
||||||
// so downstream resolvers don't cache an answer for longer than we will.
|
|
||||||
func (m *Resolver) responseTTL(cachedAt time.Time) uint32 {
|
|
||||||
remaining := m.cacheTTL - time.Since(cachedAt)
|
|
||||||
if remaining <= 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return uint32((remaining + time.Second - 1) / time.Second)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// PopulateFromConfig extracts and caches domains from the client configuration.
|
// PopulateFromConfig extracts and caches domains from the client configuration.
|
||||||
@@ -456,12 +224,19 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
aQuestion := dns.Question{
|
||||||
qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}
|
Name: dnsName,
|
||||||
delete(m.records, qA)
|
Qtype: dns.TypeA,
|
||||||
delete(m.records, qAAAA)
|
Qclass: dns.ClassINET,
|
||||||
delete(m.refreshing, qA)
|
}
|
||||||
delete(m.refreshing, qAAAA)
|
delete(m.records, aQuestion)
|
||||||
|
|
||||||
|
aaaaQuestion := dns.Question{
|
||||||
|
Name: dnsName,
|
||||||
|
Qtype: dns.TypeAAAA,
|
||||||
|
Qclass: dns.ClassINET,
|
||||||
|
}
|
||||||
|
delete(m.records, aaaaQuestion)
|
||||||
|
|
||||||
log.Debugf("removed domain=%s from cache", d.SafeString())
|
log.Debugf("removed domain=%s from cache", d.SafeString())
|
||||||
return nil
|
return nil
|
||||||
@@ -619,73 +394,3 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
|
|||||||
|
|
||||||
return domains
|
return domains
|
||||||
}
|
}
|
||||||
|
|
||||||
// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non
|
|
||||||
// A/AAAA records return nil.
|
|
||||||
func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR {
|
|
||||||
switch r := rr.(type) {
|
|
||||||
case *dns.A:
|
|
||||||
cp := *r
|
|
||||||
cp.Hdr.Name = owner
|
|
||||||
cp.Hdr.Ttl = ttl
|
|
||||||
cp.A = slices.Clone(r.A)
|
|
||||||
return &cp
|
|
||||||
case *dns.AAAA:
|
|
||||||
cp := *r
|
|
||||||
cp.Hdr.Name = owner
|
|
||||||
cp.Hdr.Ttl = ttl
|
|
||||||
cp.AAAA = slices.Clone(r.AAAA)
|
|
||||||
return &cp
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// cloneRecordsWithTTL clones A/AAAA records preserving their owner and
|
|
||||||
// stamping ttl so the response shares no memory with the cached slice.
|
|
||||||
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
|
|
||||||
out := make([]dns.RR, 0, len(records))
|
|
||||||
for _, rr := range records {
|
|
||||||
if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil {
|
|
||||||
out = append(out, cp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// cnameOwners returns dnsName plus every target reachable by following CNAMEs
|
|
||||||
// in answer, iterating until fixed point so out-of-order chains resolve.
|
|
||||||
func cnameOwners(dnsName string, answer []dns.RR) map[string]bool {
|
|
||||||
owners := map[string]bool{dnsName: true}
|
|
||||||
for {
|
|
||||||
added := false
|
|
||||||
for _, rr := range answer {
|
|
||||||
cname, ok := rr.(*dns.CNAME)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
name := strings.ToLower(dns.Fqdn(cname.Hdr.Name))
|
|
||||||
if !owners[name] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
target := strings.ToLower(dns.Fqdn(cname.Target))
|
|
||||||
if !owners[target] {
|
|
||||||
owners[target] = true
|
|
||||||
added = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !added {
|
|
||||||
return owners
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// resolveCacheTTL reads the cache TTL override env var; invalid or empty
|
|
||||||
// values fall back to defaultTTL. Called once per Resolver from NewResolver.
|
|
||||||
func resolveCacheTTL() time.Duration {
|
|
||||||
if v := os.Getenv(envMgmtCacheTTL); v != "" {
|
|
||||||
if d, err := time.ParseDuration(v); err == nil && d > 0 {
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return defaultTTL
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,408 +0,0 @@
|
|||||||
package mgmt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
)
|
|
||||||
|
|
||||||
type fakeChain struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
calls map[string]int
|
|
||||||
answers map[string][]dns.RR
|
|
||||||
err error
|
|
||||||
hasRoot bool
|
|
||||||
onLookup func()
|
|
||||||
}
|
|
||||||
|
|
||||||
func newFakeChain() *fakeChain {
|
|
||||||
return &fakeChain{
|
|
||||||
calls: map[string]int{},
|
|
||||||
answers: map[string][]dns.RR{},
|
|
||||||
hasRoot: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
|
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
return f.hasRoot
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) {
|
|
||||||
f.mu.Lock()
|
|
||||||
q := msg.Question[0]
|
|
||||||
key := q.Name + "|" + dns.TypeToString[q.Qtype]
|
|
||||||
f.calls[key]++
|
|
||||||
answers := f.answers[key]
|
|
||||||
err := f.err
|
|
||||||
onLookup := f.onLookup
|
|
||||||
f.mu.Unlock()
|
|
||||||
|
|
||||||
if onLookup != nil {
|
|
||||||
onLookup()
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
resp := &dns.Msg{}
|
|
||||||
resp.SetReply(msg)
|
|
||||||
resp.Answer = answers
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
|
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
key := name + "|" + dns.TypeToString[qtype]
|
|
||||||
hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60}
|
|
||||||
switch qtype {
|
|
||||||
case dns.TypeA:
|
|
||||||
f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}}
|
|
||||||
case dns.TypeAAAA:
|
|
||||||
f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeChain) callCount(name string, qtype uint16) int {
|
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
return f.calls[name+"|"+dns.TypeToString[qtype]]
|
|
||||||
}
|
|
||||||
|
|
||||||
// waitFor polls the predicate until it returns true or the deadline passes.
|
|
||||||
func waitFor(t *testing.T, d time.Duration, fn func() bool) {
|
|
||||||
t.Helper()
|
|
||||||
deadline := time.Now().Add(d)
|
|
||||||
for time.Now().Before(deadline) {
|
|
||||||
if fn() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
time.Sleep(5 * time.Millisecond)
|
|
||||||
}
|
|
||||||
t.Fatalf("condition not met within %s", d)
|
|
||||||
}
|
|
||||||
|
|
||||||
func queryA(t *testing.T, r *Resolver, name string) *dns.Msg {
|
|
||||||
t.Helper()
|
|
||||||
msg := new(dns.Msg)
|
|
||||||
msg.SetQuestion(name, dns.TypeA)
|
|
||||||
w := &test.MockResponseWriter{}
|
|
||||||
r.ServeDNS(w, msg)
|
|
||||||
return w.GetLastResponse()
|
|
||||||
}
|
|
||||||
|
|
||||||
func firstA(t *testing.T, resp *dns.Msg) string {
|
|
||||||
t.Helper()
|
|
||||||
require.NotNil(t, resp)
|
|
||||||
require.Greater(t, len(resp.Answer), 0, "expected at least one answer")
|
|
||||||
a, ok := resp.Answer[0].(*dns.A)
|
|
||||||
require.True(t, ok, "expected A record")
|
|
||||||
return a.A.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
|
|
||||||
// Same cached entry age, different cacheTTL values: the shorter TTL must
|
|
||||||
// trigger a background refresh, the longer one must not. Proves that the
|
|
||||||
// per-Resolver cacheTTL field actually drives the stale decision.
|
|
||||||
cachedAt := time.Now().Add(-100 * time.Millisecond)
|
|
||||||
|
|
||||||
newRec := func() *cachedRecord {
|
|
||||||
return &cachedRecord{
|
|
||||||
records: []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1").To4(),
|
|
||||||
}},
|
|
||||||
cachedAt: cachedAt,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
|
||||||
|
|
||||||
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
r.cacheTTL = 10 * time.Millisecond
|
|
||||||
chain := newFakeChain()
|
|
||||||
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
|
|
||||||
r.SetChainResolver(chain, 50)
|
|
||||||
r.records[q] = newRec()
|
|
||||||
|
|
||||||
resp := queryA(t, r, q.Name)
|
|
||||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
|
|
||||||
|
|
||||||
waitFor(t, time.Second, func() bool {
|
|
||||||
return chain.callCount(q.Name, dns.TypeA) >= 1
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
r.cacheTTL = time.Hour
|
|
||||||
chain := newFakeChain()
|
|
||||||
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
|
|
||||||
r.SetChainResolver(chain, 50)
|
|
||||||
r.records[q] = newRec()
|
|
||||||
|
|
||||||
resp := queryA(t, r, q.Name)
|
|
||||||
assert.Equal(t, "10.0.0.1", firstA(t, resp))
|
|
||||||
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
chain := newFakeChain()
|
|
||||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
|
||||||
r.SetChainResolver(chain, 50)
|
|
||||||
|
|
||||||
r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{
|
|
||||||
records: []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1").To4(),
|
|
||||||
}},
|
|
||||||
cachedAt: time.Now(), // fresh
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := queryA(t, r, "mgmt.example.com.")
|
|
||||||
assert.Equal(t, "10.0.0.1", firstA(t, resp))
|
|
||||||
|
|
||||||
time.Sleep(20 * time.Millisecond)
|
|
||||||
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
chain := newFakeChain()
|
|
||||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
|
||||||
r.SetChainResolver(chain, 50)
|
|
||||||
|
|
||||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
|
||||||
r.records[q] = &cachedRecord{
|
|
||||||
records: []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1").To4(),
|
|
||||||
}},
|
|
||||||
cachedAt: time.Now().Add(-2 * defaultTTL), // stale
|
|
||||||
}
|
|
||||||
|
|
||||||
// First query: serves stale immediately.
|
|
||||||
resp := queryA(t, r, "mgmt.example.com.")
|
|
||||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
|
|
||||||
|
|
||||||
waitFor(t, time.Second, func() bool {
|
|
||||||
return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1
|
|
||||||
})
|
|
||||||
|
|
||||||
// Next query should now return the refreshed IP.
|
|
||||||
waitFor(t, time.Second, func() bool {
|
|
||||||
resp := queryA(t, r, "mgmt.example.com.")
|
|
||||||
return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2"
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
chain := newFakeChain()
|
|
||||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
|
||||||
|
|
||||||
var inflight atomic.Int32
|
|
||||||
var maxInflight atomic.Int32
|
|
||||||
chain.onLookup = func() {
|
|
||||||
cur := inflight.Add(1)
|
|
||||||
defer inflight.Add(-1)
|
|
||||||
for {
|
|
||||||
prev := maxInflight.Load()
|
|
||||||
if cur <= prev || maxInflight.CompareAndSwap(prev, cur) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide
|
|
||||||
}
|
|
||||||
|
|
||||||
r.SetChainResolver(chain, 50)
|
|
||||||
|
|
||||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
|
||||||
r.records[q] = &cachedRecord{
|
|
||||||
records: []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1").To4(),
|
|
||||||
}},
|
|
||||||
cachedAt: time.Now().Add(-2 * defaultTTL),
|
|
||||||
}
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for i := 0; i < 50; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
queryA(t, r, "mgmt.example.com.")
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
waitFor(t, 2*time.Second, func() bool {
|
|
||||||
return inflight.Load() == 0
|
|
||||||
})
|
|
||||||
|
|
||||||
calls := chain.callCount("mgmt.example.com.", dns.TypeA)
|
|
||||||
assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls)
|
|
||||||
assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
chain := newFakeChain()
|
|
||||||
chain.err = errors.New("boom")
|
|
||||||
r.SetChainResolver(chain, 50)
|
|
||||||
|
|
||||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
|
||||||
r.records[q] = &cachedRecord{
|
|
||||||
records: []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1").To4(),
|
|
||||||
}},
|
|
||||||
cachedAt: time.Now().Add(-2 * defaultTTL),
|
|
||||||
}
|
|
||||||
|
|
||||||
// First stale hit triggers a refresh attempt that fails.
|
|
||||||
resp := queryA(t, r, "mgmt.example.com.")
|
|
||||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails")
|
|
||||||
|
|
||||||
waitFor(t, time.Second, func() bool {
|
|
||||||
return chain.callCount("mgmt.example.com.", dns.TypeA) == 1
|
|
||||||
})
|
|
||||||
waitFor(t, time.Second, func() bool {
|
|
||||||
r.mutex.RLock()
|
|
||||||
defer r.mutex.RUnlock()
|
|
||||||
c, ok := r.records[q]
|
|
||||||
return ok && !c.lastFailedRefresh.IsZero()
|
|
||||||
})
|
|
||||||
|
|
||||||
// Subsequent stale hits within backoff window should not schedule more refreshes.
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
queryA(t, r, "mgmt.example.com.")
|
|
||||||
}
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
chain := newFakeChain()
|
|
||||||
chain.hasRoot = false
|
|
||||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
|
||||||
r.SetChainResolver(chain, 50)
|
|
||||||
|
|
||||||
// With hasRoot=false the chain must not be consulted. Use a short
|
|
||||||
// deadline so the OS fallback returns quickly without waiting on a
|
|
||||||
// real network call in CI.
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
_, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.")
|
|
||||||
|
|
||||||
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA),
|
|
||||||
"chain must not be used when no root handler is registered at the bound priority")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
|
|
||||||
// ServeDNS being invoked for a question while a refresh for that question
|
|
||||||
// is inflight indicates a resolver loop (OS resolver sent the recursive
|
|
||||||
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
|
|
||||||
r := NewResolver()
|
|
||||||
|
|
||||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
|
||||||
r.records[q] = &cachedRecord{
|
|
||||||
records: []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1").To4(),
|
|
||||||
}},
|
|
||||||
cachedAt: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simulate an inflight refresh.
|
|
||||||
r.markRefreshing(q)
|
|
||||||
defer r.clearRefreshing(q)
|
|
||||||
|
|
||||||
resp := queryA(t, r, "mgmt.example.com.")
|
|
||||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries")
|
|
||||||
|
|
||||||
r.mutex.RLock()
|
|
||||||
inflight := r.refreshing[q]
|
|
||||||
r.mutex.RUnlock()
|
|
||||||
require.NotNil(t, inflight)
|
|
||||||
assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
|
|
||||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
|
||||||
r.records[q] = &cachedRecord{
|
|
||||||
records: []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1").To4(),
|
|
||||||
}},
|
|
||||||
cachedAt: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
r.markRefreshing(q)
|
|
||||||
defer r.clearRefreshing(q)
|
|
||||||
|
|
||||||
// Multiple ServeDNS calls during the same refresh must not re-set the flag
|
|
||||||
// (CompareAndSwap from false -> true returns true only on the first call).
|
|
||||||
for range 5 {
|
|
||||||
queryA(t, r, "mgmt.example.com.")
|
|
||||||
}
|
|
||||||
|
|
||||||
r.mutex.RLock()
|
|
||||||
inflight := r.refreshing[q]
|
|
||||||
r.mutex.RUnlock()
|
|
||||||
assert.True(t, inflight.Load())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
|
|
||||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
|
||||||
r.records[q] = &cachedRecord{
|
|
||||||
records: []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1").To4(),
|
|
||||||
}},
|
|
||||||
cachedAt: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
queryA(t, r, "mgmt.example.com.")
|
|
||||||
|
|
||||||
r.mutex.RLock()
|
|
||||||
_, ok := r.refreshing[q]
|
|
||||||
r.mutex.RUnlock()
|
|
||||||
assert.False(t, ok, "no refresh inflight means no loop tracking")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
chain := newFakeChain()
|
|
||||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
|
||||||
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")
|
|
||||||
r.SetChainResolver(chain, 50)
|
|
||||||
|
|
||||||
require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com")))
|
|
||||||
|
|
||||||
resp := queryA(t, r, "mgmt.example.com.")
|
|
||||||
assert.Equal(t, "10.0.0.2", firstA(t, resp))
|
|
||||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA))
|
|
||||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA))
|
|
||||||
}
|
|
||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -24,60 +23,6 @@ func TestResolver_NewResolver(t *testing.T) {
|
|||||||
assert.False(t, resolver.MatchSubdomains())
|
assert.False(t, resolver.MatchSubdomains())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolveCacheTTL(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
value string
|
|
||||||
want time.Duration
|
|
||||||
}{
|
|
||||||
{"unset falls back to default", "", defaultTTL},
|
|
||||||
{"valid duration", "45s", 45 * time.Second},
|
|
||||||
{"valid minutes", "2m", 2 * time.Minute},
|
|
||||||
{"malformed falls back to default", "not-a-duration", defaultTTL},
|
|
||||||
{"zero falls back to default", "0s", defaultTTL},
|
|
||||||
{"negative falls back to default", "-5s", defaultTTL},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
t.Setenv(envMgmtCacheTTL, tc.value)
|
|
||||||
got := resolveCacheTTL()
|
|
||||||
assert.Equal(t, tc.want, got, "parsed TTL should match")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
|
|
||||||
t.Setenv(envMgmtCacheTTL, "7s")
|
|
||||||
r := NewResolver()
|
|
||||||
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_ResponseTTL(t *testing.T) {
|
|
||||||
now := time.Now()
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
cacheTTL time.Duration
|
|
||||||
cachedAt time.Time
|
|
||||||
wantMin uint32
|
|
||||||
wantMax uint32
|
|
||||||
}{
|
|
||||||
{"fresh entry returns full TTL", 60 * time.Second, now, 59, 60},
|
|
||||||
{"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31},
|
|
||||||
{"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0},
|
|
||||||
{"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
r := &Resolver{cacheTTL: tc.cacheTTL}
|
|
||||||
got := r.responseTTL(tc.cachedAt)
|
|
||||||
assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin")
|
|
||||||
assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_ExtractDomainFromURL(t *testing.T) {
|
func TestResolver_ExtractDomainFromURL(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -212,7 +212,6 @@ func newDefaultServer(
|
|||||||
ctx, stop := context.WithCancel(ctx)
|
ctx, stop := context.WithCancel(ctx)
|
||||||
|
|
||||||
mgmtCacheResolver := mgmt.NewResolver()
|
mgmtCacheResolver := mgmt.NewResolver()
|
||||||
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
|
|
||||||
|
|
||||||
defaultServer := &DefaultServer{
|
defaultServer := &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
|
||||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
@@ -118,11 +117,13 @@ type EngineConfig struct {
|
|||||||
RosenpassPermissive bool
|
RosenpassPermissive bool
|
||||||
|
|
||||||
ServerSSHAllowed bool
|
ServerSSHAllowed bool
|
||||||
|
ServerVNCAllowed bool
|
||||||
EnableSSHRoot *bool
|
EnableSSHRoot *bool
|
||||||
EnableSSHSFTP *bool
|
EnableSSHSFTP *bool
|
||||||
EnableSSHLocalPortForwarding *bool
|
EnableSSHLocalPortForwarding *bool
|
||||||
EnableSSHRemotePortForwarding *bool
|
EnableSSHRemotePortForwarding *bool
|
||||||
DisableSSHAuth *bool
|
DisableSSHAuth *bool
|
||||||
|
DisableVNCAuth *bool
|
||||||
|
|
||||||
DNSRouteInterval time.Duration
|
DNSRouteInterval time.Duration
|
||||||
|
|
||||||
@@ -199,6 +200,7 @@ type Engine struct {
|
|||||||
networkMonitor *networkmonitor.NetworkMonitor
|
networkMonitor *networkmonitor.NetworkMonitor
|
||||||
|
|
||||||
sshServer sshServer
|
sshServer sshServer
|
||||||
|
vncSrv vncServer
|
||||||
|
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
|
|
||||||
@@ -312,6 +314,10 @@ func (e *Engine) Stop() error {
|
|||||||
log.Warnf("failed to stop SSH server: %v", err)
|
log.Warnf("failed to stop SSH server: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := e.stopVNCServer(); err != nil {
|
||||||
|
log.Warnf("failed to stop VNC server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
e.cleanupSSHConfig()
|
e.cleanupSSHConfig()
|
||||||
|
|
||||||
if e.ingressGatewayMgr != nil {
|
if e.ingressGatewayMgr != nil {
|
||||||
@@ -571,7 +577,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
e.connMgr.Start(e.ctx)
|
e.connMgr.Start(e.ctx)
|
||||||
|
|
||||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||||
e.srWatcher.Start(peer.IsForceRelayed())
|
e.srWatcher.Start()
|
||||||
|
|
||||||
e.receiveSignalEvents()
|
e.receiveSignalEvents()
|
||||||
e.receiveManagementEvents()
|
e.receiveManagementEvents()
|
||||||
@@ -605,8 +611,6 @@ func (e *Engine) createFirewall() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
firewalld.SetParentContext(e.ctx)
|
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -944,12 +948,7 @@ func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
|
|||||||
return fmt.Errorf("update relay token: %w", err)
|
return fmt.Errorf("update relay token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
urls := update.Urls
|
e.relayManager.UpdateServerURLs(update.Urls)
|
||||||
if override, ok := peer.OverrideRelayURLs(); ok {
|
|
||||||
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
|
|
||||||
urls = override
|
|
||||||
}
|
|
||||||
e.relayManager.UpdateServerURLs(urls)
|
|
||||||
|
|
||||||
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
|
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
|
||||||
// We can ignore all errors because the guard will manage the reconnection retries.
|
// We can ignore all errors because the guard will manage the reconnection retries.
|
||||||
@@ -1006,6 +1005,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
|||||||
e.config.RosenpassEnabled,
|
e.config.RosenpassEnabled,
|
||||||
e.config.RosenpassPermissive,
|
e.config.RosenpassPermissive,
|
||||||
&e.config.ServerSSHAllowed,
|
&e.config.ServerSSHAllowed,
|
||||||
|
&e.config.ServerVNCAllowed,
|
||||||
e.config.DisableClientRoutes,
|
e.config.DisableClientRoutes,
|
||||||
e.config.DisableServerRoutes,
|
e.config.DisableServerRoutes,
|
||||||
e.config.DisableDNS,
|
e.config.DisableDNS,
|
||||||
@@ -1018,6 +1018,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
|||||||
e.config.EnableSSHLocalPortForwarding,
|
e.config.EnableSSHLocalPortForwarding,
|
||||||
e.config.EnableSSHRemotePortForwarding,
|
e.config.EnableSSHRemotePortForwarding,
|
||||||
e.config.DisableSSHAuth,
|
e.config.DisableSSHAuth,
|
||||||
|
e.config.DisableVNCAuth,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||||
@@ -1045,6 +1046,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := e.updateVNC(conf.GetSshConfig()); err != nil {
|
||||||
|
log.Warnf("failed handling VNC server setup: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
state := e.statusRecorder.GetLocalPeerState()
|
state := e.statusRecorder.GetLocalPeerState()
|
||||||
state.IP = e.wgInterface.Address().String()
|
state.IP = e.wgInterface.Address().String()
|
||||||
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
||||||
@@ -1147,6 +1152,7 @@ func (e *Engine) receiveManagementEvents() {
|
|||||||
e.config.RosenpassEnabled,
|
e.config.RosenpassEnabled,
|
||||||
e.config.RosenpassPermissive,
|
e.config.RosenpassPermissive,
|
||||||
&e.config.ServerSSHAllowed,
|
&e.config.ServerSSHAllowed,
|
||||||
|
&e.config.ServerVNCAllowed,
|
||||||
e.config.DisableClientRoutes,
|
e.config.DisableClientRoutes,
|
||||||
e.config.DisableServerRoutes,
|
e.config.DisableServerRoutes,
|
||||||
e.config.DisableDNS,
|
e.config.DisableDNS,
|
||||||
@@ -1159,6 +1165,7 @@ func (e *Engine) receiveManagementEvents() {
|
|||||||
e.config.EnableSSHLocalPortForwarding,
|
e.config.EnableSSHLocalPortForwarding,
|
||||||
e.config.EnableSSHRemotePortForwarding,
|
e.config.EnableSSHRemotePortForwarding,
|
||||||
e.config.DisableSSHAuth,
|
e.config.DisableSSHAuth,
|
||||||
|
e.config.DisableVNCAuth,
|
||||||
)
|
)
|
||||||
|
|
||||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||||
@@ -1333,6 +1340,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||||
|
|
||||||
|
// VNC auth: use dedicated VNCAuth if present.
|
||||||
|
if vncAuth := networkMap.GetVncAuth(); vncAuth != nil {
|
||||||
|
e.updateVNCServerAuth(vncAuth)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||||
@@ -1742,6 +1754,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
|||||||
e.config.RosenpassEnabled,
|
e.config.RosenpassEnabled,
|
||||||
e.config.RosenpassPermissive,
|
e.config.RosenpassPermissive,
|
||||||
&e.config.ServerSSHAllowed,
|
&e.config.ServerSSHAllowed,
|
||||||
|
&e.config.ServerVNCAllowed,
|
||||||
e.config.DisableClientRoutes,
|
e.config.DisableClientRoutes,
|
||||||
e.config.DisableServerRoutes,
|
e.config.DisableServerRoutes,
|
||||||
e.config.DisableDNS,
|
e.config.DisableDNS,
|
||||||
@@ -1754,6 +1767,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
|||||||
e.config.EnableSSHLocalPortForwarding,
|
e.config.EnableSSHLocalPortForwarding,
|
||||||
e.config.EnableSSHRemotePortForwarding,
|
e.config.EnableSSHRemotePortForwarding,
|
||||||
e.config.DisableSSHAuth,
|
e.config.DisableSSHAuth,
|
||||||
|
e.config.DisableVNCAuth,
|
||||||
)
|
)
|
||||||
|
|
||||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
netMap, err := e.mgmClient.GetNetworkMap(info)
|
||||||
|
|||||||
@@ -1671,7 +1671,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
309
client/internal/engine_vnc.go
Normal file
309
client/internal/engine_vnc.go
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
|
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||||
|
)
|
||||||
|
|
||||||
|
const envVNCForceRecording = "NB_VNC_FORCE_RECORDING"
|
||||||
|
|
||||||
|
const (
|
||||||
|
vncExternalPort uint16 = 5900
|
||||||
|
vncInternalPort uint16 = 25900
|
||||||
|
)
|
||||||
|
|
||||||
|
type vncServer interface {
|
||||||
|
Start(ctx context.Context, addr netip.AddrPort, network netip.Prefix) error
|
||||||
|
Stop() error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) setupVNCPortRedirection() error {
|
||||||
|
if e.firewall == nil || e.wgInterface == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
localAddr := e.wgInterface.Address().IP
|
||||||
|
if !localAddr.IsValid() {
|
||||||
|
return errors.New("invalid local NetBird address")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, vncExternalPort, vncInternalPort); err != nil {
|
||||||
|
return fmt.Errorf("add VNC port redirection: %w", err)
|
||||||
|
}
|
||||||
|
log.Infof("VNC port redirection: %s:%d -> %s:%d", localAddr, vncExternalPort, localAddr, vncInternalPort)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) cleanupVNCPortRedirection() error {
|
||||||
|
if e.firewall == nil || e.wgInterface == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
localAddr := e.wgInterface.Address().IP
|
||||||
|
if !localAddr.IsValid() {
|
||||||
|
return errors.New("invalid local NetBird address")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, vncExternalPort, vncInternalPort); err != nil {
|
||||||
|
return fmt.Errorf("remove VNC port redirection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateVNC handles starting/stopping the VNC server based on the config flag.
|
||||||
|
// sshConf provides the JWT identity provider config (shared with SSH).
|
||||||
|
func (e *Engine) updateVNC(sshConf *mgmProto.SSHConfig) error {
|
||||||
|
if !e.config.ServerVNCAllowed {
|
||||||
|
if e.vncSrv != nil {
|
||||||
|
log.Info("VNC server disabled, stopping")
|
||||||
|
}
|
||||||
|
return e.stopVNCServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.config.BlockInbound {
|
||||||
|
log.Info("VNC server disabled because inbound connections are blocked")
|
||||||
|
return e.stopVNCServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.vncSrv != nil {
|
||||||
|
// Update JWT config on existing server in case management sent new config.
|
||||||
|
e.updateVNCServerJWT(sshConf)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.startVNCServer(sshConf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) startVNCServer(sshConf *mgmProto.SSHConfig) error {
|
||||||
|
if e.wgInterface == nil {
|
||||||
|
return errors.New("wg interface not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
capturer, injector := newPlatformVNC()
|
||||||
|
if capturer == nil || injector == nil {
|
||||||
|
log.Debug("VNC server not supported on this platform")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
netbirdIP := e.wgInterface.Address().IP
|
||||||
|
|
||||||
|
srv := vncserver.New(capturer, injector, "")
|
||||||
|
if vncNeedsServiceMode() {
|
||||||
|
log.Info("VNC: running in Session 0, enabling service mode (agent proxy)")
|
||||||
|
srv.SetServiceMode(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure VNC authentication.
|
||||||
|
if e.config.DisableVNCAuth != nil && *e.config.DisableVNCAuth {
|
||||||
|
log.Info("VNC: authentication disabled by config")
|
||||||
|
srv.SetDisableAuth(true)
|
||||||
|
} else if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
|
||||||
|
audiences := protoJWT.GetAudiences()
|
||||||
|
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
|
||||||
|
audiences = []string{protoJWT.GetAudience()}
|
||||||
|
}
|
||||||
|
srv.SetJWTConfig(&vncserver.JWTConfig{
|
||||||
|
Issuer: protoJWT.GetIssuer(),
|
||||||
|
Audiences: audiences,
|
||||||
|
KeysLocation: protoJWT.GetKeysLocation(),
|
||||||
|
MaxTokenAge: protoJWT.GetMaxTokenAge(),
|
||||||
|
})
|
||||||
|
log.Debugf("VNC: JWT authentication configured (issuer=%s)", protoJWT.GetIssuer())
|
||||||
|
}
|
||||||
|
|
||||||
|
e.configureVNCRecording(srv, sshConf)
|
||||||
|
|
||||||
|
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||||
|
srv.SetNetstackNet(netstackNet)
|
||||||
|
}
|
||||||
|
|
||||||
|
listenAddr := netip.AddrPortFrom(netbirdIP, vncInternalPort)
|
||||||
|
network := e.wgInterface.Address().Network
|
||||||
|
if err := srv.Start(e.ctx, listenAddr, network); err != nil {
|
||||||
|
return fmt.Errorf("start VNC server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.vncSrv = srv
|
||||||
|
|
||||||
|
if registrar, ok := e.firewall.(interface {
|
||||||
|
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||||
|
}); ok {
|
||||||
|
registrar.RegisterNetstackService(nftypes.TCP, vncInternalPort)
|
||||||
|
log.Debugf("registered VNC service for TCP:%d", vncInternalPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.setupVNCPortRedirection(); err != nil {
|
||||||
|
log.Warnf("setup VNC port redirection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("VNC server enabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// configureVNCRecording enables session recording on the VNC server from the
|
||||||
|
// management-supplied settings. The env var NB_VNC_FORCE_RECORDING overrides
|
||||||
|
// the API for local development: when set, recording is always enabled and
|
||||||
|
// writes into that directory. Otherwise recordings go next to the state file
|
||||||
|
// under vnc-recordings/.
|
||||||
|
func (e *Engine) configureVNCRecording(srv *vncserver.Server, sshConf *mgmProto.SSHConfig) {
|
||||||
|
recDir := os.Getenv(envVNCForceRecording)
|
||||||
|
apiEnabled := sshConf.GetEnableRecording()
|
||||||
|
|
||||||
|
if recDir == "" && !apiEnabled {
|
||||||
|
log.Debugf("VNC recording disabled (env=%q, api=%v)", recDir, apiEnabled)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if recDir == "" {
|
||||||
|
base := e.defaultRecordingBase()
|
||||||
|
if base == "" {
|
||||||
|
log.Warn("VNC recording requested by management but no state directory is available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
recDir = filepath.Join(base, "vnc-recordings")
|
||||||
|
} else {
|
||||||
|
recDir = filepath.Join(recDir, "vnc")
|
||||||
|
}
|
||||||
|
|
||||||
|
srv.SetRecordingDir(recDir)
|
||||||
|
log.Infof("VNC recording enabled (dir=%s, source=%s)", recDir, recordingSource(apiEnabled))
|
||||||
|
|
||||||
|
encKey := string(sshConf.GetRecordingEncryptionKey())
|
||||||
|
if encKey == "" {
|
||||||
|
encKey = os.Getenv("NB_VNC_RECORDING_ENCRYPTION_KEY")
|
||||||
|
}
|
||||||
|
if encKey != "" {
|
||||||
|
srv.SetRecordingEncryptionKey(encKey)
|
||||||
|
log.Info("VNC recording encryption enabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) defaultRecordingBase() string {
|
||||||
|
if e.stateManager == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
p := e.stateManager.FilePath()
|
||||||
|
if p == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return filepath.Dir(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func recordingSource(api bool) string {
|
||||||
|
if api {
|
||||||
|
return "management"
|
||||||
|
}
|
||||||
|
return "env"
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateVNCServerJWT configures the JWT validation for the VNC server using
|
||||||
|
// the same JWT config as SSH (same identity provider).
|
||||||
|
func (e *Engine) updateVNCServerJWT(sshConf *mgmProto.SSHConfig) {
|
||||||
|
if e.vncSrv == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
vncSrv, ok := e.vncSrv.(*vncserver.Server)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.config.DisableVNCAuth != nil && *e.config.DisableVNCAuth {
|
||||||
|
vncSrv.SetDisableAuth(true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
protoJWT := sshConf.GetJwtConfig()
|
||||||
|
if protoJWT == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
audiences := protoJWT.GetAudiences()
|
||||||
|
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
|
||||||
|
audiences = []string{protoJWT.GetAudience()}
|
||||||
|
}
|
||||||
|
|
||||||
|
vncSrv.SetJWTConfig(&vncserver.JWTConfig{
|
||||||
|
Issuer: protoJWT.GetIssuer(),
|
||||||
|
Audiences: audiences,
|
||||||
|
KeysLocation: protoJWT.GetKeysLocation(),
|
||||||
|
MaxTokenAge: protoJWT.GetMaxTokenAge(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateVNCServerAuth updates VNC fine-grained access control from management.
|
||||||
|
func (e *Engine) updateVNCServerAuth(vncAuth *mgmProto.VNCAuth) {
|
||||||
|
if vncAuth == nil || e.vncSrv == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
vncSrv, ok := e.vncSrv.(*vncserver.Server)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
protoUsers := vncAuth.GetAuthorizedUsers()
|
||||||
|
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
|
||||||
|
for i, hash := range protoUsers {
|
||||||
|
if len(hash) != 16 {
|
||||||
|
log.Warnf("invalid VNC auth hash length %d, expected 16", len(hash))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
machineUsers := make(map[string][]uint32)
|
||||||
|
for osUser, indexes := range vncAuth.GetMachineUsers() {
|
||||||
|
machineUsers[osUser] = indexes.GetIndexes()
|
||||||
|
}
|
||||||
|
|
||||||
|
vncSrv.UpdateVNCAuth(&sshauth.Config{
|
||||||
|
UserIDClaim: vncAuth.GetUserIDClaim(),
|
||||||
|
AuthorizedUsers: authorizedUsers,
|
||||||
|
MachineUsers: machineUsers,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetVNCServerStatus returns whether the VNC server is running.
|
||||||
|
func (e *Engine) GetVNCServerStatus() bool {
|
||||||
|
return e.vncSrv != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) stopVNCServer() error {
|
||||||
|
if e.vncSrv == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.cleanupVNCPortRedirection(); err != nil {
|
||||||
|
log.Warnf("cleanup VNC port redirection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if registrar, ok := e.firewall.(interface {
|
||||||
|
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||||
|
}); ok {
|
||||||
|
registrar.UnregisterNetstackService(nftypes.TCP, vncInternalPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("stopping VNC server")
|
||||||
|
err := e.vncSrv.Stop()
|
||||||
|
e.vncSrv = nil
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("stop VNC server: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
23
client/internal/engine_vnc_darwin.go
Normal file
23
client/internal/engine_vnc_darwin.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
//go:build darwin && !ios
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
|
||||||
|
capturer := vncserver.NewMacPoller()
|
||||||
|
injector, err := vncserver.NewMacInputInjector()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("VNC: macOS input injector: %v", err)
|
||||||
|
return capturer, &vncserver.StubInputInjector{}
|
||||||
|
}
|
||||||
|
return capturer, injector
|
||||||
|
}
|
||||||
|
|
||||||
|
func vncNeedsServiceMode() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
13
client/internal/engine_vnc_stub.go
Normal file
13
client/internal/engine_vnc_stub.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
//go:build !windows && !darwin && !freebsd && !(linux && !android)
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||||
|
|
||||||
|
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func vncNeedsServiceMode() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
13
client/internal/engine_vnc_windows.go
Normal file
13
client/internal/engine_vnc_windows.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||||
|
|
||||||
|
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
|
||||||
|
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector()
|
||||||
|
}
|
||||||
|
|
||||||
|
func vncNeedsServiceMode() bool {
|
||||||
|
return vncserver.GetCurrentSessionID() == 0
|
||||||
|
}
|
||||||
23
client/internal/engine_vnc_x11.go
Normal file
23
client/internal/engine_vnc_x11.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
//go:build (linux && !android) || freebsd
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
|
||||||
|
capturer := vncserver.NewX11Poller("")
|
||||||
|
injector, err := vncserver.NewX11InputInjector("")
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("VNC: X11 input injector: %v", err)
|
||||||
|
return capturer, &vncserver.StubInputInjector{}
|
||||||
|
}
|
||||||
|
return capturer, injector
|
||||||
|
}
|
||||||
|
|
||||||
|
func vncNeedsServiceMode() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ package activity
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -17,6 +18,10 @@ import (
|
|||||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func isBindListenerPlatform() bool {
|
||||||
|
return runtime.GOOS == "windows" || runtime.GOOS == "js"
|
||||||
|
}
|
||||||
|
|
||||||
// mockEndpointManager implements device.EndpointManager for testing
|
// mockEndpointManager implements device.EndpointManager for testing
|
||||||
type mockEndpointManager struct {
|
type mockEndpointManager struct {
|
||||||
endpoints map[netip.Addr]net.Conn
|
endpoints map[netip.Addr]net.Conn
|
||||||
@@ -176,6 +181,10 @@ func TestBindListener_Close(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_BindMode(t *testing.T) {
|
func TestManager_BindMode(t *testing.T) {
|
||||||
|
if !isBindListenerPlatform() {
|
||||||
|
t.Skip("BindListener only used on Windows/JS platforms")
|
||||||
|
}
|
||||||
|
|
||||||
mockEndpointMgr := newMockEndpointManager()
|
mockEndpointMgr := newMockEndpointManager()
|
||||||
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||||
|
|
||||||
@@ -217,6 +226,10 @@ func TestManager_BindMode(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_BindMode_MultiplePeers(t *testing.T) {
|
func TestManager_BindMode_MultiplePeers(t *testing.T) {
|
||||||
|
if !isBindListenerPlatform() {
|
||||||
|
t.Skip("BindListener only used on Windows/JS platforms")
|
||||||
|
}
|
||||||
|
|
||||||
mockEndpointMgr := newMockEndpointManager()
|
mockEndpointMgr := newMockEndpointManager()
|
||||||
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||||
|
|
||||||
|
|||||||
@@ -4,12 +4,14 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
@@ -73,6 +75,16 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error)
|
|||||||
return NewUDPListener(m.wgIface, peerCfg)
|
return NewUDPListener(m.wgIface, peerCfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BindListener is used on Windows, JS, and netstack platforms:
|
||||||
|
// - JS: Cannot listen to UDP sockets
|
||||||
|
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
|
||||||
|
// gateway points to, preventing them from reaching the loopback interface.
|
||||||
|
// - Netstack: Allows multiple instances on the same host without port conflicts.
|
||||||
|
// BindListener bypasses these issues by passing data directly through the bind.
|
||||||
|
if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() {
|
||||||
|
return NewUDPListener(m.wgIface, peerCfg)
|
||||||
|
}
|
||||||
|
|
||||||
provider, ok := m.wgIface.(bindProvider)
|
provider, ok := m.wgIface.(bindProvider)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider")
|
return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider")
|
||||||
|
|||||||
@@ -185,20 +185,17 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
|||||||
|
|
||||||
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
|
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
|
||||||
|
|
||||||
forceRelay := IsForceRelayed()
|
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||||
if !forceRelay {
|
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
if err != nil {
|
||||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
return err
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
conn.workerICE = workerICE
|
|
||||||
}
|
}
|
||||||
|
conn.workerICE = workerICE
|
||||||
|
|
||||||
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages)
|
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages)
|
||||||
|
|
||||||
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
|
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
|
||||||
if !forceRelay {
|
if !isForceRelayed() {
|
||||||
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
|
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -254,9 +251,7 @@ func (conn *Conn) Close(signalToRemote bool) {
|
|||||||
conn.wgWatcherCancel()
|
conn.wgWatcherCancel()
|
||||||
}
|
}
|
||||||
conn.workerRelay.CloseConn()
|
conn.workerRelay.CloseConn()
|
||||||
if conn.workerICE != nil {
|
conn.workerICE.Close()
|
||||||
conn.workerICE.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if conn.wgProxyRelay != nil {
|
if conn.wgProxyRelay != nil {
|
||||||
err := conn.wgProxyRelay.CloseConn()
|
err := conn.wgProxyRelay.CloseConn()
|
||||||
@@ -299,9 +294,7 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) {
|
|||||||
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
||||||
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
|
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
|
||||||
conn.dumpState.RemoteCandidate()
|
conn.dumpState.RemoteCandidate()
|
||||||
if conn.workerICE != nil {
|
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
||||||
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
|
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
|
||||||
@@ -719,35 +712,33 @@ func (conn *Conn) evalStatus() ConnStatus {
|
|||||||
return StatusConnecting
|
return StatusConnecting
|
||||||
}
|
}
|
||||||
|
|
||||||
// isConnectedOnAllWay evaluates the overall connection status based on ICE and Relay transports.
|
func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||||
//
|
// would be better to protect this with a mutex, but it could cause deadlock with Close function
|
||||||
// The result is a tri-state:
|
|
||||||
// - ConnStatusConnected: all available transports are up
|
|
||||||
// - ConnStatusPartiallyConnected: relay is up but ICE is still pending/reconnecting
|
|
||||||
// - ConnStatusDisconnected: no working transport
|
|
||||||
func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) {
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if status == guard.ConnStatusDisconnected {
|
if !connected {
|
||||||
conn.logTraceConnState()
|
conn.logTraceConnState()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
iceWorkerCreated := conn.workerICE != nil
|
// For JS platform: only relay connection is supported
|
||||||
|
if runtime.GOOS == "js" {
|
||||||
var iceInProgress bool
|
return conn.statusRelay.Get() == worker.StatusConnected
|
||||||
if iceWorkerCreated {
|
|
||||||
iceInProgress = conn.workerICE.InProgress()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return evalConnStatus(connStatusInputs{
|
// For non-JS platforms: check ICE connection status
|
||||||
forceRelay: IsForceRelayed(),
|
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
||||||
peerUsesRelay: conn.workerRelay.IsRelayConnectionSupportedWithPeer(),
|
return false
|
||||||
relayConnected: conn.statusRelay.Get() == worker.StatusConnected,
|
}
|
||||||
remoteSupportsICE: conn.handshaker.RemoteICESupported(),
|
|
||||||
iceWorkerCreated: iceWorkerCreated,
|
// If relay is supported with peer, it must also be connected
|
||||||
iceStatusConnecting: conn.statusICE.Get() != worker.StatusDisconnected,
|
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||||
iceInProgress: iceInProgress,
|
if conn.statusRelay.Get() == worker.StatusDisconnected {
|
||||||
})
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
|
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
|
||||||
@@ -935,43 +926,3 @@ func isController(config ConnConfig) bool {
|
|||||||
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
||||||
return remoteRosenpassPubKey != nil
|
return remoteRosenpassPubKey != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func evalConnStatus(in connStatusInputs) guard.ConnStatus {
|
|
||||||
// "Relay up and needed" — the peer uses relay and the transport is connected.
|
|
||||||
relayUsedAndUp := in.peerUsesRelay && in.relayConnected
|
|
||||||
|
|
||||||
// Force-relay mode: ICE never runs. Relay is the only transport and must be up.
|
|
||||||
if in.forceRelay {
|
|
||||||
return boolToConnStatus(relayUsedAndUp)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remote peer doesn't support ICE, or we haven't created the worker yet:
|
|
||||||
// relay is the only possible transport.
|
|
||||||
if !in.remoteSupportsICE || !in.iceWorkerCreated {
|
|
||||||
return boolToConnStatus(relayUsedAndUp)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ICE counts as "up" when the status is anything other than Disconnected, OR
|
|
||||||
// when a negotiation is currently in progress (so we don't spam offers while one is in flight).
|
|
||||||
iceUp := in.iceStatusConnecting || in.iceInProgress
|
|
||||||
|
|
||||||
// Relay side is acceptable if the peer doesn't rely on relay, or relay is connected.
|
|
||||||
relayOK := !in.peerUsesRelay || in.relayConnected
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case iceUp && relayOK:
|
|
||||||
return guard.ConnStatusConnected
|
|
||||||
case relayUsedAndUp:
|
|
||||||
// Relay is up but ICE is down — partially connected.
|
|
||||||
return guard.ConnStatusPartiallyConnected
|
|
||||||
default:
|
|
||||||
return guard.ConnStatusDisconnected
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func boolToConnStatus(connected bool) guard.ConnStatus {
|
|
||||||
if connected {
|
|
||||||
return guard.ConnStatusConnected
|
|
||||||
}
|
|
||||||
return guard.ConnStatusDisconnected
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -13,20 +13,6 @@ const (
|
|||||||
StatusConnected
|
StatusConnected
|
||||||
)
|
)
|
||||||
|
|
||||||
// connStatusInputs is the primitive-valued snapshot of the state that drives the
|
|
||||||
// tri-state connection classification. Extracted so the decision logic can be unit-tested
|
|
||||||
// without constructing full Worker/Handshaker objects.
|
|
||||||
type connStatusInputs struct {
|
|
||||||
forceRelay bool // NB_FORCE_RELAY or JS/WASM
|
|
||||||
peerUsesRelay bool // remote peer advertises relay support AND local has relay
|
|
||||||
relayConnected bool // statusRelay reports Connected (independent of whether peer uses relay)
|
|
||||||
remoteSupportsICE bool // remote peer sent ICE credentials
|
|
||||||
iceWorkerCreated bool // local WorkerICE exists (false in force-relay mode)
|
|
||||||
iceStatusConnecting bool // statusICE is anything other than Disconnected
|
|
||||||
iceInProgress bool // a negotiation is currently in flight
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// ConnStatus describe the status of a peer's connection
|
// ConnStatus describe the status of a peer's connection
|
||||||
type ConnStatus int32
|
type ConnStatus int32
|
||||||
|
|
||||||
|
|||||||
@@ -1,201 +0,0 @@
|
|||||||
package peer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestEvalConnStatus_ForceRelay(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
in connStatusInputs
|
|
||||||
want guard.ConnStatus
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "force relay, peer uses relay, relay up",
|
|
||||||
in: connStatusInputs{
|
|
||||||
forceRelay: true,
|
|
||||||
peerUsesRelay: true,
|
|
||||||
relayConnected: true,
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusConnected,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "force relay, peer uses relay, relay down",
|
|
||||||
in: connStatusInputs{
|
|
||||||
forceRelay: true,
|
|
||||||
peerUsesRelay: true,
|
|
||||||
relayConnected: false,
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusDisconnected,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "force relay, peer does NOT use relay - disconnected forever",
|
|
||||||
in: connStatusInputs{
|
|
||||||
forceRelay: true,
|
|
||||||
peerUsesRelay: false,
|
|
||||||
relayConnected: true,
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusDisconnected,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
if got := evalConnStatus(tc.in); got != tc.want {
|
|
||||||
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEvalConnStatus_ICEUnavailable(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
in connStatusInputs
|
|
||||||
want guard.ConnStatus
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "remote does not support ICE, peer uses relay, relay up",
|
|
||||||
in: connStatusInputs{
|
|
||||||
peerUsesRelay: true,
|
|
||||||
relayConnected: true,
|
|
||||||
remoteSupportsICE: false,
|
|
||||||
iceWorkerCreated: true,
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusConnected,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "remote does not support ICE, peer uses relay, relay down",
|
|
||||||
in: connStatusInputs{
|
|
||||||
peerUsesRelay: true,
|
|
||||||
relayConnected: false,
|
|
||||||
remoteSupportsICE: false,
|
|
||||||
iceWorkerCreated: true,
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusDisconnected,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ICE worker not yet created, relay up",
|
|
||||||
in: connStatusInputs{
|
|
||||||
peerUsesRelay: true,
|
|
||||||
relayConnected: true,
|
|
||||||
remoteSupportsICE: true,
|
|
||||||
iceWorkerCreated: false,
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusConnected,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "remote does not support ICE, peer does not use relay",
|
|
||||||
in: connStatusInputs{
|
|
||||||
peerUsesRelay: false,
|
|
||||||
relayConnected: false,
|
|
||||||
remoteSupportsICE: false,
|
|
||||||
iceWorkerCreated: true,
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusDisconnected,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
if got := evalConnStatus(tc.in); got != tc.want {
|
|
||||||
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEvalConnStatus_FullyAvailable(t *testing.T) {
|
|
||||||
base := connStatusInputs{
|
|
||||||
remoteSupportsICE: true,
|
|
||||||
iceWorkerCreated: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
mutator func(*connStatusInputs)
|
|
||||||
want guard.ConnStatus
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "ICE connected, relay connected, peer uses relay",
|
|
||||||
mutator: func(in *connStatusInputs) {
|
|
||||||
in.peerUsesRelay = true
|
|
||||||
in.relayConnected = true
|
|
||||||
in.iceStatusConnecting = true
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusConnected,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ICE connected, peer does NOT use relay",
|
|
||||||
mutator: func(in *connStatusInputs) {
|
|
||||||
in.peerUsesRelay = false
|
|
||||||
in.relayConnected = false
|
|
||||||
in.iceStatusConnecting = true
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusConnected,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ICE InProgress only, peer does NOT use relay",
|
|
||||||
mutator: func(in *connStatusInputs) {
|
|
||||||
in.peerUsesRelay = false
|
|
||||||
in.iceStatusConnecting = false
|
|
||||||
in.iceInProgress = true
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusConnected,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ICE down, relay up, peer uses relay -> partial",
|
|
||||||
mutator: func(in *connStatusInputs) {
|
|
||||||
in.peerUsesRelay = true
|
|
||||||
in.relayConnected = true
|
|
||||||
in.iceStatusConnecting = false
|
|
||||||
in.iceInProgress = false
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusPartiallyConnected,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ICE down, peer does NOT use relay -> disconnected",
|
|
||||||
mutator: func(in *connStatusInputs) {
|
|
||||||
in.peerUsesRelay = false
|
|
||||||
in.relayConnected = false
|
|
||||||
in.iceStatusConnecting = false
|
|
||||||
in.iceInProgress = false
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusDisconnected,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ICE up, peer uses relay but relay down -> partial (relay required, ICE ignored)",
|
|
||||||
mutator: func(in *connStatusInputs) {
|
|
||||||
in.peerUsesRelay = true
|
|
||||||
in.relayConnected = false
|
|
||||||
in.iceStatusConnecting = true
|
|
||||||
},
|
|
||||||
// relayOK = false (peer uses relay but it's down), iceUp = true
|
|
||||||
// first switch arm fails (relayOK false), relayUsedAndUp = false (relay down),
|
|
||||||
// falls into default: Disconnected.
|
|
||||||
want: guard.ConnStatusDisconnected,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ICE down, relay up but peer does not use relay -> disconnected",
|
|
||||||
mutator: func(in *connStatusInputs) {
|
|
||||||
in.peerUsesRelay = false
|
|
||||||
in.relayConnected = true // not actually used since peer doesn't rely on it
|
|
||||||
in.iceStatusConnecting = false
|
|
||||||
in.iceInProgress = false
|
|
||||||
},
|
|
||||||
want: guard.ConnStatusDisconnected,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
in := base
|
|
||||||
tc.mutator(&in)
|
|
||||||
if got := evalConnStatus(in); got != tc.want {
|
|
||||||
t.Fatalf("evalConnStatus = %v, want %v (inputs: %+v)", got, tc.want, in)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -7,38 +7,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
|
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
|
||||||
EnvKeyNBHomeRelayServers = "NB_HOME_RELAY_SERVERS"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func IsForceRelayed() bool {
|
func isForceRelayed() bool {
|
||||||
if runtime.GOOS == "js" {
|
if runtime.GOOS == "js" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
|
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
|
||||||
}
|
}
|
||||||
|
|
||||||
// OverrideRelayURLs returns the relay server URL list set in
|
|
||||||
// NB_HOME_RELAY_SERVERS (comma-separated) and a boolean indicating whether
|
|
||||||
// the override is active. When the env var is unset, the boolean is false
|
|
||||||
// and the caller should keep the list received from the management server.
|
|
||||||
// Intended for lab/debug scenarios where a peer must pin to a specific home
|
|
||||||
// relay regardless of what management offers.
|
|
||||||
func OverrideRelayURLs() ([]string, bool) {
|
|
||||||
raw := os.Getenv(EnvKeyNBHomeRelayServers)
|
|
||||||
if raw == "" {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
parts := strings.Split(raw, ",")
|
|
||||||
urls := make([]string, 0, len(parts))
|
|
||||||
for _, p := range parts {
|
|
||||||
p = strings.TrimSpace(p)
|
|
||||||
if p != "" {
|
|
||||||
urls = append(urls, p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(urls) == 0 {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
return urls, true
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -8,19 +8,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnStatus represents the connection state as seen by the guard.
|
type isConnectedFunc func() bool
|
||||||
type ConnStatus int
|
|
||||||
|
|
||||||
const (
|
|
||||||
// ConnStatusDisconnected means neither ICE nor Relay is connected.
|
|
||||||
ConnStatusDisconnected ConnStatus = iota
|
|
||||||
// ConnStatusPartiallyConnected means Relay is connected but ICE is not.
|
|
||||||
ConnStatusPartiallyConnected
|
|
||||||
// ConnStatusConnected means all required connections are established.
|
|
||||||
ConnStatusConnected
|
|
||||||
)
|
|
||||||
|
|
||||||
type connStatusFunc func() ConnStatus
|
|
||||||
|
|
||||||
// Guard is responsible for the reconnection logic.
|
// Guard is responsible for the reconnection logic.
|
||||||
// It will trigger to send an offer to the peer then has connection issues.
|
// It will trigger to send an offer to the peer then has connection issues.
|
||||||
@@ -32,14 +20,14 @@ type connStatusFunc func() ConnStatus
|
|||||||
// - ICE candidate changes
|
// - ICE candidate changes
|
||||||
type Guard struct {
|
type Guard struct {
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
isConnectedOnAllWay connStatusFunc
|
isConnectedOnAllWay isConnectedFunc
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
srWatcher *SRWatcher
|
srWatcher *SRWatcher
|
||||||
relayedConnDisconnected chan struct{}
|
relayedConnDisconnected chan struct{}
|
||||||
iCEConnDisconnected chan struct{}
|
iCEConnDisconnected chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewGuard(log *log.Entry, isConnectedFn connStatusFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||||
return &Guard{
|
return &Guard{
|
||||||
log: log,
|
log: log,
|
||||||
isConnectedOnAllWay: isConnectedFn,
|
isConnectedOnAllWay: isConnectedFn,
|
||||||
@@ -69,17 +57,8 @@ func (g *Guard) SetICEConnDisconnected() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// reconnectLoopWithRetry periodically checks the connection status and sends offers to re-establish connectivity.
|
// reconnectLoopWithRetry periodically check the connection status.
|
||||||
//
|
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
|
||||||
// Behavior depends on the connection state reported by isConnectedOnAllWay:
|
|
||||||
// - Connected: no action, the peer is fully reachable.
|
|
||||||
// - Disconnected (neither ICE nor Relay): retries aggressively with exponential backoff (800ms doubling
|
|
||||||
// up to timeout), never gives up. This ensures rapid recovery when the peer has no connectivity at all.
|
|
||||||
// - PartiallyConnected (Relay up, ICE not): retries up to 3 times with exponential backoff, then switches
|
|
||||||
// to one attempt per hour. This limits signaling traffic when relay already provides connectivity.
|
|
||||||
//
|
|
||||||
// External events (relay/ICE disconnect, signal/relay reconnect, candidate changes) reset the retry
|
|
||||||
// counter and backoff ticker, giving ICE a fresh chance after network conditions change.
|
|
||||||
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
||||||
srReconnectedChan := g.srWatcher.NewListener()
|
srReconnectedChan := g.srWatcher.NewListener()
|
||||||
defer g.srWatcher.RemoveListener(srReconnectedChan)
|
defer g.srWatcher.RemoveListener(srReconnectedChan)
|
||||||
@@ -89,47 +68,36 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
|||||||
|
|
||||||
tickerChannel := ticker.C
|
tickerChannel := ticker.C
|
||||||
|
|
||||||
iceState := &iceRetryState{log: g.log}
|
|
||||||
defer iceState.reset()
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-tickerChannel:
|
case t := <-tickerChannel:
|
||||||
switch g.isConnectedOnAllWay() {
|
if t.IsZero() {
|
||||||
case ConnStatusConnected:
|
g.log.Infof("retry timed out, stop periodic offer sending")
|
||||||
// all good, nothing to do
|
// after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop
|
||||||
case ConnStatusDisconnected:
|
tickerChannel = make(<-chan time.Time)
|
||||||
callback()
|
continue
|
||||||
case ConnStatusPartiallyConnected:
|
|
||||||
if iceState.shouldRetry() {
|
|
||||||
callback()
|
|
||||||
} else {
|
|
||||||
iceState.enterHourlyMode()
|
|
||||||
ticker.Stop()
|
|
||||||
tickerChannel = iceState.hourlyC()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !g.isConnectedOnAllWay() {
|
||||||
|
callback()
|
||||||
|
}
|
||||||
case <-g.relayedConnDisconnected:
|
case <-g.relayedConnDisconnected:
|
||||||
g.log.Debugf("Relay connection changed, reset reconnection ticker")
|
g.log.Debugf("Relay connection changed, reset reconnection ticker")
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
ticker = g.newReconnectTicker(ctx)
|
ticker = g.prepareExponentTicker(ctx)
|
||||||
tickerChannel = ticker.C
|
tickerChannel = ticker.C
|
||||||
iceState.reset()
|
|
||||||
|
|
||||||
case <-g.iCEConnDisconnected:
|
case <-g.iCEConnDisconnected:
|
||||||
g.log.Debugf("ICE connection changed, reset reconnection ticker")
|
g.log.Debugf("ICE connection changed, reset reconnection ticker")
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
ticker = g.newReconnectTicker(ctx)
|
ticker = g.prepareExponentTicker(ctx)
|
||||||
tickerChannel = ticker.C
|
tickerChannel = ticker.C
|
||||||
iceState.reset()
|
|
||||||
|
|
||||||
case <-srReconnectedChan:
|
case <-srReconnectedChan:
|
||||||
g.log.Debugf("has network changes, reset reconnection ticker")
|
g.log.Debugf("has network changes, reset reconnection ticker")
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
ticker = g.newReconnectTicker(ctx)
|
ticker = g.prepareExponentTicker(ctx)
|
||||||
tickerChannel = ticker.C
|
tickerChannel = ticker.C
|
||||||
iceState.reset()
|
|
||||||
|
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
g.log.Debugf("context is done, stop reconnect loop")
|
g.log.Debugf("context is done, stop reconnect loop")
|
||||||
@@ -152,7 +120,7 @@ func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker {
|
|||||||
return backoff.NewTicker(bo)
|
return backoff.NewTicker(bo)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Guard) newReconnectTicker(ctx context.Context) *backoff.Ticker {
|
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
|
||||||
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
||||||
InitialInterval: 800 * time.Millisecond,
|
InitialInterval: 800 * time.Millisecond,
|
||||||
RandomizationFactor: 0.1,
|
RandomizationFactor: 0.1,
|
||||||
|
|||||||
@@ -1,61 +0,0 @@
|
|||||||
package guard
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// maxICERetries is the maximum number of ICE offer attempts when relay is connected
|
|
||||||
maxICERetries = 3
|
|
||||||
// iceRetryInterval is the periodic retry interval after ICE retries are exhausted
|
|
||||||
iceRetryInterval = 1 * time.Hour
|
|
||||||
)
|
|
||||||
|
|
||||||
// iceRetryState tracks the limited ICE retry attempts when relay is already connected.
|
|
||||||
// After maxICERetries attempts it switches to a periodic hourly retry.
|
|
||||||
type iceRetryState struct {
|
|
||||||
log *log.Entry
|
|
||||||
retries int
|
|
||||||
hourly *time.Ticker
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *iceRetryState) reset() {
|
|
||||||
s.retries = 0
|
|
||||||
if s.hourly != nil {
|
|
||||||
s.hourly.Stop()
|
|
||||||
s.hourly = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// shouldRetry reports whether the caller should send another ICE offer on this tick.
|
|
||||||
// Returns false when the per-cycle retry budget is exhausted and the caller must switch
|
|
||||||
// to the hourly ticker via enterHourlyMode + hourlyC.
|
|
||||||
func (s *iceRetryState) shouldRetry() bool {
|
|
||||||
if s.hourly != nil {
|
|
||||||
s.log.Debugf("hourly ICE retry attempt")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
s.retries++
|
|
||||||
if s.retries <= maxICERetries {
|
|
||||||
s.log.Debugf("ICE retry attempt %d/%d", s.retries, maxICERetries)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// enterHourlyMode starts the hourly retry ticker. Must be called after shouldRetry returns false.
|
|
||||||
func (s *iceRetryState) enterHourlyMode() {
|
|
||||||
s.log.Infof("ICE retries exhausted (%d/%d), switching to hourly retry", maxICERetries, maxICERetries)
|
|
||||||
s.hourly = time.NewTicker(iceRetryInterval)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *iceRetryState) hourlyC() <-chan time.Time {
|
|
||||||
if s.hourly == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return s.hourly.C
|
|
||||||
}
|
|
||||||
@@ -1,103 +0,0 @@
|
|||||||
package guard
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newTestRetryState() *iceRetryState {
|
|
||||||
return &iceRetryState{log: log.NewEntry(log.StandardLogger())}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestICERetryState_AllowsInitialBudget(t *testing.T) {
|
|
||||||
s := newTestRetryState()
|
|
||||||
|
|
||||||
for i := 1; i <= maxICERetries; i++ {
|
|
||||||
if !s.shouldRetry() {
|
|
||||||
t.Fatalf("shouldRetry returned false on attempt %d, want true (budget = %d)", i, maxICERetries)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestICERetryState_ExhaustsAfterBudget(t *testing.T) {
|
|
||||||
s := newTestRetryState()
|
|
||||||
|
|
||||||
for i := 0; i < maxICERetries; i++ {
|
|
||||||
_ = s.shouldRetry()
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.shouldRetry() {
|
|
||||||
t.Fatalf("shouldRetry returned true after budget exhausted, want false")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestICERetryState_HourlyCNilBeforeEnterHourlyMode(t *testing.T) {
|
|
||||||
s := newTestRetryState()
|
|
||||||
|
|
||||||
if s.hourlyC() != nil {
|
|
||||||
t.Fatalf("hourlyC returned non-nil channel before enterHourlyMode")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestICERetryState_EnterHourlyModeArmsTicker(t *testing.T) {
|
|
||||||
s := newTestRetryState()
|
|
||||||
for i := 0; i < maxICERetries+1; i++ {
|
|
||||||
_ = s.shouldRetry()
|
|
||||||
}
|
|
||||||
|
|
||||||
s.enterHourlyMode()
|
|
||||||
defer s.reset()
|
|
||||||
|
|
||||||
if s.hourlyC() == nil {
|
|
||||||
t.Fatalf("hourlyC returned nil after enterHourlyMode")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestICERetryState_ShouldRetryTrueInHourlyMode(t *testing.T) {
|
|
||||||
s := newTestRetryState()
|
|
||||||
s.enterHourlyMode()
|
|
||||||
defer s.reset()
|
|
||||||
|
|
||||||
if !s.shouldRetry() {
|
|
||||||
t.Fatalf("shouldRetry returned false in hourly mode, want true")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Subsequent calls also return true — we keep retrying on each hourly tick.
|
|
||||||
if !s.shouldRetry() {
|
|
||||||
t.Fatalf("second shouldRetry returned false in hourly mode, want true")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestICERetryState_ResetRestoresBudget(t *testing.T) {
|
|
||||||
s := newTestRetryState()
|
|
||||||
for i := 0; i < maxICERetries+1; i++ {
|
|
||||||
_ = s.shouldRetry()
|
|
||||||
}
|
|
||||||
s.enterHourlyMode()
|
|
||||||
|
|
||||||
s.reset()
|
|
||||||
|
|
||||||
if s.hourlyC() != nil {
|
|
||||||
t.Fatalf("hourlyC returned non-nil channel after reset")
|
|
||||||
}
|
|
||||||
if s.retries != 0 {
|
|
||||||
t.Fatalf("retries = %d after reset, want 0", s.retries)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 1; i <= maxICERetries; i++ {
|
|
||||||
if !s.shouldRetry() {
|
|
||||||
t.Fatalf("shouldRetry returned false on attempt %d after reset, want true", i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestICERetryState_ResetIsIdempotent(t *testing.T) {
|
|
||||||
s := newTestRetryState()
|
|
||||||
s.reset()
|
|
||||||
s.reset() // second call must not panic or re-stop a nil ticker
|
|
||||||
|
|
||||||
if s.hourlyC() != nil {
|
|
||||||
t.Fatalf("hourlyC non-nil after double reset")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -39,7 +39,7 @@ func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscove
|
|||||||
return srw
|
return srw
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *SRWatcher) Start(disableICEMonitor bool) {
|
func (w *SRWatcher) Start() {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
@@ -50,10 +50,8 @@ func (w *SRWatcher) Start(disableICEMonitor bool) {
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
w.cancelIceMonitor = cancel
|
w.cancelIceMonitor = cancel
|
||||||
|
|
||||||
if !disableICEMonitor {
|
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
|
||||||
}
|
|
||||||
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
||||||
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -44,10 +43,6 @@ type OfferAnswer struct {
|
|||||||
SessionID *ICESessionID
|
SessionID *ICESessionID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *OfferAnswer) hasICECredentials() bool {
|
|
||||||
return o.IceCredentials.UFrag != "" && o.IceCredentials.Pwd != ""
|
|
||||||
}
|
|
||||||
|
|
||||||
type Handshaker struct {
|
type Handshaker struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
@@ -64,10 +59,6 @@ type Handshaker struct {
|
|||||||
relayListener *AsyncOfferListener
|
relayListener *AsyncOfferListener
|
||||||
iceListener func(remoteOfferAnswer *OfferAnswer)
|
iceListener func(remoteOfferAnswer *OfferAnswer)
|
||||||
|
|
||||||
// remoteICESupported tracks whether the remote peer includes ICE credentials in its offers/answers.
|
|
||||||
// When false, the local side skips ICE listener dispatch and suppresses ICE credentials in responses.
|
|
||||||
remoteICESupported atomic.Bool
|
|
||||||
|
|
||||||
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
||||||
remoteOffersCh chan OfferAnswer
|
remoteOffersCh chan OfferAnswer
|
||||||
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
|
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
|
||||||
@@ -75,7 +66,7 @@ type Handshaker struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker {
|
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker {
|
||||||
h := &Handshaker{
|
return &Handshaker{
|
||||||
log: log,
|
log: log,
|
||||||
config: config,
|
config: config,
|
||||||
signaler: signaler,
|
signaler: signaler,
|
||||||
@@ -85,13 +76,6 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
|
|||||||
remoteOffersCh: make(chan OfferAnswer),
|
remoteOffersCh: make(chan OfferAnswer),
|
||||||
remoteAnswerCh: make(chan OfferAnswer),
|
remoteAnswerCh: make(chan OfferAnswer),
|
||||||
}
|
}
|
||||||
// assume remote supports ICE until we learn otherwise from received offers
|
|
||||||
h.remoteICESupported.Store(ice != nil)
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handshaker) RemoteICESupported() bool {
|
|
||||||
return h.remoteICESupported.Load()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||||
@@ -106,20 +90,18 @@ func (h *Handshaker) Listen(ctx context.Context) {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case remoteOfferAnswer := <-h.remoteOffersCh:
|
case remoteOfferAnswer := <-h.remoteOffersCh:
|
||||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
|
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||||
|
|
||||||
// Record signaling received for reconnection attempts
|
// Record signaling received for reconnection attempts
|
||||||
if h.metricsStages != nil {
|
if h.metricsStages != nil {
|
||||||
h.metricsStages.RecordSignalingReceived()
|
h.metricsStages.RecordSignalingReceived()
|
||||||
}
|
}
|
||||||
|
|
||||||
h.updateRemoteICEState(&remoteOfferAnswer)
|
|
||||||
|
|
||||||
if h.relayListener != nil {
|
if h.relayListener != nil {
|
||||||
h.relayListener.Notify(&remoteOfferAnswer)
|
h.relayListener.Notify(&remoteOfferAnswer)
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.iceListener != nil && h.RemoteICESupported() {
|
if h.iceListener != nil {
|
||||||
h.iceListener(&remoteOfferAnswer)
|
h.iceListener(&remoteOfferAnswer)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -128,20 +110,18 @@ func (h *Handshaker) Listen(ctx context.Context) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
||||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
|
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||||
|
|
||||||
// Record signaling received for reconnection attempts
|
// Record signaling received for reconnection attempts
|
||||||
if h.metricsStages != nil {
|
if h.metricsStages != nil {
|
||||||
h.metricsStages.RecordSignalingReceived()
|
h.metricsStages.RecordSignalingReceived()
|
||||||
}
|
}
|
||||||
|
|
||||||
h.updateRemoteICEState(&remoteOfferAnswer)
|
|
||||||
|
|
||||||
if h.relayListener != nil {
|
if h.relayListener != nil {
|
||||||
h.relayListener.Notify(&remoteOfferAnswer)
|
h.relayListener.Notify(&remoteOfferAnswer)
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.iceListener != nil && h.RemoteICESupported() {
|
if h.iceListener != nil {
|
||||||
h.iceListener(&remoteOfferAnswer)
|
h.iceListener(&remoteOfferAnswer)
|
||||||
}
|
}
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
@@ -203,18 +183,15 @@ func (h *Handshaker) sendAnswer() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshaker) buildOfferAnswer() OfferAnswer {
|
func (h *Handshaker) buildOfferAnswer() OfferAnswer {
|
||||||
|
uFrag, pwd := h.ice.GetLocalUserCredentials()
|
||||||
|
sid := h.ice.SessionID()
|
||||||
answer := OfferAnswer{
|
answer := OfferAnswer{
|
||||||
|
IceCredentials: IceCredentials{uFrag, pwd},
|
||||||
WgListenPort: h.config.LocalWgPort,
|
WgListenPort: h.config.LocalWgPort,
|
||||||
Version: version.NetbirdVersion(),
|
Version: version.NetbirdVersion(),
|
||||||
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
|
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
|
||||||
RosenpassAddr: h.config.RosenpassConfig.Addr,
|
RosenpassAddr: h.config.RosenpassConfig.Addr,
|
||||||
}
|
SessionID: &sid,
|
||||||
|
|
||||||
if h.ice != nil && h.RemoteICESupported() {
|
|
||||||
uFrag, pwd := h.ice.GetLocalUserCredentials()
|
|
||||||
sid := h.ice.SessionID()
|
|
||||||
answer.IceCredentials = IceCredentials{uFrag, pwd}
|
|
||||||
answer.SessionID = &sid
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if addr, err := h.relay.RelayInstanceAddress(); err == nil {
|
if addr, err := h.relay.RelayInstanceAddress(); err == nil {
|
||||||
@@ -223,18 +200,3 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer {
|
|||||||
|
|
||||||
return answer
|
return answer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshaker) updateRemoteICEState(offer *OfferAnswer) {
|
|
||||||
hasICE := offer.hasICECredentials()
|
|
||||||
prev := h.remoteICESupported.Swap(hasICE)
|
|
||||||
if prev != hasICE {
|
|
||||||
if hasICE {
|
|
||||||
h.log.Infof("remote peer started sending ICE credentials")
|
|
||||||
} else {
|
|
||||||
h.log.Infof("remote peer stopped sending ICE credentials")
|
|
||||||
if h.ice != nil {
|
|
||||||
h.ice.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -46,13 +46,9 @@ func (s *Signaler) Ready() bool {
|
|||||||
|
|
||||||
// SignalOfferAnswer signals either an offer or an answer to remote peer
|
// SignalOfferAnswer signals either an offer or an answer to remote peer
|
||||||
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
|
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
|
||||||
var sessionIDBytes []byte
|
sessionIDBytes, err := offerAnswer.SessionID.Bytes()
|
||||||
if offerAnswer.SessionID != nil {
|
if err != nil {
|
||||||
var err error
|
log.Warnf("failed to get session ID bytes: %v", err)
|
||||||
sessionIDBytes, err = offerAnswer.SessionID.Bytes()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to get session ID bytes: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
msg, err := signal.MarshalCredential(
|
msg, err := signal.MarshalCredential(
|
||||||
s.wgPrivateKey,
|
s.wgPrivateKey,
|
||||||
|
|||||||
@@ -64,11 +64,13 @@ type ConfigInput struct {
|
|||||||
StateFilePath string
|
StateFilePath string
|
||||||
PreSharedKey *string
|
PreSharedKey *string
|
||||||
ServerSSHAllowed *bool
|
ServerSSHAllowed *bool
|
||||||
|
ServerVNCAllowed *bool
|
||||||
EnableSSHRoot *bool
|
EnableSSHRoot *bool
|
||||||
EnableSSHSFTP *bool
|
EnableSSHSFTP *bool
|
||||||
EnableSSHLocalPortForwarding *bool
|
EnableSSHLocalPortForwarding *bool
|
||||||
EnableSSHRemotePortForwarding *bool
|
EnableSSHRemotePortForwarding *bool
|
||||||
DisableSSHAuth *bool
|
DisableSSHAuth *bool
|
||||||
|
DisableVNCAuth *bool
|
||||||
SSHJWTCacheTTL *int
|
SSHJWTCacheTTL *int
|
||||||
NATExternalIPs []string
|
NATExternalIPs []string
|
||||||
CustomDNSAddress []byte
|
CustomDNSAddress []byte
|
||||||
@@ -114,11 +116,13 @@ type Config struct {
|
|||||||
RosenpassEnabled bool
|
RosenpassEnabled bool
|
||||||
RosenpassPermissive bool
|
RosenpassPermissive bool
|
||||||
ServerSSHAllowed *bool
|
ServerSSHAllowed *bool
|
||||||
|
ServerVNCAllowed *bool
|
||||||
EnableSSHRoot *bool
|
EnableSSHRoot *bool
|
||||||
EnableSSHSFTP *bool
|
EnableSSHSFTP *bool
|
||||||
EnableSSHLocalPortForwarding *bool
|
EnableSSHLocalPortForwarding *bool
|
||||||
EnableSSHRemotePortForwarding *bool
|
EnableSSHRemotePortForwarding *bool
|
||||||
DisableSSHAuth *bool
|
DisableSSHAuth *bool
|
||||||
|
DisableVNCAuth *bool
|
||||||
SSHJWTCacheTTL *int
|
SSHJWTCacheTTL *int
|
||||||
|
|
||||||
DisableClientRoutes bool
|
DisableClientRoutes bool
|
||||||
@@ -415,6 +419,21 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.ServerVNCAllowed != nil {
|
||||||
|
if config.ServerVNCAllowed == nil || *input.ServerVNCAllowed != *config.ServerVNCAllowed {
|
||||||
|
if *input.ServerVNCAllowed {
|
||||||
|
log.Infof("enabling VNC server")
|
||||||
|
} else {
|
||||||
|
log.Infof("disabling VNC server")
|
||||||
|
}
|
||||||
|
config.ServerVNCAllowed = input.ServerVNCAllowed
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
} else if config.ServerVNCAllowed == nil {
|
||||||
|
config.ServerVNCAllowed = util.True()
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
|
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
|
||||||
if *input.EnableSSHRoot {
|
if *input.EnableSSHRoot {
|
||||||
log.Infof("enabling SSH root login")
|
log.Infof("enabling SSH root login")
|
||||||
@@ -465,6 +484,16 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.DisableVNCAuth != nil && input.DisableVNCAuth != config.DisableVNCAuth {
|
||||||
|
if *input.DisableVNCAuth {
|
||||||
|
log.Infof("disabling VNC authentication")
|
||||||
|
} else {
|
||||||
|
log.Infof("enabling VNC authentication")
|
||||||
|
}
|
||||||
|
config.DisableVNCAuth = input.DisableVNCAuth
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
|
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
|
||||||
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
|
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
|
||||||
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
|
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
|
||||||
|
|||||||
@@ -89,16 +89,8 @@ func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
|
|||||||
return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec)
|
return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec)
|
||||||
}
|
}
|
||||||
|
|
||||||
reused := false
|
|
||||||
if err := r.addScopedDefault(unspec, nexthop); err != nil {
|
if err := r.addScopedDefault(unspec, nexthop); err != nil {
|
||||||
if !errors.Is(err, unix.EEXIST) {
|
return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
|
||||||
return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
|
|
||||||
}
|
|
||||||
// macOS installs its own RTF_IFSCOPE defaults for primary service
|
|
||||||
// selection on multi-NIC setups, so a route on this ifindex can
|
|
||||||
// already exist before we try. Binding to it via IP[V6]_BOUND_IF
|
|
||||||
// still produces the scoped lookup we need.
|
|
||||||
reused = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
af := unix.AF_INET
|
af := unix.AF_INET
|
||||||
@@ -110,11 +102,7 @@ func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
|
|||||||
if nexthop.IP.IsValid() {
|
if nexthop.IP.IsValid() {
|
||||||
via = nexthop.IP.String()
|
via = nexthop.IP.String()
|
||||||
}
|
}
|
||||||
verb := "installed"
|
log.Infof("installed scoped default route via %s on %s for %s", via, nexthop.Intf.Name, afOf(unspec))
|
||||||
if reused {
|
|
||||||
verb = "reused existing"
|
|
||||||
}
|
|
||||||
log.Infof("%s scoped default route via %s on %s for %s", verb, via, nexthop.Intf.Name, afOf(unspec))
|
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,358 +2,217 @@
|
|||||||
|
|
||||||
package sleep
|
package sleep
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo LDFLAGS: -framework IOKit -framework CoreFoundation
|
||||||
|
#include <IOKit/pwr_mgt/IOPMLib.h>
|
||||||
|
#include <IOKit/IOMessage.h>
|
||||||
|
#include <CoreFoundation/CoreFoundation.h>
|
||||||
|
|
||||||
|
extern void sleepCallbackBridge();
|
||||||
|
extern void poweredOnCallbackBridge();
|
||||||
|
extern void suspendedCallbackBridge();
|
||||||
|
extern void resumedCallbackBridge();
|
||||||
|
|
||||||
|
|
||||||
|
// C global variables for IOKit state
|
||||||
|
static IONotificationPortRef g_notifyPortRef = NULL;
|
||||||
|
static io_object_t g_notifierObject = 0;
|
||||||
|
static io_object_t g_generalInterestNotifier = 0;
|
||||||
|
static io_connect_t g_rootPort = 0;
|
||||||
|
static CFRunLoopRef g_runLoop = NULL;
|
||||||
|
|
||||||
|
static void sleepCallback(void* refCon, io_service_t service, natural_t messageType, void* messageArgument) {
|
||||||
|
switch (messageType) {
|
||||||
|
case kIOMessageSystemWillSleep:
|
||||||
|
sleepCallbackBridge();
|
||||||
|
IOAllowPowerChange(g_rootPort, (long)messageArgument);
|
||||||
|
break;
|
||||||
|
case kIOMessageSystemHasPoweredOn:
|
||||||
|
poweredOnCallbackBridge();
|
||||||
|
break;
|
||||||
|
case kIOMessageServiceIsSuspended:
|
||||||
|
suspendedCallbackBridge();
|
||||||
|
break;
|
||||||
|
case kIOMessageServiceIsResumed:
|
||||||
|
resumedCallbackBridge();
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void registerNotifications() {
|
||||||
|
g_rootPort = IORegisterForSystemPower(
|
||||||
|
NULL,
|
||||||
|
&g_notifyPortRef,
|
||||||
|
(IOServiceInterestCallback)sleepCallback,
|
||||||
|
&g_notifierObject
|
||||||
|
);
|
||||||
|
|
||||||
|
if (g_rootPort == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
CFRunLoopAddSource(CFRunLoopGetCurrent(),
|
||||||
|
IONotificationPortGetRunLoopSource(g_notifyPortRef),
|
||||||
|
kCFRunLoopCommonModes);
|
||||||
|
|
||||||
|
g_runLoop = CFRunLoopGetCurrent();
|
||||||
|
CFRunLoopRun();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void unregisterNotifications() {
|
||||||
|
CFRunLoopRemoveSource(g_runLoop,
|
||||||
|
IONotificationPortGetRunLoopSource(g_notifyPortRef),
|
||||||
|
kCFRunLoopCommonModes);
|
||||||
|
|
||||||
|
IODeregisterForSystemPower(&g_notifierObject);
|
||||||
|
IOServiceClose(g_rootPort);
|
||||||
|
IONotificationPortDestroy(g_notifyPortRef);
|
||||||
|
CFRunLoopStop(g_runLoop);
|
||||||
|
|
||||||
|
g_notifyPortRef = NULL;
|
||||||
|
g_notifierObject = 0;
|
||||||
|
g_rootPort = 0;
|
||||||
|
g_runLoop = NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/ebitengine/purego"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// IOKit message types from IOKit/IOMessage.h.
|
|
||||||
const (
|
|
||||||
kIOMessageCanSystemSleep uintptr = 0xe0000270
|
|
||||||
kIOMessageSystemWillSleep uintptr = 0xe0000280
|
|
||||||
kIOMessageSystemHasPoweredOn uintptr = 0xe0000300
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ioKit iokitFuncs
|
|
||||||
cf cfFuncs
|
|
||||||
cfCommonModes uintptr
|
|
||||||
|
|
||||||
libInitOnce sync.Once
|
|
||||||
libInitErr error
|
|
||||||
|
|
||||||
// callbackThunk is the single C-callable trampoline registered with IOKit.
|
|
||||||
callbackThunk uintptr
|
|
||||||
|
|
||||||
serviceRegistry = make(map[*Detector]struct{})
|
serviceRegistry = make(map[*Detector]struct{})
|
||||||
serviceRegistryMu sync.Mutex
|
serviceRegistryMu sync.Mutex
|
||||||
session *runLoopSession
|
|
||||||
|
|
||||||
// lifecycleMu serializes Register/Deregister so a new registration can't
|
|
||||||
// start a second runloop while a previous teardown is still pending.
|
|
||||||
lifecycleMu sync.Mutex
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// iokitFuncs holds IOKit symbols resolved once at init.
|
//export sleepCallbackBridge
|
||||||
type iokitFuncs struct {
|
func sleepCallbackBridge() {
|
||||||
IORegisterForSystemPower func(refcon uintptr, portRef *uintptr, callback uintptr, notifier *uintptr) uintptr
|
log.Info("sleepCallbackBridge event triggered")
|
||||||
IODeregisterForSystemPower func(notifier *uintptr) int32
|
|
||||||
IOAllowPowerChange func(kernelPort uintptr, notificationID uintptr) int32
|
|
||||||
IOServiceClose func(connect uintptr) int32
|
|
||||||
IONotificationPortGetRunLoopSource func(port uintptr) uintptr
|
|
||||||
IONotificationPortDestroy func(port uintptr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// cfFuncs holds CoreFoundation symbols resolved once at init.
|
|
||||||
type cfFuncs struct {
|
|
||||||
CFRunLoopGetCurrent func() uintptr
|
|
||||||
CFRunLoopRun func()
|
|
||||||
CFRunLoopStop func(rl uintptr)
|
|
||||||
CFRunLoopAddSource func(rl, source, mode uintptr)
|
|
||||||
CFRunLoopRemoveSource func(rl, source, mode uintptr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// runLoopSession bundles the handles owned by one CFRunLoop lifetime. A nil
|
|
||||||
// session means no runloop is active and the next Register must start one.
|
|
||||||
type runLoopSession struct {
|
|
||||||
rl uintptr
|
|
||||||
port uintptr
|
|
||||||
notifier uintptr
|
|
||||||
rp uintptr
|
|
||||||
}
|
|
||||||
|
|
||||||
// detectorSnapshot pins a detector's callback and done channel so dispatch
|
|
||||||
// runs with values valid at snapshot time, even if a concurrent
|
|
||||||
// Deregister/Register rewrites the detector's fields.
|
|
||||||
type detectorSnapshot struct {
|
|
||||||
detector *Detector
|
|
||||||
callback func(event EventType)
|
|
||||||
done <-chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Detector delivers sleep and wake events to a registered callback.
|
|
||||||
type Detector struct {
|
|
||||||
callback func(event EventType)
|
|
||||||
done chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register installs callback for power events. The first registration starts
|
|
||||||
// the CFRunLoop on a dedicated OS-locked thread and blocks until IOKit
|
|
||||||
// registration succeeds or fails; subsequent registrations just add to the
|
|
||||||
// dispatch set.
|
|
||||||
func (d *Detector) Register(callback func(event EventType)) error {
|
|
||||||
lifecycleMu.Lock()
|
|
||||||
defer lifecycleMu.Unlock()
|
|
||||||
|
|
||||||
serviceRegistryMu.Lock()
|
serviceRegistryMu.Lock()
|
||||||
|
defer serviceRegistryMu.Unlock()
|
||||||
|
|
||||||
|
for svc := range serviceRegistry {
|
||||||
|
svc.triggerCallback(EventTypeSleep)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//export resumedCallbackBridge
|
||||||
|
func resumedCallbackBridge() {
|
||||||
|
log.Info("resumedCallbackBridge event triggered")
|
||||||
|
}
|
||||||
|
|
||||||
|
//export suspendedCallbackBridge
|
||||||
|
func suspendedCallbackBridge() {
|
||||||
|
log.Info("suspendedCallbackBridge event triggered")
|
||||||
|
}
|
||||||
|
|
||||||
|
//export poweredOnCallbackBridge
|
||||||
|
func poweredOnCallbackBridge() {
|
||||||
|
log.Info("poweredOnCallbackBridge event triggered")
|
||||||
|
serviceRegistryMu.Lock()
|
||||||
|
defer serviceRegistryMu.Unlock()
|
||||||
|
|
||||||
|
for svc := range serviceRegistry {
|
||||||
|
svc.triggerCallback(EventTypeWakeUp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Detector struct {
|
||||||
|
callback func(event EventType)
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDetector() (*Detector, error) {
|
||||||
|
return &Detector{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Detector) Register(callback func(event EventType)) error {
|
||||||
|
serviceRegistryMu.Lock()
|
||||||
|
defer serviceRegistryMu.Unlock()
|
||||||
|
|
||||||
if _, exists := serviceRegistry[d]; exists {
|
if _, exists := serviceRegistry[d]; exists {
|
||||||
serviceRegistryMu.Unlock()
|
|
||||||
return fmt.Errorf("detector service already registered")
|
return fmt.Errorf("detector service already registered")
|
||||||
}
|
}
|
||||||
d.callback = callback
|
|
||||||
d.done = make(chan struct{})
|
|
||||||
serviceRegistry[d] = struct{}{}
|
|
||||||
needSetup := session == nil
|
|
||||||
serviceRegistryMu.Unlock()
|
|
||||||
|
|
||||||
if !needSetup {
|
d.callback = callback
|
||||||
|
|
||||||
|
d.ctx, d.cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
if len(serviceRegistry) > 0 {
|
||||||
|
serviceRegistry[d] = struct{}{}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
errCh := make(chan error, 1)
|
serviceRegistry[d] = struct{}{}
|
||||||
go runRunLoop(errCh)
|
|
||||||
if err := <-errCh; err != nil {
|
// CFRunLoop must run on a single fixed OS thread
|
||||||
serviceRegistryMu.Lock()
|
go func() {
|
||||||
delete(serviceRegistry, d)
|
runtime.LockOSThread()
|
||||||
close(d.done)
|
defer runtime.UnlockOSThread()
|
||||||
d.done = nil
|
|
||||||
serviceRegistryMu.Unlock()
|
C.registerNotifications()
|
||||||
return err
|
}()
|
||||||
}
|
|
||||||
|
|
||||||
log.Info("sleep detection service started on macOS")
|
log.Info("sleep detection service started on macOS")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deregister removes the detector. When the last detector leaves, IOKit
|
// Deregister removes the detector. When the last detector is removed, IOKit registration is torn down
|
||||||
// notifications are torn down and the runloop is stopped.
|
// and the runloop is stopped and cleaned up.
|
||||||
func (d *Detector) Deregister() error {
|
func (d *Detector) Deregister() error {
|
||||||
lifecycleMu.Lock()
|
|
||||||
defer lifecycleMu.Unlock()
|
|
||||||
|
|
||||||
serviceRegistryMu.Lock()
|
serviceRegistryMu.Lock()
|
||||||
if _, exists := serviceRegistry[d]; !exists {
|
defer serviceRegistryMu.Unlock()
|
||||||
serviceRegistryMu.Unlock()
|
_, exists := serviceRegistry[d]
|
||||||
|
if !exists {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
close(d.done)
|
|
||||||
|
// cancel and remove this detector
|
||||||
|
d.cancel()
|
||||||
delete(serviceRegistry, d)
|
delete(serviceRegistry, d)
|
||||||
|
|
||||||
|
// If other Detectors still exist, leave IOKit running
|
||||||
if len(serviceRegistry) > 0 {
|
if len(serviceRegistry) > 0 {
|
||||||
serviceRegistryMu.Unlock()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
sess := session
|
|
||||||
serviceRegistryMu.Unlock()
|
|
||||||
|
|
||||||
log.Info("sleep detection service stopping (deregister)")
|
log.Info("sleep detection service stopping (deregister)")
|
||||||
|
|
||||||
if sess == nil {
|
// Deregister IOKit notifications, stop runloop, and free resources
|
||||||
return nil
|
C.unregisterNotifications()
|
||||||
}
|
|
||||||
|
|
||||||
if sess.rl != 0 && sess.port != 0 {
|
|
||||||
source := ioKit.IONotificationPortGetRunLoopSource(sess.port)
|
|
||||||
cf.CFRunLoopRemoveSource(sess.rl, source, cfCommonModes)
|
|
||||||
}
|
|
||||||
if sess.notifier != 0 {
|
|
||||||
n := sess.notifier
|
|
||||||
ioKit.IODeregisterForSystemPower(&n)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear session only after IODeregisterForSystemPower returns so any
|
|
||||||
// in-flight powerCallback can still look up session.rp to ack sleep.
|
|
||||||
serviceRegistryMu.Lock()
|
|
||||||
session = nil
|
|
||||||
serviceRegistryMu.Unlock()
|
|
||||||
|
|
||||||
if sess.rp != 0 {
|
|
||||||
ioKit.IOServiceClose(sess.rp)
|
|
||||||
}
|
|
||||||
if sess.port != 0 {
|
|
||||||
ioKit.IONotificationPortDestroy(sess.port)
|
|
||||||
}
|
|
||||||
if sess.rl != 0 {
|
|
||||||
cf.CFRunLoopStop(sess.rl)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Detector) triggerCallback(event EventType, cb func(event EventType), done <-chan struct{}) {
|
func (d *Detector) triggerCallback(event EventType) {
|
||||||
if cb == nil || done == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
doneChan := make(chan struct{})
|
doneChan := make(chan struct{})
|
||||||
|
|
||||||
timeout := time.NewTimer(500 * time.Millisecond)
|
timeout := time.NewTimer(500 * time.Millisecond)
|
||||||
defer timeout.Stop()
|
defer timeout.Stop()
|
||||||
|
|
||||||
go func() {
|
cb := d.callback
|
||||||
defer close(doneChan)
|
go func(callback func(event EventType)) {
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
log.Errorf("panic in sleep callback: %v", r)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
log.Info("sleep detection event fired")
|
log.Info("sleep detection event fired")
|
||||||
cb(event)
|
callback(event)
|
||||||
}()
|
close(doneChan)
|
||||||
|
}(cb)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-doneChan:
|
case <-doneChan:
|
||||||
case <-done:
|
case <-d.ctx.Done():
|
||||||
case <-timeout.C:
|
case <-timeout.C:
|
||||||
log.Warn("sleep callback timed out")
|
log.Warnf("sleep callback timed out")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDetector initializes IOKit/CoreFoundation bindings and returns a Detector.
|
|
||||||
func NewDetector() (*Detector, error) {
|
|
||||||
if err := initLibs(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &Detector{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func initLibs() error {
|
|
||||||
libInitOnce.Do(func() {
|
|
||||||
iokit, err := purego.Dlopen("/System/Library/Frameworks/IOKit.framework/IOKit", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
|
||||||
if err != nil {
|
|
||||||
libInitErr = fmt.Errorf("dlopen IOKit: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
cfLib, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
|
||||||
if err != nil {
|
|
||||||
libInitErr = fmt.Errorf("dlopen CoreFoundation: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
purego.RegisterLibFunc(&ioKit.IORegisterForSystemPower, iokit, "IORegisterForSystemPower")
|
|
||||||
purego.RegisterLibFunc(&ioKit.IODeregisterForSystemPower, iokit, "IODeregisterForSystemPower")
|
|
||||||
purego.RegisterLibFunc(&ioKit.IOAllowPowerChange, iokit, "IOAllowPowerChange")
|
|
||||||
purego.RegisterLibFunc(&ioKit.IOServiceClose, iokit, "IOServiceClose")
|
|
||||||
purego.RegisterLibFunc(&ioKit.IONotificationPortGetRunLoopSource, iokit, "IONotificationPortGetRunLoopSource")
|
|
||||||
purego.RegisterLibFunc(&ioKit.IONotificationPortDestroy, iokit, "IONotificationPortDestroy")
|
|
||||||
|
|
||||||
purego.RegisterLibFunc(&cf.CFRunLoopGetCurrent, cfLib, "CFRunLoopGetCurrent")
|
|
||||||
purego.RegisterLibFunc(&cf.CFRunLoopRun, cfLib, "CFRunLoopRun")
|
|
||||||
purego.RegisterLibFunc(&cf.CFRunLoopStop, cfLib, "CFRunLoopStop")
|
|
||||||
purego.RegisterLibFunc(&cf.CFRunLoopAddSource, cfLib, "CFRunLoopAddSource")
|
|
||||||
purego.RegisterLibFunc(&cf.CFRunLoopRemoveSource, cfLib, "CFRunLoopRemoveSource")
|
|
||||||
|
|
||||||
modeAddr, err := purego.Dlsym(cfLib, "kCFRunLoopCommonModes")
|
|
||||||
if err != nil {
|
|
||||||
libInitErr = fmt.Errorf("dlsym kCFRunLoopCommonModes: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Launder the uintptr-to-pointer conversion through a Go variable so
|
|
||||||
// go vet's unsafeptr analyzer doesn't flag a system-library global.
|
|
||||||
cfCommonModes = **(**uintptr)(unsafe.Pointer(&modeAddr))
|
|
||||||
|
|
||||||
// NewCallback slots are a finite, non-reclaimable resource, so register
|
|
||||||
// a single thunk that dispatches to the current Detector set.
|
|
||||||
callbackThunk = purego.NewCallback(powerCallback)
|
|
||||||
})
|
|
||||||
return libInitErr
|
|
||||||
}
|
|
||||||
|
|
||||||
// powerCallback is the IOServiceInterestCallback trampoline, invoked on the
|
|
||||||
// runloop thread. A Go panic crossing the purego boundary has undefined
|
|
||||||
// behavior, so contain it here.
|
|
||||||
func powerCallback(refcon, service, messageType, messageArgument uintptr) uintptr {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
log.Errorf("panic in sleep powerCallback: %v", r)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
switch messageType {
|
|
||||||
case kIOMessageCanSystemSleep:
|
|
||||||
// Not acknowledging forces a 30s IOKit timeout before idle sleep.
|
|
||||||
allowPowerChange(messageArgument)
|
|
||||||
case kIOMessageSystemWillSleep:
|
|
||||||
dispatchEvent(EventTypeSleep)
|
|
||||||
allowPowerChange(messageArgument)
|
|
||||||
case kIOMessageSystemHasPoweredOn:
|
|
||||||
dispatchEvent(EventTypeWakeUp)
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func allowPowerChange(messageArgument uintptr) {
|
|
||||||
serviceRegistryMu.Lock()
|
|
||||||
var port uintptr
|
|
||||||
if session != nil {
|
|
||||||
port = session.rp
|
|
||||||
}
|
|
||||||
serviceRegistryMu.Unlock()
|
|
||||||
if port != 0 {
|
|
||||||
ioKit.IOAllowPowerChange(port, messageArgument)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func dispatchEvent(event EventType) {
|
|
||||||
serviceRegistryMu.Lock()
|
|
||||||
snaps := make([]detectorSnapshot, 0, len(serviceRegistry))
|
|
||||||
for d := range serviceRegistry {
|
|
||||||
snaps = append(snaps, detectorSnapshot{
|
|
||||||
detector: d,
|
|
||||||
callback: d.callback,
|
|
||||||
done: d.done,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
serviceRegistryMu.Unlock()
|
|
||||||
|
|
||||||
for _, s := range snaps {
|
|
||||||
s.detector.triggerCallback(event, s.callback, s.done)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// runRunLoop owns the OS-locked thread that CFRunLoop is pinned to. Setup
|
|
||||||
// result is reported on errCh so Register can surface failures synchronously.
|
|
||||||
func runRunLoop(errCh chan<- error) {
|
|
||||||
runtime.LockOSThread()
|
|
||||||
defer runtime.UnlockOSThread()
|
|
||||||
|
|
||||||
sess, err := setupSession()
|
|
||||||
if err == nil {
|
|
||||||
serviceRegistryMu.Lock()
|
|
||||||
session = sess
|
|
||||||
serviceRegistryMu.Unlock()
|
|
||||||
}
|
|
||||||
errCh <- err
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
log.Errorf("panic in sleep runloop: %v", r)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
cf.CFRunLoopRun()
|
|
||||||
}
|
|
||||||
|
|
||||||
// setupSession performs the IOKit registration on the current thread. Panics
|
|
||||||
// are converted to errors so runRunLoop never leaves errCh unsent.
|
|
||||||
func setupSession() (s *runLoopSession, err error) {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
err = fmt.Errorf("panic during runloop setup: %v", r)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
var portRef, notifier uintptr
|
|
||||||
rp := ioKit.IORegisterForSystemPower(0, &portRef, callbackThunk, ¬ifier)
|
|
||||||
if rp == 0 {
|
|
||||||
return nil, fmt.Errorf("IORegisterForSystemPower returned zero")
|
|
||||||
}
|
|
||||||
|
|
||||||
rl := cf.CFRunLoopGetCurrent()
|
|
||||||
source := ioKit.IONotificationPortGetRunLoopSource(portRef)
|
|
||||||
cf.CFRunLoopAddSource(rl, source, cfCommonModes)
|
|
||||||
|
|
||||||
return &runLoopSession{rl: rl, port: portRef, notifier: notifier, rp: rp}, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -74,6 +74,14 @@ func New(filePath string) *Manager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FilePath returns the path of the underlying state file.
|
||||||
|
func (m *Manager) FilePath() string {
|
||||||
|
if m == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return m.filePath
|
||||||
|
}
|
||||||
|
|
||||||
// Start starts the state manager periodic save routine
|
// Start starts the state manager periodic save routine
|
||||||
func (m *Manager) Start() {
|
func (m *Manager) Start() {
|
||||||
if m == nil {
|
if m == nil {
|
||||||
|
|||||||
@@ -13,25 +13,15 @@
|
|||||||
|
|
||||||
<MajorUpgrade AllowSameVersionUpgrades='yes' DowngradeErrorMessage="A newer version of [ProductName] is already installed. Setup will now exit."/>
|
<MajorUpgrade AllowSameVersionUpgrades='yes' DowngradeErrorMessage="A newer version of [ProductName] is already installed. Setup will now exit."/>
|
||||||
|
|
||||||
<!-- Autostart: enabled by default, disable with AUTOSTART=0 on the msiexec command line -->
|
|
||||||
<Property Id="AUTOSTART" Value="1" />
|
|
||||||
|
|
||||||
<StandardDirectory Id="ProgramFiles64Folder">
|
<StandardDirectory Id="ProgramFiles64Folder">
|
||||||
<Directory Id="NetbirdInstallDir" Name="Netbird">
|
<Directory Id="NetbirdInstallDir" Name="Netbird">
|
||||||
<Component Id="NetbirdFiles" Guid="db3165de-cc6e-4922-8396-9d892950e23e" Bitness="always64">
|
<Component Id="NetbirdFiles" Guid="db3165de-cc6e-4922-8396-9d892950e23e" Bitness="always64">
|
||||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird.exe" KeyPath="yes" />
|
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird.exe" KeyPath="yes" />
|
||||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird-ui.exe">
|
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird-ui.exe">
|
||||||
<Shortcut Id="NetbirdDesktopShortcut" Directory="DesktopFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon">
|
<Shortcut Id="NetbirdDesktopShortcut" Directory="DesktopFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
|
||||||
<ShortcutProperty Key="System.AppUserModel.ID" Value="NetBird" />
|
<Shortcut Id="NetbirdStartMenuShortcut" Directory="StartMenuFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
|
||||||
<ShortcutProperty Key="System.AppUserModel.ToastActivatorCLSID" Value="{0E1B4DE7-E148-432B-9814-544F941826EC}" />
|
|
||||||
</Shortcut>
|
|
||||||
<Shortcut Id="NetbirdStartMenuShortcut" Directory="StartMenuFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon">
|
|
||||||
<ShortcutProperty Key="System.AppUserModel.ID" Value="NetBird" />
|
|
||||||
<ShortcutProperty Key="System.AppUserModel.ToastActivatorCLSID" Value="{0E1B4DE7-E148-432B-9814-544F941826EC}" />
|
|
||||||
</Shortcut>
|
|
||||||
</File>
|
</File>
|
||||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\wintun.dll" />
|
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\wintun.dll" />
|
||||||
<File Id="NetbirdToastIcon" Name="netbird.png" Source=".\client\ui\assets\netbird.png" />
|
|
||||||
<?if $(var.ArchSuffix) = "amd64" ?>
|
<?if $(var.ArchSuffix) = "amd64" ?>
|
||||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\opengl32.dll" />
|
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\opengl32.dll" />
|
||||||
<?endif ?>
|
<?endif ?>
|
||||||
@@ -56,31 +46,8 @@
|
|||||||
</Directory>
|
</Directory>
|
||||||
</StandardDirectory>
|
</StandardDirectory>
|
||||||
|
|
||||||
<!-- Per-user component: HKCU keypath (auto GUID via "*"), separate from
|
|
||||||
the per-machine NetbirdFiles component to satisfy ICE57. -->
|
|
||||||
<StandardDirectory Id="ProgramMenuFolder">
|
|
||||||
<Component Id="NetbirdAumidRegistry" Guid="*">
|
|
||||||
<RegistryKey Root="HKCU" Key="Software\Classes\AppUserModelId\NetBird" ForceDeleteOnUninstall="yes">
|
|
||||||
<RegistryValue Name="InstalledByMSI" Type="integer" Value="1" KeyPath="yes" />
|
|
||||||
</RegistryKey>
|
|
||||||
</Component>
|
|
||||||
</StandardDirectory>
|
|
||||||
|
|
||||||
<StandardDirectory Id="CommonAppDataFolder">
|
|
||||||
<Directory Id="NetbirdAutoStartDir" Name="Netbird">
|
|
||||||
<Component Id="NetbirdAutoStart" Guid="b199eaca-b0dd-4032-af19-679cfad48eb3" Bitness="always64">
|
|
||||||
<Condition>AUTOSTART = "1"</Condition>
|
|
||||||
<RegistryValue Root="HKLM" Key="Software\Microsoft\Windows\CurrentVersion\Run"
|
|
||||||
Name="Netbird" Value=""[NetbirdInstallDir]netbird-ui.exe""
|
|
||||||
Type="string" KeyPath="yes" />
|
|
||||||
</Component>
|
|
||||||
</Directory>
|
|
||||||
</StandardDirectory>
|
|
||||||
|
|
||||||
<ComponentGroup Id="NetbirdFilesComponent">
|
<ComponentGroup Id="NetbirdFilesComponent">
|
||||||
<ComponentRef Id="NetbirdFiles" />
|
<ComponentRef Id="NetbirdFiles" />
|
||||||
<ComponentRef Id="NetbirdAumidRegistry" />
|
|
||||||
<ComponentRef Id="NetbirdAutoStart" />
|
|
||||||
</ComponentGroup>
|
</ComponentGroup>
|
||||||
|
|
||||||
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />
|
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -104,6 +104,8 @@ service DaemonService {
|
|||||||
// StopCPUProfile stops CPU profiling in the daemon
|
// StopCPUProfile stops CPU profiling in the daemon
|
||||||
rpc StopCPUProfile(StopCPUProfileRequest) returns (StopCPUProfileResponse) {}
|
rpc StopCPUProfile(StopCPUProfileRequest) returns (StopCPUProfileResponse) {}
|
||||||
|
|
||||||
|
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
|
||||||
|
|
||||||
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
|
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
|
||||||
|
|
||||||
// ExposeService exposes a local port via the NetBird reverse proxy
|
// ExposeService exposes a local port via the NetBird reverse proxy
|
||||||
@@ -112,6 +114,20 @@ service DaemonService {
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
message OSLifecycleRequest {
|
||||||
|
// avoid collision with loglevel enum
|
||||||
|
enum CycleType {
|
||||||
|
UNKNOWN = 0;
|
||||||
|
SLEEP = 1;
|
||||||
|
WAKEUP = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
CycleType type = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message OSLifecycleResponse {}
|
||||||
|
|
||||||
|
|
||||||
message LoginRequest {
|
message LoginRequest {
|
||||||
// setupKey netbird setup key.
|
// setupKey netbird setup key.
|
||||||
string setupKey = 1;
|
string setupKey = 1;
|
||||||
@@ -193,6 +209,9 @@ message LoginRequest {
|
|||||||
optional bool enableSSHRemotePortForwarding = 37;
|
optional bool enableSSHRemotePortForwarding = 37;
|
||||||
optional bool disableSSHAuth = 38;
|
optional bool disableSSHAuth = 38;
|
||||||
optional int32 sshJWTCacheTTL = 39;
|
optional int32 sshJWTCacheTTL = 39;
|
||||||
|
|
||||||
|
optional bool serverVNCAllowed = 41;
|
||||||
|
optional bool disableVNCAuth = 42;
|
||||||
}
|
}
|
||||||
|
|
||||||
message LoginResponse {
|
message LoginResponse {
|
||||||
@@ -300,6 +319,10 @@ message GetConfigResponse {
|
|||||||
bool disableSSHAuth = 25;
|
bool disableSSHAuth = 25;
|
||||||
|
|
||||||
int32 sshJWTCacheTTL = 26;
|
int32 sshJWTCacheTTL = 26;
|
||||||
|
|
||||||
|
bool serverVNCAllowed = 28;
|
||||||
|
|
||||||
|
bool disableVNCAuth = 29;
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeerState contains the latest state of a peer
|
// PeerState contains the latest state of a peer
|
||||||
@@ -378,6 +401,11 @@ message SSHServerState {
|
|||||||
repeated SSHSessionInfo sessions = 2;
|
repeated SSHSessionInfo sessions = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// VNCServerState contains the latest state of the VNC server
|
||||||
|
message VNCServerState {
|
||||||
|
bool enabled = 1;
|
||||||
|
}
|
||||||
|
|
||||||
// FullStatus contains the full state held by the Status instance
|
// FullStatus contains the full state held by the Status instance
|
||||||
message FullStatus {
|
message FullStatus {
|
||||||
ManagementState managementState = 1;
|
ManagementState managementState = 1;
|
||||||
@@ -392,6 +420,7 @@ message FullStatus {
|
|||||||
|
|
||||||
bool lazyConnectionEnabled = 9;
|
bool lazyConnectionEnabled = 9;
|
||||||
SSHServerState sshServerState = 10;
|
SSHServerState sshServerState = 10;
|
||||||
|
VNCServerState vncServerState = 11;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Networks
|
// Networks
|
||||||
@@ -661,6 +690,9 @@ message SetConfigRequest {
|
|||||||
optional bool enableSSHRemotePortForwarding = 32;
|
optional bool enableSSHRemotePortForwarding = 32;
|
||||||
optional bool disableSSHAuth = 33;
|
optional bool disableSSHAuth = 33;
|
||||||
optional int32 sshJWTCacheTTL = 34;
|
optional int32 sshJWTCacheTTL = 34;
|
||||||
|
|
||||||
|
optional bool serverVNCAllowed = 36;
|
||||||
|
optional bool disableVNCAuth = 37;
|
||||||
}
|
}
|
||||||
|
|
||||||
message SetConfigResponse{}
|
message SetConfigResponse{}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -120,7 +120,6 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
|||||||
}
|
}
|
||||||
agent := &serverAgent{s}
|
agent := &serverAgent{s}
|
||||||
s.sleepHandler = sleephandler.New(agent)
|
s.sleepHandler = sleephandler.New(agent)
|
||||||
s.startSleepDetector()
|
|
||||||
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
@@ -370,6 +369,7 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
|||||||
config.RosenpassPermissive = msg.RosenpassPermissive
|
config.RosenpassPermissive = msg.RosenpassPermissive
|
||||||
config.DisableAutoConnect = msg.DisableAutoConnect
|
config.DisableAutoConnect = msg.DisableAutoConnect
|
||||||
config.ServerSSHAllowed = msg.ServerSSHAllowed
|
config.ServerSSHAllowed = msg.ServerSSHAllowed
|
||||||
|
config.ServerVNCAllowed = msg.ServerVNCAllowed
|
||||||
config.NetworkMonitor = msg.NetworkMonitor
|
config.NetworkMonitor = msg.NetworkMonitor
|
||||||
config.DisableClientRoutes = msg.DisableClientRoutes
|
config.DisableClientRoutes = msg.DisableClientRoutes
|
||||||
config.DisableServerRoutes = msg.DisableServerRoutes
|
config.DisableServerRoutes = msg.DisableServerRoutes
|
||||||
@@ -386,6 +386,9 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
|||||||
if msg.DisableSSHAuth != nil {
|
if msg.DisableSSHAuth != nil {
|
||||||
config.DisableSSHAuth = msg.DisableSSHAuth
|
config.DisableSSHAuth = msg.DisableSSHAuth
|
||||||
}
|
}
|
||||||
|
if msg.DisableVNCAuth != nil {
|
||||||
|
config.DisableVNCAuth = msg.DisableVNCAuth
|
||||||
|
}
|
||||||
if msg.SshJWTCacheTTL != nil {
|
if msg.SshJWTCacheTTL != nil {
|
||||||
ttl := int(*msg.SshJWTCacheTTL)
|
ttl := int(*msg.SshJWTCacheTTL)
|
||||||
config.SSHJWTCacheTTL = &ttl
|
config.SSHJWTCacheTTL = &ttl
|
||||||
@@ -1124,6 +1127,7 @@ func (s *Server) Status(
|
|||||||
pbFullStatus := fullStatus.ToProto()
|
pbFullStatus := fullStatus.ToProto()
|
||||||
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
||||||
pbFullStatus.SshServerState = s.getSSHServerState()
|
pbFullStatus.SshServerState = s.getSSHServerState()
|
||||||
|
pbFullStatus.VncServerState = s.getVNCServerState()
|
||||||
statusResponse.FullStatus = pbFullStatus
|
statusResponse.FullStatus = pbFullStatus
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1163,6 +1167,26 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
|
|||||||
return sshServerState
|
return sshServerState
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getVNCServerState retrieves the current VNC server state.
|
||||||
|
func (s *Server) getVNCServerState() *proto.VNCServerState {
|
||||||
|
s.mutex.Lock()
|
||||||
|
connectClient := s.connectClient
|
||||||
|
s.mutex.Unlock()
|
||||||
|
|
||||||
|
if connectClient == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := connectClient.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &proto.VNCServerState{
|
||||||
|
Enabled: engine.GetVNCServerStatus(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||||
func (s *Server) GetPeerSSHHostKey(
|
func (s *Server) GetPeerSSHHostKey(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@@ -1504,6 +1528,11 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
|||||||
disableSSHAuth = *cfg.DisableSSHAuth
|
disableSSHAuth = *cfg.DisableSSHAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
|
disableVNCAuth := false
|
||||||
|
if cfg.DisableVNCAuth != nil {
|
||||||
|
disableVNCAuth = *cfg.DisableVNCAuth
|
||||||
|
}
|
||||||
|
|
||||||
sshJWTCacheTTL := int32(0)
|
sshJWTCacheTTL := int32(0)
|
||||||
if cfg.SSHJWTCacheTTL != nil {
|
if cfg.SSHJWTCacheTTL != nil {
|
||||||
sshJWTCacheTTL = int32(*cfg.SSHJWTCacheTTL)
|
sshJWTCacheTTL = int32(*cfg.SSHJWTCacheTTL)
|
||||||
@@ -1518,6 +1547,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
|||||||
Mtu: int64(cfg.MTU),
|
Mtu: int64(cfg.MTU),
|
||||||
DisableAutoConnect: cfg.DisableAutoConnect,
|
DisableAutoConnect: cfg.DisableAutoConnect,
|
||||||
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
||||||
|
ServerVNCAllowed: cfg.ServerVNCAllowed != nil && *cfg.ServerVNCAllowed,
|
||||||
RosenpassEnabled: cfg.RosenpassEnabled,
|
RosenpassEnabled: cfg.RosenpassEnabled,
|
||||||
RosenpassPermissive: cfg.RosenpassPermissive,
|
RosenpassPermissive: cfg.RosenpassPermissive,
|
||||||
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
|
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
|
||||||
@@ -1533,6 +1563,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
|||||||
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
|
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
|
||||||
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
|
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
|
||||||
DisableSSHAuth: disableSSHAuth,
|
DisableSSHAuth: disableSSHAuth,
|
||||||
|
DisableVNCAuth: disableVNCAuth,
|
||||||
SshJWTCacheTTL: sshJWTCacheTTL,
|
SshJWTCacheTTL: sshJWTCacheTTL,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -335,7 +335,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,6 +58,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
|||||||
rosenpassEnabled := true
|
rosenpassEnabled := true
|
||||||
rosenpassPermissive := true
|
rosenpassPermissive := true
|
||||||
serverSSHAllowed := true
|
serverSSHAllowed := true
|
||||||
|
serverVNCAllowed := true
|
||||||
|
disableVNCAuth := true
|
||||||
interfaceName := "utun100"
|
interfaceName := "utun100"
|
||||||
wireguardPort := int64(51820)
|
wireguardPort := int64(51820)
|
||||||
preSharedKey := "test-psk"
|
preSharedKey := "test-psk"
|
||||||
@@ -82,6 +84,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
|||||||
RosenpassEnabled: &rosenpassEnabled,
|
RosenpassEnabled: &rosenpassEnabled,
|
||||||
RosenpassPermissive: &rosenpassPermissive,
|
RosenpassPermissive: &rosenpassPermissive,
|
||||||
ServerSSHAllowed: &serverSSHAllowed,
|
ServerSSHAllowed: &serverSSHAllowed,
|
||||||
|
ServerVNCAllowed: &serverVNCAllowed,
|
||||||
|
DisableVNCAuth: &disableVNCAuth,
|
||||||
InterfaceName: &interfaceName,
|
InterfaceName: &interfaceName,
|
||||||
WireguardPort: &wireguardPort,
|
WireguardPort: &wireguardPort,
|
||||||
OptionalPreSharedKey: &preSharedKey,
|
OptionalPreSharedKey: &preSharedKey,
|
||||||
@@ -125,6 +129,10 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
|||||||
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
|
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
|
||||||
require.NotNil(t, cfg.ServerSSHAllowed)
|
require.NotNil(t, cfg.ServerSSHAllowed)
|
||||||
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
|
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
|
||||||
|
require.NotNil(t, cfg.ServerVNCAllowed)
|
||||||
|
require.Equal(t, serverVNCAllowed, *cfg.ServerVNCAllowed)
|
||||||
|
require.NotNil(t, cfg.DisableVNCAuth)
|
||||||
|
require.Equal(t, disableVNCAuth, *cfg.DisableVNCAuth)
|
||||||
require.Equal(t, interfaceName, cfg.WgIface)
|
require.Equal(t, interfaceName, cfg.WgIface)
|
||||||
require.Equal(t, int(wireguardPort), cfg.WgPort)
|
require.Equal(t, int(wireguardPort), cfg.WgPort)
|
||||||
require.Equal(t, preSharedKey, cfg.PreSharedKey)
|
require.Equal(t, preSharedKey, cfg.PreSharedKey)
|
||||||
@@ -176,6 +184,8 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
|
|||||||
"RosenpassEnabled": true,
|
"RosenpassEnabled": true,
|
||||||
"RosenpassPermissive": true,
|
"RosenpassPermissive": true,
|
||||||
"ServerSSHAllowed": true,
|
"ServerSSHAllowed": true,
|
||||||
|
"ServerVNCAllowed": true,
|
||||||
|
"DisableVNCAuth": true,
|
||||||
"InterfaceName": true,
|
"InterfaceName": true,
|
||||||
"WireguardPort": true,
|
"WireguardPort": true,
|
||||||
"OptionalPreSharedKey": true,
|
"OptionalPreSharedKey": true,
|
||||||
@@ -236,6 +246,8 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
|
|||||||
"enable-rosenpass": "RosenpassEnabled",
|
"enable-rosenpass": "RosenpassEnabled",
|
||||||
"rosenpass-permissive": "RosenpassPermissive",
|
"rosenpass-permissive": "RosenpassPermissive",
|
||||||
"allow-server-ssh": "ServerSSHAllowed",
|
"allow-server-ssh": "ServerSSHAllowed",
|
||||||
|
"allow-server-vnc": "ServerVNCAllowed",
|
||||||
|
"disable-vnc-auth": "DisableVNCAuth",
|
||||||
"interface-name": "InterfaceName",
|
"interface-name": "InterfaceName",
|
||||||
"wireguard-port": "WireguardPort",
|
"wireguard-port": "WireguardPort",
|
||||||
"preshared-key": "OptionalPreSharedKey",
|
"preshared-key": "OptionalPreSharedKey",
|
||||||
|
|||||||
@@ -2,18 +2,13 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/sleep"
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
const envDisableSleepDetector = "NB_DISABLE_SLEEP_DETECTOR"
|
|
||||||
|
|
||||||
// serverAgent adapts Server to the handler.Agent and handler.StatusChecker interfaces
|
// serverAgent adapts Server to the handler.Agent and handler.StatusChecker interfaces
|
||||||
type serverAgent struct {
|
type serverAgent struct {
|
||||||
s *Server
|
s *Server
|
||||||
@@ -33,61 +28,19 @@ func (a *serverAgent) Status() (internal.StatusType, error) {
|
|||||||
return internal.CtxGetState(a.s.rootCtx).Status()
|
return internal.CtxGetState(a.s.rootCtx).Status()
|
||||||
}
|
}
|
||||||
|
|
||||||
// startSleepDetector starts the OS sleep/wake detector and forwards events to
|
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
|
||||||
// the sleep handler. On platforms without a supported detector the attempt
|
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
|
||||||
// logs a warning and returns. Setting NB_DISABLE_SLEEP_DETECTOR=true skips
|
switch req.GetType() {
|
||||||
// registration entirely.
|
case proto.OSLifecycleRequest_WAKEUP:
|
||||||
func (s *Server) startSleepDetector() {
|
if err := s.sleepHandler.HandleWakeUp(callerCtx); err != nil {
|
||||||
if sleepDetectorDisabled() {
|
return &proto.OSLifecycleResponse{}, err
|
||||||
log.Info("sleep detection disabled via " + envDisableSleepDetector)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
svc, err := sleep.New()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to initialize sleep detection: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = svc.Register(func(event sleep.EventType) {
|
|
||||||
switch event {
|
|
||||||
case sleep.EventTypeSleep:
|
|
||||||
log.Info("handling sleep event")
|
|
||||||
if err := s.sleepHandler.HandleSleep(s.rootCtx); err != nil {
|
|
||||||
log.Errorf("failed to handle sleep event: %v", err)
|
|
||||||
}
|
|
||||||
case sleep.EventTypeWakeUp:
|
|
||||||
log.Info("handling wakeup event")
|
|
||||||
if err := s.sleepHandler.HandleWakeUp(s.rootCtx); err != nil {
|
|
||||||
log.Errorf("failed to handle wakeup event: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
case proto.OSLifecycleRequest_SLEEP:
|
||||||
if err != nil {
|
if err := s.sleepHandler.HandleSleep(callerCtx); err != nil {
|
||||||
log.Errorf("failed to register sleep detector: %v", err)
|
return &proto.OSLifecycleResponse{}, err
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info("sleep detection service initialized")
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
<-s.rootCtx.Done()
|
|
||||||
log.Info("stopping sleep event listener")
|
|
||||||
if err := svc.Deregister(); err != nil {
|
|
||||||
log.Errorf("failed to deregister sleep detector: %v", err)
|
|
||||||
}
|
}
|
||||||
}()
|
default:
|
||||||
}
|
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
|
||||||
|
|
||||||
func sleepDetectorDisabled() bool {
|
|
||||||
val := os.Getenv(envDisableSleepDetector)
|
|
||||||
if val == "" {
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
disabled, err := strconv.ParseBool(val)
|
return &proto.OSLifecycleResponse{}, nil
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to parse %s=%q: %v", envDisableSleepDetector, val, err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return disabled
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -200,8 +200,8 @@ func newLsaString(s string) lsaString {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateS4UUserToken creates a Windows token using S4U authentication
|
// generateS4UUserToken creates a Windows token using S4U authentication.
|
||||||
// This is the exact approach OpenSSH for Windows uses for public key authentication
|
// This is the same approach OpenSSH for Windows uses for public key authentication.
|
||||||
func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
|
func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
|
||||||
userCpn := buildUserCpn(username, domain)
|
userCpn := buildUserCpn(username, domain)
|
||||||
|
|
||||||
|
|||||||
@@ -507,27 +507,7 @@ func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
|
|||||||
maxTokenAge = DefaultJWTMaxTokenAge
|
maxTokenAge = DefaultJWTMaxTokenAge
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
return jwt.CheckTokenAge(token, time.Duration(maxTokenAge)*time.Second)
|
||||||
if !ok {
|
|
||||||
userID := extractUserID(token)
|
|
||||||
return fmt.Errorf("token has invalid claims format (user=%s)", userID)
|
|
||||||
}
|
|
||||||
|
|
||||||
iat, ok := claims["iat"].(float64)
|
|
||||||
if !ok {
|
|
||||||
userID := extractUserID(token)
|
|
||||||
return fmt.Errorf("token missing iat claim (user=%s)", userID)
|
|
||||||
}
|
|
||||||
|
|
||||||
issuedAt := time.Unix(int64(iat), 0)
|
|
||||||
tokenAge := time.Since(issuedAt)
|
|
||||||
maxAge := time.Duration(maxTokenAge) * time.Second
|
|
||||||
if tokenAge > maxAge {
|
|
||||||
userID := getUserIDFromClaims(claims)
|
|
||||||
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) {
|
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) {
|
||||||
@@ -558,27 +538,7 @@ func (s *Server) hasSSHAccess(userAuth *auth.UserAuth) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func extractUserID(token *gojwt.Token) string {
|
func extractUserID(token *gojwt.Token) string {
|
||||||
if token == nil {
|
return jwt.UserIDFromToken(token)
|
||||||
return "unknown"
|
|
||||||
}
|
|
||||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
|
||||||
if !ok {
|
|
||||||
return "unknown"
|
|
||||||
}
|
|
||||||
return getUserIDFromClaims(claims)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getUserIDFromClaims(claims gojwt.MapClaims) string {
|
|
||||||
if sub, ok := claims["sub"].(string); ok && sub != "" {
|
|
||||||
return sub
|
|
||||||
}
|
|
||||||
if userID, ok := claims["user_id"].(string); ok && userID != "" {
|
|
||||||
return userID
|
|
||||||
}
|
|
||||||
if email, ok := claims["email"].(string); ok && email != "" {
|
|
||||||
return email
|
|
||||||
}
|
|
||||||
return "unknown"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {
|
func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {
|
||||||
|
|||||||
@@ -130,6 +130,10 @@ type SSHServerStateOutput struct {
|
|||||||
Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"`
|
Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type VNCServerStateOutput struct {
|
||||||
|
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
type OutputOverview struct {
|
type OutputOverview struct {
|
||||||
Peers PeersStateOutput `json:"peers" yaml:"peers"`
|
Peers PeersStateOutput `json:"peers" yaml:"peers"`
|
||||||
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
||||||
@@ -151,6 +155,7 @@ type OutputOverview struct {
|
|||||||
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
|
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
|
||||||
ProfileName string `json:"profileName" yaml:"profileName"`
|
ProfileName string `json:"profileName" yaml:"profileName"`
|
||||||
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
|
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
|
||||||
|
VNCServerState VNCServerStateOutput `json:"vncServer" yaml:"vncServer"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertToStatusOutputOverview converts protobuf status to the output overview.
|
// ConvertToStatusOutputOverview converts protobuf status to the output overview.
|
||||||
@@ -171,6 +176,9 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
|
|||||||
|
|
||||||
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
||||||
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
|
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
|
||||||
|
vncServerOverview := VNCServerStateOutput{
|
||||||
|
Enabled: pbFullStatus.GetVncServerState().GetEnabled(),
|
||||||
|
}
|
||||||
peersOverview := mapPeers(pbFullStatus.GetPeers(), opts.StatusFilter, opts.PrefixNamesFilter, opts.PrefixNamesFilterMap, opts.IPsFilter, opts.ConnectionTypeFilter)
|
peersOverview := mapPeers(pbFullStatus.GetPeers(), opts.StatusFilter, opts.PrefixNamesFilter, opts.PrefixNamesFilterMap, opts.IPsFilter, opts.ConnectionTypeFilter)
|
||||||
|
|
||||||
overview := OutputOverview{
|
overview := OutputOverview{
|
||||||
@@ -194,6 +202,7 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
|
|||||||
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
|
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
|
||||||
ProfileName: opts.ProfileName,
|
ProfileName: opts.ProfileName,
|
||||||
SSHServerState: sshServerOverview,
|
SSHServerState: sshServerOverview,
|
||||||
|
VNCServerState: vncServerOverview,
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.Anonymize {
|
if opts.Anonymize {
|
||||||
@@ -524,6 +533,11 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
vncServerStatus := "Disabled"
|
||||||
|
if o.VNCServerState.Enabled {
|
||||||
|
vncServerStatus = "Enabled"
|
||||||
|
}
|
||||||
|
|
||||||
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
|
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
|
||||||
|
|
||||||
var forwardingRulesString string
|
var forwardingRulesString string
|
||||||
@@ -553,6 +567,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
|||||||
"Quantum resistance: %s\n"+
|
"Quantum resistance: %s\n"+
|
||||||
"Lazy connection: %s\n"+
|
"Lazy connection: %s\n"+
|
||||||
"SSH Server: %s\n"+
|
"SSH Server: %s\n"+
|
||||||
|
"VNC Server: %s\n"+
|
||||||
"Networks: %s\n"+
|
"Networks: %s\n"+
|
||||||
"%s"+
|
"%s"+
|
||||||
"Peers count: %s\n",
|
"Peers count: %s\n",
|
||||||
@@ -570,6 +585,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
|||||||
rosenpassEnabledStatus,
|
rosenpassEnabledStatus,
|
||||||
lazyConnectionEnabledStatus,
|
lazyConnectionEnabledStatus,
|
||||||
sshServerStatus,
|
sshServerStatus,
|
||||||
|
vncServerStatus,
|
||||||
networks,
|
networks,
|
||||||
forwardingRulesString,
|
forwardingRulesString,
|
||||||
peersCountString,
|
peersCountString,
|
||||||
|
|||||||
@@ -398,6 +398,9 @@ func TestParsingToJSON(t *testing.T) {
|
|||||||
"sshServer":{
|
"sshServer":{
|
||||||
"enabled":false,
|
"enabled":false,
|
||||||
"sessions":[]
|
"sessions":[]
|
||||||
|
},
|
||||||
|
"vncServer":{
|
||||||
|
"enabled":false
|
||||||
}
|
}
|
||||||
}`
|
}`
|
||||||
// @formatter:on
|
// @formatter:on
|
||||||
@@ -505,6 +508,8 @@ profileName: ""
|
|||||||
sshServer:
|
sshServer:
|
||||||
enabled: false
|
enabled: false
|
||||||
sessions: []
|
sessions: []
|
||||||
|
vncServer:
|
||||||
|
enabled: false
|
||||||
`
|
`
|
||||||
|
|
||||||
assert.Equal(t, expectedYAML, yaml)
|
assert.Equal(t, expectedYAML, yaml)
|
||||||
@@ -572,6 +577,7 @@ Interface type: Kernel
|
|||||||
Quantum resistance: false
|
Quantum resistance: false
|
||||||
Lazy connection: false
|
Lazy connection: false
|
||||||
SSH Server: Disabled
|
SSH Server: Disabled
|
||||||
|
VNC Server: Disabled
|
||||||
Networks: 10.10.0.0/24
|
Networks: 10.10.0.0/24
|
||||||
Peers count: 2/2 Connected
|
Peers count: 2/2 Connected
|
||||||
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
|
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
|
||||||
@@ -596,6 +602,7 @@ Interface type: Kernel
|
|||||||
Quantum resistance: false
|
Quantum resistance: false
|
||||||
Lazy connection: false
|
Lazy connection: false
|
||||||
SSH Server: Disabled
|
SSH Server: Disabled
|
||||||
|
VNC Server: Disabled
|
||||||
Networks: 10.10.0.0/24
|
Networks: 10.10.0.0/24
|
||||||
Peers count: 2/2 Connected
|
Peers count: 2/2 Connected
|
||||||
`
|
`
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ type Info struct {
|
|||||||
RosenpassEnabled bool
|
RosenpassEnabled bool
|
||||||
RosenpassPermissive bool
|
RosenpassPermissive bool
|
||||||
ServerSSHAllowed bool
|
ServerSSHAllowed bool
|
||||||
|
ServerVNCAllowed bool
|
||||||
|
|
||||||
DisableClientRoutes bool
|
DisableClientRoutes bool
|
||||||
DisableServerRoutes bool
|
DisableServerRoutes bool
|
||||||
@@ -77,21 +78,27 @@ type Info struct {
|
|||||||
EnableSSHLocalPortForwarding bool
|
EnableSSHLocalPortForwarding bool
|
||||||
EnableSSHRemotePortForwarding bool
|
EnableSSHRemotePortForwarding bool
|
||||||
DisableSSHAuth bool
|
DisableSSHAuth bool
|
||||||
|
DisableVNCAuth bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Info) SetFlags(
|
func (i *Info) SetFlags(
|
||||||
rosenpassEnabled, rosenpassPermissive bool,
|
rosenpassEnabled, rosenpassPermissive bool,
|
||||||
serverSSHAllowed *bool,
|
serverSSHAllowed *bool,
|
||||||
|
serverVNCAllowed *bool,
|
||||||
disableClientRoutes, disableServerRoutes,
|
disableClientRoutes, disableServerRoutes,
|
||||||
disableDNS, disableFirewall, blockLANAccess, blockInbound, lazyConnectionEnabled bool,
|
disableDNS, disableFirewall, blockLANAccess, blockInbound, lazyConnectionEnabled bool,
|
||||||
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
|
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
|
||||||
disableSSHAuth *bool,
|
disableSSHAuth *bool,
|
||||||
|
disableVNCAuth *bool,
|
||||||
) {
|
) {
|
||||||
i.RosenpassEnabled = rosenpassEnabled
|
i.RosenpassEnabled = rosenpassEnabled
|
||||||
i.RosenpassPermissive = rosenpassPermissive
|
i.RosenpassPermissive = rosenpassPermissive
|
||||||
if serverSSHAllowed != nil {
|
if serverSSHAllowed != nil {
|
||||||
i.ServerSSHAllowed = *serverSSHAllowed
|
i.ServerSSHAllowed = *serverSSHAllowed
|
||||||
}
|
}
|
||||||
|
if serverVNCAllowed != nil {
|
||||||
|
i.ServerVNCAllowed = *serverVNCAllowed
|
||||||
|
}
|
||||||
|
|
||||||
i.DisableClientRoutes = disableClientRoutes
|
i.DisableClientRoutes = disableClientRoutes
|
||||||
i.DisableServerRoutes = disableServerRoutes
|
i.DisableServerRoutes = disableServerRoutes
|
||||||
@@ -117,6 +124,9 @@ func (i *Info) SetFlags(
|
|||||||
if disableSSHAuth != nil {
|
if disableSSHAuth != nil {
|
||||||
i.DisableSSHAuth = *disableSSHAuth
|
i.DisableSSHAuth = *disableSSHAuth
|
||||||
}
|
}
|
||||||
|
if disableVNCAuth != nil {
|
||||||
|
i.DisableVNCAuth = *disableVNCAuth
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
||||||
|
|||||||
@@ -38,10 +38,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/sleep"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/ui/desktop"
|
"github.com/netbirdio/netbird/client/ui/desktop"
|
||||||
"github.com/netbirdio/netbird/client/ui/event"
|
"github.com/netbirdio/netbird/client/ui/event"
|
||||||
"github.com/netbirdio/netbird/client/ui/notifier"
|
|
||||||
"github.com/netbirdio/netbird/client/ui/process"
|
"github.com/netbirdio/netbird/client/ui/process"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
|
|
||||||
@@ -260,7 +260,6 @@ type serviceClient struct {
|
|||||||
|
|
||||||
// application with main windows.
|
// application with main windows.
|
||||||
app fyne.App
|
app fyne.App
|
||||||
notifier notifier.Notifier
|
|
||||||
wSettings fyne.Window
|
wSettings fyne.Window
|
||||||
showAdvancedSettings bool
|
showAdvancedSettings bool
|
||||||
sendNotification bool
|
sendNotification bool
|
||||||
@@ -365,7 +364,6 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
|
|||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
addr: args.addr,
|
addr: args.addr,
|
||||||
app: args.app,
|
app: args.app,
|
||||||
notifier: notifier.New(args.app),
|
|
||||||
logFile: args.logFile,
|
logFile: args.logFile,
|
||||||
sendNotification: false,
|
sendNotification: false,
|
||||||
|
|
||||||
@@ -894,7 +892,7 @@ func (s *serviceClient) updateStatus() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("get service status: %v", err)
|
log.Errorf("get service status: %v", err)
|
||||||
if s.connected {
|
if s.connected {
|
||||||
s.notifier.Send("Error", "Connection to service lost")
|
s.app.SendNotification(fyne.NewNotification("Error", "Connection to service lost"))
|
||||||
}
|
}
|
||||||
s.setDisconnectedStatus()
|
s.setDisconnectedStatus()
|
||||||
return err
|
return err
|
||||||
@@ -1111,7 +1109,7 @@ func (s *serviceClient) onTrayReady() {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
s.eventManager = event.NewManager(s.notifier, s.addr)
|
s.eventManager = event.NewManager(s.app, s.addr)
|
||||||
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
|
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
|
||||||
s.eventManager.AddHandler(func(event *proto.SystemEvent) {
|
s.eventManager.AddHandler(func(event *proto.SystemEvent) {
|
||||||
if event.Category == proto.SystemEvent_SYSTEM {
|
if event.Category == proto.SystemEvent_SYSTEM {
|
||||||
@@ -1148,6 +1146,9 @@ func (s *serviceClient) onTrayReady() {
|
|||||||
|
|
||||||
go s.eventManager.Start(s.ctx)
|
go s.eventManager.Start(s.ctx)
|
||||||
go s.eventHandler.listen(s.ctx)
|
go s.eventHandler.listen(s.ctx)
|
||||||
|
|
||||||
|
// Start sleep detection listener
|
||||||
|
go s.startSleepListener()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File {
|
func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File {
|
||||||
@@ -1208,6 +1209,62 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService
|
|||||||
return s.conn, nil
|
return s.conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// startSleepListener initializes the sleep detection service and listens for sleep events
|
||||||
|
func (s *serviceClient) startSleepListener() {
|
||||||
|
sleepService, err := sleep.New()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("%v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sleepService.Register(s.handleSleepEvents); err != nil {
|
||||||
|
log.Errorf("failed to start sleep detection: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("sleep detection service initialized")
|
||||||
|
|
||||||
|
// Cleanup on context cancellation
|
||||||
|
go func() {
|
||||||
|
<-s.ctx.Done()
|
||||||
|
log.Info("stopping sleep event listener")
|
||||||
|
if err := sleepService.Deregister(); err != nil {
|
||||||
|
log.Errorf("failed to deregister sleep detection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleSleepEvents sends a sleep notification to the daemon via gRPC
|
||||||
|
func (s *serviceClient) handleSleepEvents(event sleep.EventType) {
|
||||||
|
conn, err := s.getSrvClient(0)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to get daemon client for sleep notification: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &proto.OSLifecycleRequest{}
|
||||||
|
|
||||||
|
switch event {
|
||||||
|
case sleep.EventTypeWakeUp:
|
||||||
|
log.Infof("handle wakeup event: %v", event)
|
||||||
|
req.Type = proto.OSLifecycleRequest_WAKEUP
|
||||||
|
case sleep.EventTypeSleep:
|
||||||
|
log.Infof("handle sleep event: %v", event)
|
||||||
|
req.Type = proto.OSLifecycleRequest_SLEEP
|
||||||
|
default:
|
||||||
|
log.Infof("unknown event: %v", event)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = conn.NotifyOSLifecycle(s.ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to notify daemon about os lifecycle notification: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("successfully notified daemon about os lifecycle")
|
||||||
|
}
|
||||||
|
|
||||||
// setSettingsEnabled enables or disables the settings menu based on the provided state
|
// setSettingsEnabled enables or disables the settings menu based on the provided state
|
||||||
func (s *serviceClient) setSettingsEnabled(enabled bool) {
|
func (s *serviceClient) setSettingsEnabled(enabled bool) {
|
||||||
if s.mSettings != nil {
|
if s.mSettings != nil {
|
||||||
@@ -1491,7 +1548,7 @@ func (s *serviceClient) onUpdateAvailable(newVersion string, enforced bool) {
|
|||||||
|
|
||||||
if enforced && s.lastNotifiedVersion != newVersion {
|
if enforced && s.lastNotifiedVersion != newVersion {
|
||||||
s.lastNotifiedVersion = newVersion
|
s.lastNotifiedVersion = newVersion
|
||||||
s.notifier.Send("Update available", "A new version "+newVersion+" is ready to install")
|
s.app.SendNotification(fyne.NewNotification("Update available", "A new version "+newVersion+" is ready to install"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"fyne.io/fyne/v2"
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
@@ -17,17 +18,11 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/ui/desktop"
|
"github.com/netbirdio/netbird/client/ui/desktop"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Notifier sends desktop notifications. Defined here so the event package
|
|
||||||
// does not depend on fyne or the platform-specific notifier implementation.
|
|
||||||
type Notifier interface {
|
|
||||||
Send(title, body string)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Handler func(*proto.SystemEvent)
|
type Handler func(*proto.SystemEvent)
|
||||||
|
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
notifier Notifier
|
app fyne.App
|
||||||
addr string
|
addr string
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
@@ -36,10 +31,10 @@ type Manager struct {
|
|||||||
handlers []Handler
|
handlers []Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(notifier Notifier, addr string) *Manager {
|
func NewManager(app fyne.App, addr string) *Manager {
|
||||||
return &Manager{
|
return &Manager{
|
||||||
notifier: notifier,
|
app: app,
|
||||||
addr: addr,
|
addr: addr,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,7 +114,7 @@ func (e *Manager) handleEvent(event *proto.SystemEvent) {
|
|||||||
if id != "" {
|
if id != "" {
|
||||||
body += fmt.Sprintf(" ID: %s", id)
|
body += fmt.Sprintf(" ID: %s", id)
|
||||||
}
|
}
|
||||||
e.notifier.Send(title, body)
|
e.app.SendNotification(fyne.NewNotification(title, body))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, handler := range handlers {
|
for _, handler := range handlers {
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
|
||||||
|
"fyne.io/fyne/v2"
|
||||||
"fyne.io/systray"
|
"fyne.io/systray"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
@@ -86,7 +87,7 @@ func (h *eventHandler) handleConnectClick() {
|
|||||||
if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) {
|
if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) {
|
||||||
log.Debugf("connect operation cancelled by user")
|
log.Debugf("connect operation cancelled by user")
|
||||||
} else {
|
} else {
|
||||||
h.client.notifier.Send("Error", "Failed to connect")
|
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect"))
|
||||||
log.Errorf("connect failed: %v", err)
|
log.Errorf("connect failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -111,7 +112,7 @@ func (h *eventHandler) handleDisconnectClick() {
|
|||||||
if err := h.client.menuDownClick(); err != nil {
|
if err := h.client.menuDownClick(); err != nil {
|
||||||
st, ok := status.FromError(err)
|
st, ok := status.FromError(err)
|
||||||
if !errors.Is(err, context.Canceled) && !(ok && st.Code() == codes.Canceled) {
|
if !errors.Is(err, context.Canceled) && !(ok && st.Code() == codes.Canceled) {
|
||||||
h.client.notifier.Send("Error", "Failed to disconnect")
|
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to disconnect"))
|
||||||
log.Errorf("disconnect failed: %v", err)
|
log.Errorf("disconnect failed: %v", err)
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("disconnect cancelled or already disconnecting")
|
log.Debugf("disconnect cancelled or already disconnecting")
|
||||||
@@ -129,7 +130,7 @@ func (h *eventHandler) handleAllowSSHClick() {
|
|||||||
if err := h.updateConfigWithErr(); err != nil {
|
if err := h.updateConfigWithErr(); err != nil {
|
||||||
h.toggleCheckbox(h.client.mAllowSSH) // revert checkbox state on error
|
h.toggleCheckbox(h.client.mAllowSSH) // revert checkbox state on error
|
||||||
log.Errorf("failed to update config: %v", err)
|
log.Errorf("failed to update config: %v", err)
|
||||||
h.client.notifier.Send("Error", "Failed to update SSH settings")
|
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update SSH settings"))
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -139,7 +140,7 @@ func (h *eventHandler) handleAutoConnectClick() {
|
|||||||
if err := h.updateConfigWithErr(); err != nil {
|
if err := h.updateConfigWithErr(); err != nil {
|
||||||
h.toggleCheckbox(h.client.mAutoConnect) // revert checkbox state on error
|
h.toggleCheckbox(h.client.mAutoConnect) // revert checkbox state on error
|
||||||
log.Errorf("failed to update config: %v", err)
|
log.Errorf("failed to update config: %v", err)
|
||||||
h.client.notifier.Send("Error", "Failed to update auto-connect settings")
|
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update auto-connect settings"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,7 +149,7 @@ func (h *eventHandler) handleRosenpassClick() {
|
|||||||
if err := h.updateConfigWithErr(); err != nil {
|
if err := h.updateConfigWithErr(); err != nil {
|
||||||
h.toggleCheckbox(h.client.mEnableRosenpass) // revert checkbox state on error
|
h.toggleCheckbox(h.client.mEnableRosenpass) // revert checkbox state on error
|
||||||
log.Errorf("failed to update config: %v", err)
|
log.Errorf("failed to update config: %v", err)
|
||||||
h.client.notifier.Send("Error", "Failed to update Rosenpass settings")
|
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update Rosenpass settings"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,7 +158,7 @@ func (h *eventHandler) handleLazyConnectionClick() {
|
|||||||
if err := h.updateConfigWithErr(); err != nil {
|
if err := h.updateConfigWithErr(); err != nil {
|
||||||
h.toggleCheckbox(h.client.mLazyConnEnabled) // revert checkbox state on error
|
h.toggleCheckbox(h.client.mLazyConnEnabled) // revert checkbox state on error
|
||||||
log.Errorf("failed to update config: %v", err)
|
log.Errorf("failed to update config: %v", err)
|
||||||
h.client.notifier.Send("Error", "Failed to update lazy connection settings")
|
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update lazy connection settings"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -166,7 +167,7 @@ func (h *eventHandler) handleBlockInboundClick() {
|
|||||||
if err := h.updateConfigWithErr(); err != nil {
|
if err := h.updateConfigWithErr(); err != nil {
|
||||||
h.toggleCheckbox(h.client.mBlockInbound) // revert checkbox state on error
|
h.toggleCheckbox(h.client.mBlockInbound) // revert checkbox state on error
|
||||||
log.Errorf("failed to update config: %v", err)
|
log.Errorf("failed to update config: %v", err)
|
||||||
h.client.notifier.Send("Error", "Failed to update block inbound settings")
|
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update block inbound settings"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -175,7 +176,7 @@ func (h *eventHandler) handleNotificationsClick() {
|
|||||||
if err := h.updateConfigWithErr(); err != nil {
|
if err := h.updateConfigWithErr(); err != nil {
|
||||||
h.toggleCheckbox(h.client.mNotifications) // revert checkbox state on error
|
h.toggleCheckbox(h.client.mNotifications) // revert checkbox state on error
|
||||||
log.Errorf("failed to update config: %v", err)
|
log.Errorf("failed to update config: %v", err)
|
||||||
h.client.notifier.Send("Error", "Failed to update notifications settings")
|
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update notifications settings"))
|
||||||
} else if h.client.eventManager != nil {
|
} else if h.client.eventManager != nil {
|
||||||
h.client.eventManager.SetNotificationsEnabled(h.client.mNotifications.Checked())
|
h.client.eventManager.SetNotificationsEnabled(h.client.mNotifications.Checked())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,27 +0,0 @@
|
|||||||
// Package notifier sends desktop notifications. On Windows it uses the WinRT
|
|
||||||
// COM API directly via go-toast/v2 to avoid the PowerShell window flash that
|
|
||||||
// fyne's default implementation produces. On other platforms it delegates to
|
|
||||||
// fyne.
|
|
||||||
package notifier
|
|
||||||
|
|
||||||
import "fyne.io/fyne/v2"
|
|
||||||
|
|
||||||
// Notifier sends desktop notifications.
|
|
||||||
type Notifier interface {
|
|
||||||
Send(title, body string)
|
|
||||||
}
|
|
||||||
|
|
||||||
// New returns a platform-specific Notifier. The fyne app is used as the
|
|
||||||
// fallback notifier on platforms where no native implementation is wired up,
|
|
||||||
// and on Windows when the COM path fails to initialize.
|
|
||||||
func New(app fyne.App) Notifier {
|
|
||||||
return newNotifier(app)
|
|
||||||
}
|
|
||||||
|
|
||||||
type fyneNotifier struct {
|
|
||||||
app fyne.App
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fyneNotifier) Send(title, body string) {
|
|
||||||
f.app.SendNotification(fyne.NewNotification(title, body))
|
|
||||||
}
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
//go:build !windows
|
|
||||||
|
|
||||||
package notifier
|
|
||||||
|
|
||||||
import "fyne.io/fyne/v2"
|
|
||||||
|
|
||||||
func newNotifier(app fyne.App) Notifier {
|
|
||||||
return &fyneNotifier{app: app}
|
|
||||||
}
|
|
||||||
@@ -1,88 +0,0 @@
|
|||||||
package notifier
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"fyne.io/fyne/v2"
|
|
||||||
toast "git.sr.ht/~jackmordaunt/go-toast/v2"
|
|
||||||
"git.sr.ht/~jackmordaunt/go-toast/v2/wintoast"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// appID is the AppUserModelID shown in the Windows Action Center. It
|
|
||||||
// must match the System.AppUserModel.ID property set on the Start Menu
|
|
||||||
// shortcut by the MSI (see client/netbird.wxs); otherwise Windows
|
|
||||||
// groups toasts under a separate, unbranded entry.
|
|
||||||
appID = "NetBird"
|
|
||||||
|
|
||||||
// appGUID identifies the COM activation callback class. Generated once
|
|
||||||
// for NetBird; do not change without coordinating an installer bump,
|
|
||||||
// since old registry entries pointing at the previous GUID would orphan.
|
|
||||||
appGUID = "{0E1B4DE7-E148-432B-9814-544F941826EC}"
|
|
||||||
)
|
|
||||||
|
|
||||||
type comNotifier struct {
|
|
||||||
fallback *fyneNotifier
|
|
||||||
ready bool
|
|
||||||
iconPath string
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
initOnce sync.Once
|
|
||||||
initErr error
|
|
||||||
)
|
|
||||||
|
|
||||||
func newNotifier(app fyne.App) Notifier {
|
|
||||||
n := &comNotifier{
|
|
||||||
fallback: &fyneNotifier{app: app},
|
|
||||||
iconPath: resolveIcon(),
|
|
||||||
}
|
|
||||||
initOnce.Do(func() {
|
|
||||||
initErr = wintoast.SetAppData(wintoast.AppData{
|
|
||||||
AppID: appID,
|
|
||||||
GUID: appGUID,
|
|
||||||
IconPath: n.iconPath,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
if initErr != nil {
|
|
||||||
log.Warnf("toast: register app data failed, falling back to fyne notifications: %v", initErr)
|
|
||||||
return n.fallback
|
|
||||||
}
|
|
||||||
n.ready = true
|
|
||||||
return n
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *comNotifier) Send(title, body string) {
|
|
||||||
if !n.ready {
|
|
||||||
n.fallback.Send(title, body)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
notification := toast.Notification{
|
|
||||||
AppID: appID,
|
|
||||||
Title: title,
|
|
||||||
Body: body,
|
|
||||||
Icon: n.iconPath,
|
|
||||||
}
|
|
||||||
if err := notification.Push(); err != nil {
|
|
||||||
log.Warnf("toast: push failed, using fyne fallback: %v", err)
|
|
||||||
n.fallback.Send(title, body)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// resolveIcon returns an absolute path to the toast icon, or an empty string
|
|
||||||
// when no icon can be located. Windows requires a PNG/JPG for the
|
|
||||||
// AppUserModelId IconUri registry value; .ico is silently ignored.
|
|
||||||
func resolveIcon() string {
|
|
||||||
exe, err := os.Executable()
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
candidate := filepath.Join(filepath.Dir(exe), "netbird.png")
|
|
||||||
if _, err := os.Stat(candidate); err == nil {
|
|
||||||
return candidate
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
@@ -548,7 +548,7 @@ func (p *profileMenu) refresh() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to switch profile: %v", err)
|
log.Errorf("failed to switch profile: %v", err)
|
||||||
// show notification dialog
|
// show notification dialog
|
||||||
p.serviceClient.notifier.Send("Error", "Failed to switch profile")
|
p.app.SendNotification(fyne.NewNotification("Error", "Failed to switch profile"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -628,9 +628,9 @@ func (p *profileMenu) refresh() {
|
|||||||
}
|
}
|
||||||
if err := p.eventHandler.logout(p.ctx); err != nil {
|
if err := p.eventHandler.logout(p.ctx); err != nil {
|
||||||
log.Errorf("logout failed: %v", err)
|
log.Errorf("logout failed: %v", err)
|
||||||
p.serviceClient.notifier.Send("Error", "Failed to deregister")
|
p.app.SendNotification(fyne.NewNotification("Error", "Failed to deregister"))
|
||||||
} else {
|
} else {
|
||||||
p.serviceClient.notifier.Send("Success", "Deregistered successfully")
|
p.app.SendNotification(fyne.NewNotification("Success", "Deregistered successfully"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
474
client/vnc/server/agent_windows.go
Normal file
474
client/vnc/server/agent_windows.go
Normal file
@@ -0,0 +1,474 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
crand "crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
agentPort = "15900"
|
||||||
|
|
||||||
|
// agentTokenLen is the length of the random authentication token
|
||||||
|
// used to verify that connections to the agent come from the service.
|
||||||
|
agentTokenLen = 32
|
||||||
|
|
||||||
|
stillActive = 259
|
||||||
|
|
||||||
|
tokenPrimary = 1
|
||||||
|
securityImpersonation = 2
|
||||||
|
tokenSessionID = 12
|
||||||
|
|
||||||
|
createUnicodeEnvironment = 0x00000400
|
||||||
|
createNoWindow = 0x08000000
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||||
|
advapi32 = windows.NewLazySystemDLL("advapi32.dll")
|
||||||
|
userenv = windows.NewLazySystemDLL("userenv.dll")
|
||||||
|
|
||||||
|
procWTSGetActiveConsoleSessionId = kernel32.NewProc("WTSGetActiveConsoleSessionId")
|
||||||
|
procSetTokenInformation = advapi32.NewProc("SetTokenInformation")
|
||||||
|
procCreateEnvironmentBlock = userenv.NewProc("CreateEnvironmentBlock")
|
||||||
|
procDestroyEnvironmentBlock = userenv.NewProc("DestroyEnvironmentBlock")
|
||||||
|
|
||||||
|
wtsapi32 = windows.NewLazySystemDLL("wtsapi32.dll")
|
||||||
|
procWTSEnumerateSessionsW = wtsapi32.NewProc("WTSEnumerateSessionsW")
|
||||||
|
procWTSFreeMemory = wtsapi32.NewProc("WTSFreeMemory")
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetCurrentSessionID returns the session ID of the current process.
|
||||||
|
func GetCurrentSessionID() uint32 {
|
||||||
|
var token windows.Token
|
||||||
|
if err := windows.OpenProcessToken(windows.CurrentProcess(),
|
||||||
|
windows.TOKEN_QUERY, &token); err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
defer token.Close()
|
||||||
|
var id uint32
|
||||||
|
var ret uint32
|
||||||
|
_ = windows.GetTokenInformation(token, windows.TokenSessionId,
|
||||||
|
(*byte)(unsafe.Pointer(&id)), 4, &ret)
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func getConsoleSessionID() uint32 {
|
||||||
|
r, _, _ := procWTSGetActiveConsoleSessionId.Call()
|
||||||
|
return uint32(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
wtsActive = 0
|
||||||
|
wtsConnected = 1
|
||||||
|
wtsDisconnected = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
type wtsSessionInfo struct {
|
||||||
|
SessionID uint32
|
||||||
|
WinStationName [66]byte // actually *uint16, but we just need the struct size
|
||||||
|
State uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// getActiveSessionID returns the session ID of the best session to attach to.
|
||||||
|
// Prefers an active (logged-in, interactive) session over the console session.
|
||||||
|
// This avoids kicking out an RDP user when the console is at the login screen.
|
||||||
|
func getActiveSessionID() uint32 {
|
||||||
|
var sessionInfo uintptr
|
||||||
|
var count uint32
|
||||||
|
|
||||||
|
r, _, _ := procWTSEnumerateSessionsW.Call(
|
||||||
|
0, // WTS_CURRENT_SERVER_HANDLE
|
||||||
|
0, // reserved
|
||||||
|
1, // version
|
||||||
|
uintptr(unsafe.Pointer(&sessionInfo)),
|
||||||
|
uintptr(unsafe.Pointer(&count)),
|
||||||
|
)
|
||||||
|
if r == 0 || count == 0 {
|
||||||
|
return getConsoleSessionID()
|
||||||
|
}
|
||||||
|
defer procWTSFreeMemory.Call(sessionInfo)
|
||||||
|
|
||||||
|
type wtsSession struct {
|
||||||
|
SessionID uint32
|
||||||
|
Station *uint16
|
||||||
|
State uint32
|
||||||
|
}
|
||||||
|
sessions := unsafe.Slice((*wtsSession)(unsafe.Pointer(sessionInfo)), count)
|
||||||
|
|
||||||
|
// Find the first active session (not session 0, which is the services session).
|
||||||
|
var bestID uint32
|
||||||
|
found := false
|
||||||
|
for _, s := range sessions {
|
||||||
|
if s.SessionID == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if s.State == wtsActive {
|
||||||
|
bestID = s.SessionID
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
return getConsoleSessionID()
|
||||||
|
}
|
||||||
|
return bestID
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSystemTokenForSession duplicates the current SYSTEM token and sets its
|
||||||
|
// session ID so the spawned process runs in the target session. Using a SYSTEM
|
||||||
|
// token gives access to both Default and Winlogon desktops plus UIPI bypass.
|
||||||
|
func getSystemTokenForSession(sessionID uint32) (windows.Token, error) {
|
||||||
|
var cur windows.Token
|
||||||
|
if err := windows.OpenProcessToken(windows.CurrentProcess(),
|
||||||
|
windows.MAXIMUM_ALLOWED, &cur); err != nil {
|
||||||
|
return 0, fmt.Errorf("OpenProcessToken: %w", err)
|
||||||
|
}
|
||||||
|
defer cur.Close()
|
||||||
|
|
||||||
|
var dup windows.Token
|
||||||
|
if err := windows.DuplicateTokenEx(cur, windows.MAXIMUM_ALLOWED, nil,
|
||||||
|
securityImpersonation, tokenPrimary, &dup); err != nil {
|
||||||
|
return 0, fmt.Errorf("DuplicateTokenEx: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sid := sessionID
|
||||||
|
r, _, err := procSetTokenInformation.Call(
|
||||||
|
uintptr(dup),
|
||||||
|
uintptr(tokenSessionID),
|
||||||
|
uintptr(unsafe.Pointer(&sid)),
|
||||||
|
unsafe.Sizeof(sid),
|
||||||
|
)
|
||||||
|
if r == 0 {
|
||||||
|
dup.Close()
|
||||||
|
return 0, fmt.Errorf("SetTokenInformation(SessionId=%d): %w", sessionID, err)
|
||||||
|
}
|
||||||
|
return dup, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const agentTokenEnvVar = "NB_VNC_AGENT_TOKEN"
|
||||||
|
|
||||||
|
// injectEnvVar appends a KEY=VALUE entry to a Unicode environment block.
|
||||||
|
// The block is a sequence of null-terminated UTF-16 strings, terminated by
|
||||||
|
// an extra null. Returns a new block pointer with the entry added.
|
||||||
|
func injectEnvVar(envBlock uintptr, key, value string) uintptr {
|
||||||
|
entry := key + "=" + value
|
||||||
|
|
||||||
|
// Walk the existing block to find its total length.
|
||||||
|
ptr := (*uint16)(unsafe.Pointer(envBlock))
|
||||||
|
var totalChars int
|
||||||
|
for {
|
||||||
|
ch := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars)*2))
|
||||||
|
if ch == 0 {
|
||||||
|
// Check for double-null terminator.
|
||||||
|
next := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars+1)*2))
|
||||||
|
totalChars++
|
||||||
|
if next == 0 {
|
||||||
|
// End of block (don't count the final null yet, we'll rebuild).
|
||||||
|
break
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
totalChars++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
entryUTF16, _ := windows.UTF16FromString(entry)
|
||||||
|
// New block: existing entries + new entry (null-terminated) + final null.
|
||||||
|
newLen := totalChars + len(entryUTF16) + 1
|
||||||
|
newBlock := make([]uint16, newLen)
|
||||||
|
// Copy existing entries (up to but not including the final null).
|
||||||
|
for i := range totalChars {
|
||||||
|
newBlock[i] = *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(i)*2))
|
||||||
|
}
|
||||||
|
copy(newBlock[totalChars:], entryUTF16)
|
||||||
|
newBlock[newLen-1] = 0 // final null terminator
|
||||||
|
|
||||||
|
return uintptr(unsafe.Pointer(&newBlock[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func spawnAgentInSession(sessionID uint32, port string, authToken string) (windows.Handle, error) {
|
||||||
|
token, err := getSystemTokenForSession(sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("get SYSTEM token for session %d: %w", sessionID, err)
|
||||||
|
}
|
||||||
|
defer token.Close()
|
||||||
|
|
||||||
|
var envBlock uintptr
|
||||||
|
r, _, _ := procCreateEnvironmentBlock.Call(
|
||||||
|
uintptr(unsafe.Pointer(&envBlock)),
|
||||||
|
uintptr(token),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
if r != 0 {
|
||||||
|
defer procDestroyEnvironmentBlock.Call(envBlock)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inject the auth token into the environment block so it doesn't appear
|
||||||
|
// in the process command line (visible via tasklist/wmic).
|
||||||
|
if r != 0 {
|
||||||
|
envBlock = injectEnvVar(envBlock, agentTokenEnvVar, authToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
exePath, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("get executable path: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmdLine := fmt.Sprintf(`"%s" vnc-agent --port %s`, exePath, port)
|
||||||
|
cmdLineW, err := windows.UTF16PtrFromString(cmdLine)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("UTF16 cmdline: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an inheritable pipe for the agent's stderr so we can relog
|
||||||
|
// its output in the service process.
|
||||||
|
var sa windows.SecurityAttributes
|
||||||
|
sa.Length = uint32(unsafe.Sizeof(sa))
|
||||||
|
sa.InheritHandle = 1
|
||||||
|
|
||||||
|
var stderrRead, stderrWrite windows.Handle
|
||||||
|
if err := windows.CreatePipe(&stderrRead, &stderrWrite, &sa, 0); err != nil {
|
||||||
|
return 0, fmt.Errorf("create stderr pipe: %w", err)
|
||||||
|
}
|
||||||
|
// The read end must NOT be inherited by the child.
|
||||||
|
windows.SetHandleInformation(stderrRead, windows.HANDLE_FLAG_INHERIT, 0)
|
||||||
|
|
||||||
|
desktop, _ := windows.UTF16PtrFromString(`WinSta0\Default`)
|
||||||
|
si := windows.StartupInfo{
|
||||||
|
Cb: uint32(unsafe.Sizeof(windows.StartupInfo{})),
|
||||||
|
Desktop: desktop,
|
||||||
|
Flags: windows.STARTF_USESHOWWINDOW | windows.STARTF_USESTDHANDLES,
|
||||||
|
ShowWindow: 0,
|
||||||
|
StdErr: stderrWrite,
|
||||||
|
StdOutput: stderrWrite,
|
||||||
|
}
|
||||||
|
var pi windows.ProcessInformation
|
||||||
|
|
||||||
|
var envPtr *uint16
|
||||||
|
if envBlock != 0 {
|
||||||
|
envPtr = (*uint16)(unsafe.Pointer(envBlock))
|
||||||
|
}
|
||||||
|
|
||||||
|
err = windows.CreateProcessAsUser(
|
||||||
|
token, nil, cmdLineW,
|
||||||
|
nil, nil, true, // inheritHandles=true for the pipe
|
||||||
|
createUnicodeEnvironment|createNoWindow,
|
||||||
|
envPtr, nil, &si, &pi,
|
||||||
|
)
|
||||||
|
// Close the write end in the parent so reads will get EOF when the child exits.
|
||||||
|
windows.CloseHandle(stderrWrite)
|
||||||
|
if err != nil {
|
||||||
|
windows.CloseHandle(stderrRead)
|
||||||
|
return 0, fmt.Errorf("CreateProcessAsUser: %w", err)
|
||||||
|
}
|
||||||
|
windows.CloseHandle(pi.Thread)
|
||||||
|
|
||||||
|
// Relog agent output in the service with a [vnc-agent] prefix.
|
||||||
|
go relogAgentOutput(stderrRead)
|
||||||
|
|
||||||
|
log.Infof("spawned agent PID=%d in session %d on port %s", pi.ProcessId, sessionID, port)
|
||||||
|
return pi.Process, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sessionManager monitors the active console session and ensures a VNC agent
|
||||||
|
// process is running in it. When the session changes (e.g., user switch, RDP
|
||||||
|
// connect/disconnect), it kills the old agent and spawns a new one.
|
||||||
|
type sessionManager struct {
|
||||||
|
port string
|
||||||
|
mu sync.Mutex
|
||||||
|
agentProc windows.Handle
|
||||||
|
sessionID uint32
|
||||||
|
authToken string
|
||||||
|
done chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSessionManager(port string) *sessionManager {
|
||||||
|
return &sessionManager{port: port, sessionID: ^uint32(0), done: make(chan struct{})}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateAuthToken creates a new random hex token for agent authentication.
|
||||||
|
func generateAuthToken() string {
|
||||||
|
b := make([]byte, agentTokenLen)
|
||||||
|
if _, err := crand.Read(b); err != nil {
|
||||||
|
log.Warnf("generate agent auth token: %v", err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthToken returns the current agent authentication token.
|
||||||
|
func (m *sessionManager) AuthToken() string {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return m.authToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop signals the session manager to exit its polling loop.
|
||||||
|
func (m *sessionManager) Stop() {
|
||||||
|
select {
|
||||||
|
case <-m.done:
|
||||||
|
default:
|
||||||
|
close(m.done)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *sessionManager) run() {
|
||||||
|
ticker := time.NewTicker(2 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
sid := getActiveSessionID()
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
if sid != m.sessionID {
|
||||||
|
log.Infof("active session changed: %d -> %d", m.sessionID, sid)
|
||||||
|
m.killAgent()
|
||||||
|
m.sessionID = sid
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.agentProc != 0 {
|
||||||
|
var code uint32
|
||||||
|
_ = windows.GetExitCodeProcess(m.agentProc, &code)
|
||||||
|
if code != stillActive {
|
||||||
|
log.Infof("agent exited (code=%d), respawning", code)
|
||||||
|
windows.CloseHandle(m.agentProc)
|
||||||
|
m.agentProc = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.agentProc == 0 && sid != 0xFFFFFFFF {
|
||||||
|
m.authToken = generateAuthToken()
|
||||||
|
h, err := spawnAgentInSession(sid, m.port, m.authToken)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("spawn agent in session %d: %v", sid, err)
|
||||||
|
m.authToken = ""
|
||||||
|
} else {
|
||||||
|
m.agentProc = h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-m.done:
|
||||||
|
m.mu.Lock()
|
||||||
|
m.killAgent()
|
||||||
|
m.mu.Unlock()
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *sessionManager) killAgent() {
|
||||||
|
if m.agentProc != 0 {
|
||||||
|
_ = windows.TerminateProcess(m.agentProc, 0)
|
||||||
|
windows.CloseHandle(m.agentProc)
|
||||||
|
m.agentProc = 0
|
||||||
|
log.Info("killed old agent")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// relogAgentOutput reads JSON log lines from the agent's stderr pipe and
|
||||||
|
// relogs them at the correct level with the service's formatter.
|
||||||
|
func relogAgentOutput(pipe windows.Handle) {
|
||||||
|
defer windows.CloseHandle(pipe)
|
||||||
|
f := os.NewFile(uintptr(pipe), "vnc-agent-stderr")
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
entry := log.WithField("component", "vnc-agent")
|
||||||
|
dec := json.NewDecoder(f)
|
||||||
|
for dec.More() {
|
||||||
|
var m map[string]any
|
||||||
|
if err := dec.Decode(&m); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
msg, _ := m["msg"].(string)
|
||||||
|
if msg == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward extra fields from the agent (skip standard logrus fields).
|
||||||
|
// Remap "caller" to "source" so it doesn't conflict with logrus internals
|
||||||
|
// but still shows the original file/line from the agent process.
|
||||||
|
fields := make(log.Fields)
|
||||||
|
for k, v := range m {
|
||||||
|
switch k {
|
||||||
|
case "msg", "level", "time", "func":
|
||||||
|
continue
|
||||||
|
case "caller":
|
||||||
|
fields["source"] = v
|
||||||
|
default:
|
||||||
|
fields[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
e := entry.WithFields(fields)
|
||||||
|
|
||||||
|
switch m["level"] {
|
||||||
|
case "error":
|
||||||
|
e.Error(msg)
|
||||||
|
case "warning":
|
||||||
|
e.Warn(msg)
|
||||||
|
case "debug":
|
||||||
|
e.Debug(msg)
|
||||||
|
case "trace":
|
||||||
|
e.Trace(msg)
|
||||||
|
default:
|
||||||
|
e.Info(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// proxyToAgent connects to the agent, sends the auth token, then proxies
|
||||||
|
// the VNC client connection bidirectionally.
|
||||||
|
func proxyToAgent(client net.Conn, port string, authToken string) {
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
addr := "127.0.0.1:" + port
|
||||||
|
var agentConn net.Conn
|
||||||
|
var err error
|
||||||
|
for range 50 {
|
||||||
|
agentConn, err = net.DialTimeout("tcp", addr, time.Second)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("proxy cannot reach agent at %s: %v", addr, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer agentConn.Close()
|
||||||
|
|
||||||
|
// Send the auth token so the agent can verify this connection
|
||||||
|
// comes from the trusted service process.
|
||||||
|
tokenBytes, _ := hex.DecodeString(authToken)
|
||||||
|
if _, err := agentConn.Write(tokenBytes); err != nil {
|
||||||
|
log.Warnf("send auth token to agent: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("proxy connected to agent, starting bidirectional copy")
|
||||||
|
|
||||||
|
done := make(chan struct{}, 2)
|
||||||
|
cp := func(label string, dst, src net.Conn) {
|
||||||
|
n, err := io.Copy(dst, src)
|
||||||
|
log.Debugf("proxy %s: %d bytes, err=%v", label, n, err)
|
||||||
|
done <- struct{}{}
|
||||||
|
}
|
||||||
|
go cp("client→agent", agentConn, client)
|
||||||
|
go cp("agent→client", client, agentConn)
|
||||||
|
<-done
|
||||||
|
}
|
||||||
486
client/vnc/server/capture_darwin.go
Normal file
486
client/vnc/server/capture_darwin.go
Normal file
@@ -0,0 +1,486 @@
|
|||||||
|
//go:build darwin && !ios
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"hash/maphash"
|
||||||
|
"image"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/ebitengine/purego"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
var darwinCaptureOnce sync.Once
|
||||||
|
|
||||||
|
var (
|
||||||
|
cgMainDisplayID func() uint32
|
||||||
|
cgDisplayPixelsWide func(uint32) uintptr
|
||||||
|
cgDisplayPixelsHigh func(uint32) uintptr
|
||||||
|
cgDisplayCreateImage func(uint32) uintptr
|
||||||
|
cgImageGetWidth func(uintptr) uintptr
|
||||||
|
cgImageGetHeight func(uintptr) uintptr
|
||||||
|
cgImageGetBytesPerRow func(uintptr) uintptr
|
||||||
|
cgImageGetBitsPerPixel func(uintptr) uintptr
|
||||||
|
cgImageGetDataProvider func(uintptr) uintptr
|
||||||
|
cgDataProviderCopyData func(uintptr) uintptr
|
||||||
|
cgImageRelease func(uintptr)
|
||||||
|
cfDataGetLength func(uintptr) int64
|
||||||
|
cfDataGetBytePtr func(uintptr) uintptr
|
||||||
|
cfRelease func(uintptr)
|
||||||
|
cgPreflightScreenCaptureAccess func() bool
|
||||||
|
cgRequestScreenCaptureAccess func() bool
|
||||||
|
darwinCaptureReady bool
|
||||||
|
)
|
||||||
|
|
||||||
|
func initDarwinCapture() {
|
||||||
|
darwinCaptureOnce.Do(func() {
|
||||||
|
cg, err := purego.Dlopen("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("load CoreGraphics: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cf, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("load CoreFoundation: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
purego.RegisterLibFunc(&cgMainDisplayID, cg, "CGMainDisplayID")
|
||||||
|
purego.RegisterLibFunc(&cgDisplayPixelsWide, cg, "CGDisplayPixelsWide")
|
||||||
|
purego.RegisterLibFunc(&cgDisplayPixelsHigh, cg, "CGDisplayPixelsHigh")
|
||||||
|
purego.RegisterLibFunc(&cgDisplayCreateImage, cg, "CGDisplayCreateImage")
|
||||||
|
purego.RegisterLibFunc(&cgImageGetWidth, cg, "CGImageGetWidth")
|
||||||
|
purego.RegisterLibFunc(&cgImageGetHeight, cg, "CGImageGetHeight")
|
||||||
|
purego.RegisterLibFunc(&cgImageGetBytesPerRow, cg, "CGImageGetBytesPerRow")
|
||||||
|
purego.RegisterLibFunc(&cgImageGetBitsPerPixel, cg, "CGImageGetBitsPerPixel")
|
||||||
|
purego.RegisterLibFunc(&cgImageGetDataProvider, cg, "CGImageGetDataProvider")
|
||||||
|
purego.RegisterLibFunc(&cgDataProviderCopyData, cg, "CGDataProviderCopyData")
|
||||||
|
purego.RegisterLibFunc(&cgImageRelease, cg, "CGImageRelease")
|
||||||
|
purego.RegisterLibFunc(&cfDataGetLength, cf, "CFDataGetLength")
|
||||||
|
purego.RegisterLibFunc(&cfDataGetBytePtr, cf, "CFDataGetBytePtr")
|
||||||
|
purego.RegisterLibFunc(&cfRelease, cf, "CFRelease")
|
||||||
|
|
||||||
|
// Screen capture permission APIs (macOS 11+). Might not exist on older versions.
|
||||||
|
if sym, err := purego.Dlsym(cg, "CGPreflightScreenCaptureAccess"); err == nil {
|
||||||
|
purego.RegisterFunc(&cgPreflightScreenCaptureAccess, sym)
|
||||||
|
}
|
||||||
|
if sym, err := purego.Dlsym(cg, "CGRequestScreenCaptureAccess"); err == nil {
|
||||||
|
purego.RegisterFunc(&cgRequestScreenCaptureAccess, sym)
|
||||||
|
}
|
||||||
|
|
||||||
|
darwinCaptureReady = true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// errFrameUnchanged signals that the raw capture bytes matched the previous
|
||||||
|
// frame, so the caller can skip the expensive BGRA to RGBA conversion.
|
||||||
|
var errFrameUnchanged = errors.New("frame unchanged")
|
||||||
|
|
||||||
|
// CGCapturer captures the macOS main display using Core Graphics.
|
||||||
|
type CGCapturer struct {
|
||||||
|
displayID uint32
|
||||||
|
w, h int
|
||||||
|
// downscale is 1 for pixel-perfect, 2 for Retina 2:1 box-filter downscale.
|
||||||
|
downscale int
|
||||||
|
hashSeed maphash.Seed
|
||||||
|
lastHash uint64
|
||||||
|
hasHash bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCGCapturer creates a screen capturer for the main display.
|
||||||
|
func NewCGCapturer() (*CGCapturer, error) {
|
||||||
|
initDarwinCapture()
|
||||||
|
if !darwinCaptureReady {
|
||||||
|
return nil, fmt.Errorf("CoreGraphics not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request Screen Recording permission (shows system dialog on macOS 11+).
|
||||||
|
if cgPreflightScreenCaptureAccess != nil && !cgPreflightScreenCaptureAccess() {
|
||||||
|
if cgRequestScreenCaptureAccess != nil {
|
||||||
|
cgRequestScreenCaptureAccess()
|
||||||
|
}
|
||||||
|
openPrivacyPane("Privacy_ScreenCapture")
|
||||||
|
log.Warn("Screen Recording permission not granted. " +
|
||||||
|
"Opened System Settings > Privacy & Security > Screen Recording; enable netbird and restart.")
|
||||||
|
}
|
||||||
|
|
||||||
|
displayID := cgMainDisplayID()
|
||||||
|
c := &CGCapturer{displayID: displayID, downscale: 1, hashSeed: maphash.MakeSeed()}
|
||||||
|
|
||||||
|
// Probe actual pixel dimensions via a test capture. CGDisplayPixelsWide/High
|
||||||
|
// returns logical points on Retina, but CGDisplayCreateImage produces native
|
||||||
|
// pixels (often 2x), so probing the image is the only reliable source.
|
||||||
|
img, err := c.Capture()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("probe capture: %w", err)
|
||||||
|
}
|
||||||
|
nativeW := img.Rect.Dx()
|
||||||
|
nativeH := img.Rect.Dy()
|
||||||
|
c.hasHash = false
|
||||||
|
if nativeW == 0 || nativeH == 0 {
|
||||||
|
return nil, errors.New("display dimensions are zero")
|
||||||
|
}
|
||||||
|
|
||||||
|
logicalW := int(cgDisplayPixelsWide(displayID))
|
||||||
|
logicalH := int(cgDisplayPixelsHigh(displayID))
|
||||||
|
|
||||||
|
// Enable 2:1 downscale on Retina unless explicitly disabled. Cuts pixel
|
||||||
|
// count 4x, shrinking convert, diff, and wire data proportionally.
|
||||||
|
if !retinaDownscaleDisabled() && nativeW >= 2*logicalW && nativeH >= 2*logicalH && nativeW%2 == 0 && nativeH%2 == 0 {
|
||||||
|
c.downscale = 2
|
||||||
|
}
|
||||||
|
c.w = nativeW / c.downscale
|
||||||
|
c.h = nativeH / c.downscale
|
||||||
|
|
||||||
|
log.Infof("macOS capturer ready: %dx%d (native %dx%d, logical %dx%d, downscale=%d, display=%d)",
|
||||||
|
c.w, c.h, nativeW, nativeH, logicalW, logicalH, c.downscale, displayID)
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func retinaDownscaleDisabled() bool {
|
||||||
|
v := os.Getenv(EnvVNCDisableDownscale)
|
||||||
|
if v == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
disabled, err := strconv.ParseBool(v)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("parse %s: %v", EnvVNCDisableDownscale, err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return disabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// Width returns the screen width.
|
||||||
|
func (c *CGCapturer) Width() int { return c.w }
|
||||||
|
|
||||||
|
// Height returns the screen height.
|
||||||
|
func (c *CGCapturer) Height() int { return c.h }
|
||||||
|
|
||||||
|
// Capture returns the current screen as an RGBA image.
|
||||||
|
func (c *CGCapturer) Capture() (*image.RGBA, error) {
|
||||||
|
cgImage := cgDisplayCreateImage(c.displayID)
|
||||||
|
if cgImage == 0 {
|
||||||
|
return nil, fmt.Errorf("CGDisplayCreateImage returned nil (screen recording permission?)")
|
||||||
|
}
|
||||||
|
defer cgImageRelease(cgImage)
|
||||||
|
|
||||||
|
w := int(cgImageGetWidth(cgImage))
|
||||||
|
h := int(cgImageGetHeight(cgImage))
|
||||||
|
bytesPerRow := int(cgImageGetBytesPerRow(cgImage))
|
||||||
|
bpp := int(cgImageGetBitsPerPixel(cgImage))
|
||||||
|
|
||||||
|
provider := cgImageGetDataProvider(cgImage)
|
||||||
|
if provider == 0 {
|
||||||
|
return nil, fmt.Errorf("CGImageGetDataProvider returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
cfData := cgDataProviderCopyData(provider)
|
||||||
|
if cfData == 0 {
|
||||||
|
return nil, fmt.Errorf("CGDataProviderCopyData returned nil")
|
||||||
|
}
|
||||||
|
defer cfRelease(cfData)
|
||||||
|
|
||||||
|
dataLen := int(cfDataGetLength(cfData))
|
||||||
|
dataPtr := cfDataGetBytePtr(cfData)
|
||||||
|
if dataPtr == 0 || dataLen == 0 {
|
||||||
|
return nil, fmt.Errorf("empty image data")
|
||||||
|
}
|
||||||
|
|
||||||
|
src := unsafe.Slice((*byte)(unsafe.Pointer(dataPtr)), dataLen)
|
||||||
|
|
||||||
|
hash := maphash.Bytes(c.hashSeed, src)
|
||||||
|
if c.hasHash && hash == c.lastHash {
|
||||||
|
return nil, errFrameUnchanged
|
||||||
|
}
|
||||||
|
c.lastHash = hash
|
||||||
|
c.hasHash = true
|
||||||
|
|
||||||
|
ds := c.downscale
|
||||||
|
if ds < 1 {
|
||||||
|
ds = 1
|
||||||
|
}
|
||||||
|
outW := w / ds
|
||||||
|
outH := h / ds
|
||||||
|
img := image.NewRGBA(image.Rect(0, 0, outW, outH))
|
||||||
|
|
||||||
|
bytesPerPixel := bpp / 8
|
||||||
|
if bytesPerPixel == 4 && ds == 1 {
|
||||||
|
convertBGRAToRGBA(img.Pix, img.Stride, src, bytesPerRow, w, h)
|
||||||
|
} else if bytesPerPixel == 4 && ds == 2 {
|
||||||
|
convertBGRAToRGBADownscale2(img.Pix, img.Stride, src, bytesPerRow, outW, outH)
|
||||||
|
} else {
|
||||||
|
for row := 0; row < outH; row++ {
|
||||||
|
srcOff := row * ds * bytesPerRow
|
||||||
|
dstOff := row * img.Stride
|
||||||
|
for col := 0; col < outW; col++ {
|
||||||
|
si := srcOff + col*ds*bytesPerPixel
|
||||||
|
di := dstOff + col*4
|
||||||
|
img.Pix[di+0] = src[si+2]
|
||||||
|
img.Pix[di+1] = src[si+1]
|
||||||
|
img.Pix[di+2] = src[si+0]
|
||||||
|
img.Pix[di+3] = 0xff
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return img, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertBGRAToRGBADownscale2 averages every 2x2 BGRA block into one RGBA
|
||||||
|
// output pixel, parallelised across GOMAXPROCS cores. outW and outH are the
|
||||||
|
// destination dimensions (source is 2*outW by 2*outH).
|
||||||
|
func convertBGRAToRGBADownscale2(dst []byte, dstStride int, src []byte, srcStride, outW, outH int) {
|
||||||
|
workers := runtime.GOMAXPROCS(0)
|
||||||
|
if workers > outH {
|
||||||
|
workers = outH
|
||||||
|
}
|
||||||
|
if workers < 1 || outH < 32 {
|
||||||
|
workers = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
convertRows := func(y0, y1 int) {
|
||||||
|
for row := y0; row < y1; row++ {
|
||||||
|
srcRow0 := 2 * row * srcStride
|
||||||
|
srcRow1 := srcRow0 + srcStride
|
||||||
|
dstOff := row * dstStride
|
||||||
|
for col := 0; col < outW; col++ {
|
||||||
|
s0 := srcRow0 + col*8
|
||||||
|
s1 := srcRow1 + col*8
|
||||||
|
b := (uint32(src[s0]) + uint32(src[s0+4]) + uint32(src[s1]) + uint32(src[s1+4])) >> 2
|
||||||
|
g := (uint32(src[s0+1]) + uint32(src[s0+5]) + uint32(src[s1+1]) + uint32(src[s1+5])) >> 2
|
||||||
|
r := (uint32(src[s0+2]) + uint32(src[s0+6]) + uint32(src[s1+2]) + uint32(src[s1+6])) >> 2
|
||||||
|
di := dstOff + col*4
|
||||||
|
dst[di+0] = byte(r)
|
||||||
|
dst[di+1] = byte(g)
|
||||||
|
dst[di+2] = byte(b)
|
||||||
|
dst[di+3] = 0xff
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if workers == 1 {
|
||||||
|
convertRows(0, outH)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
chunk := (outH + workers - 1) / workers
|
||||||
|
for i := 0; i < workers; i++ {
|
||||||
|
y0 := i * chunk
|
||||||
|
y1 := y0 + chunk
|
||||||
|
if y1 > outH {
|
||||||
|
y1 = outH
|
||||||
|
}
|
||||||
|
if y0 >= y1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
go func(y0, y1 int) {
|
||||||
|
defer wg.Done()
|
||||||
|
convertRows(y0, y1)
|
||||||
|
}(y0, y1)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertBGRAToRGBA swaps R/B channels using uint32 word operations, and
|
||||||
|
// parallelises across GOMAXPROCS cores for large images.
|
||||||
|
func convertBGRAToRGBA(dst []byte, dstStride int, src []byte, srcStride, w, h int) {
|
||||||
|
workers := runtime.GOMAXPROCS(0)
|
||||||
|
if workers > h {
|
||||||
|
workers = h
|
||||||
|
}
|
||||||
|
if workers < 1 || h < 64 {
|
||||||
|
workers = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
convertRows := func(y0, y1 int) {
|
||||||
|
rowBytes := w * 4
|
||||||
|
for row := y0; row < y1; row++ {
|
||||||
|
dstRow := dst[row*dstStride : row*dstStride+rowBytes]
|
||||||
|
srcRow := src[row*srcStride : row*srcStride+rowBytes]
|
||||||
|
dstU := unsafe.Slice((*uint32)(unsafe.Pointer(&dstRow[0])), w)
|
||||||
|
srcU := unsafe.Slice((*uint32)(unsafe.Pointer(&srcRow[0])), w)
|
||||||
|
for i, p := range srcU {
|
||||||
|
dstU[i] = (p & 0xff00ff00) | ((p & 0x000000ff) << 16) | ((p & 0x00ff0000) >> 16) | 0xff000000
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if workers == 1 {
|
||||||
|
convertRows(0, h)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
chunk := (h + workers - 1) / workers
|
||||||
|
for i := 0; i < workers; i++ {
|
||||||
|
y0 := i * chunk
|
||||||
|
y1 := y0 + chunk
|
||||||
|
if y1 > h {
|
||||||
|
y1 = h
|
||||||
|
}
|
||||||
|
if y0 >= y1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
go func(y0, y1 int) {
|
||||||
|
defer wg.Done()
|
||||||
|
convertRows(y0, y1)
|
||||||
|
}(y0, y1)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MacPoller wraps CGCapturer in a continuous capture loop.
|
||||||
|
type MacPoller struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
frame *image.RGBA
|
||||||
|
w, h int
|
||||||
|
done chan struct{}
|
||||||
|
// wake shortens the init-retry backoff when a client is trying to connect,
|
||||||
|
// so granting Screen Recording mid-session takes effect immediately.
|
||||||
|
wake chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMacPoller creates a capturer that continuously grabs the macOS display.
|
||||||
|
func NewMacPoller() *MacPoller {
|
||||||
|
p := &MacPoller{
|
||||||
|
done: make(chan struct{}),
|
||||||
|
wake: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
go p.loop()
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wake pokes the init-retry loop so it doesn't wait out the full backoff
|
||||||
|
// before trying again. Safe to call from any goroutine; extra calls while a
|
||||||
|
// wake is pending are dropped.
|
||||||
|
func (p *MacPoller) Wake() {
|
||||||
|
select {
|
||||||
|
case p.wake <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the capture loop.
|
||||||
|
func (p *MacPoller) Close() {
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
default:
|
||||||
|
close(p.done)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Width returns the screen width.
|
||||||
|
func (p *MacPoller) Width() int {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
return p.w
|
||||||
|
}
|
||||||
|
|
||||||
|
// Height returns the screen height.
|
||||||
|
func (p *MacPoller) Height() int {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
return p.h
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capture returns the most recent frame.
|
||||||
|
func (p *MacPoller) Capture() (*image.RGBA, error) {
|
||||||
|
p.mu.Lock()
|
||||||
|
img := p.frame
|
||||||
|
p.mu.Unlock()
|
||||||
|
if img != nil {
|
||||||
|
return img, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("no frame available yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *MacPoller) loop() {
|
||||||
|
var capturer *CGCapturer
|
||||||
|
var initFails int
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if capturer == nil {
|
||||||
|
var err error
|
||||||
|
capturer, err = NewCGCapturer()
|
||||||
|
if err != nil {
|
||||||
|
initFails++
|
||||||
|
// Retry forever with backoff: the user may grant Screen
|
||||||
|
// Recording after the server started, and we need to pick it
|
||||||
|
// up whenever that happens.
|
||||||
|
delay := 2 * time.Second
|
||||||
|
if initFails > 15 {
|
||||||
|
delay = 30 * time.Second
|
||||||
|
} else if initFails > 5 {
|
||||||
|
delay = 10 * time.Second
|
||||||
|
}
|
||||||
|
if initFails == 1 || initFails%10 == 0 {
|
||||||
|
log.Warnf("macOS capturer: %v (attempt %d, retrying every %s)", err, initFails, delay)
|
||||||
|
} else {
|
||||||
|
log.Debugf("macOS capturer: %v (attempt %d)", err, initFails)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
return
|
||||||
|
case <-p.wake:
|
||||||
|
// Client is trying to connect, retry now.
|
||||||
|
case <-time.After(delay):
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
initFails = 0
|
||||||
|
p.mu.Lock()
|
||||||
|
p.w, p.h = capturer.Width(), capturer.Height()
|
||||||
|
p.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
img, err := capturer.Capture()
|
||||||
|
if errors.Is(err, errFrameUnchanged) {
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
return
|
||||||
|
case <-time.After(33 * time.Millisecond):
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("macOS capture: %v", err)
|
||||||
|
capturer = nil
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
return
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
p.mu.Lock()
|
||||||
|
p.frame = img
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
return
|
||||||
|
case <-time.After(33 * time.Millisecond): // ~30 fps
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ ScreenCapturer = (*MacPoller)(nil)
|
||||||
99
client/vnc/server/capture_dxgi_windows.go
Normal file
99
client/vnc/server/capture_dxgi_windows.go
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
|
||||||
|
"github.com/kirides/go-d3d/d3d11"
|
||||||
|
"github.com/kirides/go-d3d/outputduplication"
|
||||||
|
)
|
||||||
|
|
||||||
|
// dxgiCapturer captures the desktop using DXGI Desktop Duplication.
|
||||||
|
// Provides GPU-accelerated capture with native dirty rect tracking.
|
||||||
|
// Only works from the interactive user session, not Session 0.
|
||||||
|
//
|
||||||
|
// Uses a double-buffer: DXGI writes into img, then we copy to the current
|
||||||
|
// output buffer and hand it out. Alternating between two output buffers
|
||||||
|
// avoids allocating a new image.RGBA per frame (~8MB at 1080p, 30fps).
|
||||||
|
type dxgiCapturer struct {
|
||||||
|
dup *outputduplication.OutputDuplicator
|
||||||
|
device *d3d11.ID3D11Device
|
||||||
|
ctx *d3d11.ID3D11DeviceContext
|
||||||
|
img *image.RGBA
|
||||||
|
out [2]*image.RGBA
|
||||||
|
outIdx int
|
||||||
|
width int
|
||||||
|
height int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDXGICapturer() (*dxgiCapturer, error) {
|
||||||
|
device, deviceCtx, err := d3d11.NewD3D11Device()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create D3D11 device: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dup, err := outputduplication.NewIDXGIOutputDuplication(device, deviceCtx, 0)
|
||||||
|
if err != nil {
|
||||||
|
device.Release()
|
||||||
|
deviceCtx.Release()
|
||||||
|
return nil, fmt.Errorf("create output duplication: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w, h := screenSize()
|
||||||
|
if w == 0 || h == 0 {
|
||||||
|
dup.Release()
|
||||||
|
device.Release()
|
||||||
|
deviceCtx.Release()
|
||||||
|
return nil, fmt.Errorf("screen dimensions are zero")
|
||||||
|
}
|
||||||
|
|
||||||
|
rect := image.Rect(0, 0, w, h)
|
||||||
|
c := &dxgiCapturer{
|
||||||
|
dup: dup,
|
||||||
|
device: device,
|
||||||
|
ctx: deviceCtx,
|
||||||
|
img: image.NewRGBA(rect),
|
||||||
|
out: [2]*image.RGBA{image.NewRGBA(rect), image.NewRGBA(rect)},
|
||||||
|
width: w,
|
||||||
|
height: h,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Grab the initial frame with a longer timeout to ensure we have
|
||||||
|
// a valid image before returning.
|
||||||
|
_ = dup.GetImage(c.img, 2000)
|
||||||
|
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *dxgiCapturer) capture() (*image.RGBA, error) {
|
||||||
|
err := c.dup.GetImage(c.img, 100)
|
||||||
|
if err != nil && !errors.Is(err, outputduplication.ErrNoImageYet) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy into the next output buffer. The DesktopCapturer hands out the
|
||||||
|
// returned pointer to VNC sessions that read pixels concurrently, so we
|
||||||
|
// alternate between two pre-allocated buffers instead of allocating per frame.
|
||||||
|
out := c.out[c.outIdx]
|
||||||
|
c.outIdx ^= 1
|
||||||
|
copy(out.Pix, c.img.Pix)
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *dxgiCapturer) close() {
|
||||||
|
if c.dup != nil {
|
||||||
|
c.dup.Release()
|
||||||
|
c.dup = nil
|
||||||
|
}
|
||||||
|
if c.ctx != nil {
|
||||||
|
c.ctx.Release()
|
||||||
|
c.ctx = nil
|
||||||
|
}
|
||||||
|
if c.device != nil {
|
||||||
|
c.device.Release()
|
||||||
|
c.device = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
461
client/vnc/server/capture_windows.go
Normal file
461
client/vnc/server/capture_windows.go
Normal file
@@ -0,0 +1,461 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
gdi32 = windows.NewLazySystemDLL("gdi32.dll")
|
||||||
|
user32 = windows.NewLazySystemDLL("user32.dll")
|
||||||
|
|
||||||
|
procGetDC = user32.NewProc("GetDC")
|
||||||
|
procReleaseDC = user32.NewProc("ReleaseDC")
|
||||||
|
procCreateCompatDC = gdi32.NewProc("CreateCompatibleDC")
|
||||||
|
procCreateDIBSection = gdi32.NewProc("CreateDIBSection")
|
||||||
|
procSelectObject = gdi32.NewProc("SelectObject")
|
||||||
|
procDeleteObject = gdi32.NewProc("DeleteObject")
|
||||||
|
procDeleteDC = gdi32.NewProc("DeleteDC")
|
||||||
|
procBitBlt = gdi32.NewProc("BitBlt")
|
||||||
|
procGetSystemMetrics = user32.NewProc("GetSystemMetrics")
|
||||||
|
|
||||||
|
// Desktop switching for service/Session 0 capture.
|
||||||
|
procOpenInputDesktop = user32.NewProc("OpenInputDesktop")
|
||||||
|
procSetThreadDesktop = user32.NewProc("SetThreadDesktop")
|
||||||
|
procCloseDesktop = user32.NewProc("CloseDesktop")
|
||||||
|
procOpenWindowStation = user32.NewProc("OpenWindowStationW")
|
||||||
|
procSetProcessWindowStation = user32.NewProc("SetProcessWindowStation")
|
||||||
|
procCloseWindowStation = user32.NewProc("CloseWindowStation")
|
||||||
|
procGetUserObjectInformationW = user32.NewProc("GetUserObjectInformationW")
|
||||||
|
)
|
||||||
|
|
||||||
|
const uoiName = 2
|
||||||
|
|
||||||
|
const (
|
||||||
|
smCxScreen = 0
|
||||||
|
smCyScreen = 1
|
||||||
|
srccopy = 0x00CC0020
|
||||||
|
dibRgbColors = 0
|
||||||
|
)
|
||||||
|
|
||||||
|
type bitmapInfoHeader struct {
|
||||||
|
Size uint32
|
||||||
|
Width int32
|
||||||
|
Height int32
|
||||||
|
Planes uint16
|
||||||
|
BitCount uint16
|
||||||
|
Compression uint32
|
||||||
|
SizeImage uint32
|
||||||
|
XPelsPerMeter int32
|
||||||
|
YPelsPerMeter int32
|
||||||
|
ClrUsed uint32
|
||||||
|
ClrImportant uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type bitmapInfo struct {
|
||||||
|
Header bitmapInfoHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupInteractiveWindowStation associates the current process with WinSta0,
|
||||||
|
// the interactive window station. This is required for a SYSTEM service in
|
||||||
|
// Session 0 to call OpenInputDesktop for screen capture and input injection.
|
||||||
|
func setupInteractiveWindowStation() error {
|
||||||
|
name, err := windows.UTF16PtrFromString("WinSta0")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("UTF16 WinSta0: %w", err)
|
||||||
|
}
|
||||||
|
hWinSta, _, err := procOpenWindowStation.Call(
|
||||||
|
uintptr(unsafe.Pointer(name)),
|
||||||
|
0,
|
||||||
|
uintptr(windows.MAXIMUM_ALLOWED),
|
||||||
|
)
|
||||||
|
if hWinSta == 0 {
|
||||||
|
return fmt.Errorf("OpenWindowStation(WinSta0): %w", err)
|
||||||
|
}
|
||||||
|
r, _, err := procSetProcessWindowStation.Call(hWinSta)
|
||||||
|
if r == 0 {
|
||||||
|
procCloseWindowStation.Call(hWinSta)
|
||||||
|
return fmt.Errorf("SetProcessWindowStation: %w", err)
|
||||||
|
}
|
||||||
|
log.Info("process window station set to WinSta0 (interactive)")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func screenSize() (int, int) {
|
||||||
|
w, _, _ := procGetSystemMetrics.Call(uintptr(smCxScreen))
|
||||||
|
h, _, _ := procGetSystemMetrics.Call(uintptr(smCyScreen))
|
||||||
|
return int(w), int(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getDesktopName(hDesk uintptr) string {
|
||||||
|
var buf [256]uint16
|
||||||
|
var needed uint32
|
||||||
|
procGetUserObjectInformationW.Call(hDesk, uoiName,
|
||||||
|
uintptr(unsafe.Pointer(&buf[0])), 512,
|
||||||
|
uintptr(unsafe.Pointer(&needed)))
|
||||||
|
return windows.UTF16ToString(buf[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// switchToInputDesktop opens the desktop currently receiving user input
|
||||||
|
// and sets it as the calling OS thread's desktop. Must be called from a
|
||||||
|
// goroutine locked to its OS thread via runtime.LockOSThread().
|
||||||
|
func switchToInputDesktop() (bool, string) {
|
||||||
|
hDesk, _, _ := procOpenInputDesktop.Call(0, 0, uintptr(windows.MAXIMUM_ALLOWED))
|
||||||
|
if hDesk == 0 {
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
name := getDesktopName(hDesk)
|
||||||
|
ret, _, _ := procSetThreadDesktop.Call(hDesk)
|
||||||
|
procCloseDesktop.Call(hDesk)
|
||||||
|
return ret != 0, name
|
||||||
|
}
|
||||||
|
|
||||||
|
// gdiCapturer captures the desktop screen using GDI BitBlt.
|
||||||
|
// GDI objects (DC, DIBSection) are allocated once and reused across frames.
|
||||||
|
type gdiCapturer struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
width int
|
||||||
|
height int
|
||||||
|
|
||||||
|
// Pre-allocated GDI resources, reused across captures.
|
||||||
|
memDC uintptr
|
||||||
|
bmp uintptr
|
||||||
|
bits uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
func newGDICapturer() (*gdiCapturer, error) {
|
||||||
|
w, h := screenSize()
|
||||||
|
if w == 0 || h == 0 {
|
||||||
|
return nil, fmt.Errorf("screen dimensions are zero")
|
||||||
|
}
|
||||||
|
c := &gdiCapturer{width: w, height: h}
|
||||||
|
if err := c.allocGDI(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// allocGDI pre-allocates the compatible DC and DIB section for reuse.
|
||||||
|
func (c *gdiCapturer) allocGDI() error {
|
||||||
|
screenDC, _, _ := procGetDC.Call(0)
|
||||||
|
if screenDC == 0 {
|
||||||
|
return fmt.Errorf("GetDC returned 0")
|
||||||
|
}
|
||||||
|
defer procReleaseDC.Call(0, screenDC)
|
||||||
|
|
||||||
|
memDC, _, _ := procCreateCompatDC.Call(screenDC)
|
||||||
|
if memDC == 0 {
|
||||||
|
return fmt.Errorf("CreateCompatibleDC returned 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
bi := bitmapInfo{
|
||||||
|
Header: bitmapInfoHeader{
|
||||||
|
Size: uint32(unsafe.Sizeof(bitmapInfoHeader{})),
|
||||||
|
Width: int32(c.width),
|
||||||
|
Height: -int32(c.height), // negative = top-down DIB
|
||||||
|
Planes: 1,
|
||||||
|
BitCount: 32,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var bits uintptr
|
||||||
|
bmp, _, _ := procCreateDIBSection.Call(
|
||||||
|
screenDC,
|
||||||
|
uintptr(unsafe.Pointer(&bi)),
|
||||||
|
dibRgbColors,
|
||||||
|
uintptr(unsafe.Pointer(&bits)),
|
||||||
|
0, 0,
|
||||||
|
)
|
||||||
|
if bmp == 0 || bits == 0 {
|
||||||
|
procDeleteDC.Call(memDC)
|
||||||
|
return fmt.Errorf("CreateDIBSection returned 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
procSelectObject.Call(memDC, bmp)
|
||||||
|
|
||||||
|
c.memDC = memDC
|
||||||
|
c.bmp = bmp
|
||||||
|
c.bits = bits
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *gdiCapturer) close() { c.freeGDI() }
|
||||||
|
|
||||||
|
// freeGDI releases pre-allocated GDI resources.
|
||||||
|
func (c *gdiCapturer) freeGDI() {
|
||||||
|
if c.bmp != 0 {
|
||||||
|
procDeleteObject.Call(c.bmp)
|
||||||
|
c.bmp = 0
|
||||||
|
}
|
||||||
|
if c.memDC != 0 {
|
||||||
|
procDeleteDC.Call(c.memDC)
|
||||||
|
c.memDC = 0
|
||||||
|
}
|
||||||
|
c.bits = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *gdiCapturer) capture() (*image.RGBA, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if c.memDC == 0 {
|
||||||
|
return nil, fmt.Errorf("GDI resources not allocated")
|
||||||
|
}
|
||||||
|
|
||||||
|
screenDC, _, _ := procGetDC.Call(0)
|
||||||
|
if screenDC == 0 {
|
||||||
|
return nil, fmt.Errorf("GetDC returned 0")
|
||||||
|
}
|
||||||
|
defer procReleaseDC.Call(0, screenDC)
|
||||||
|
|
||||||
|
ret, _, _ := procBitBlt.Call(c.memDC, 0, 0, uintptr(c.width), uintptr(c.height),
|
||||||
|
screenDC, 0, 0, srccopy)
|
||||||
|
if ret == 0 {
|
||||||
|
return nil, fmt.Errorf("BitBlt returned 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
n := c.width * c.height * 4
|
||||||
|
raw := unsafe.Slice((*byte)(unsafe.Pointer(c.bits)), n)
|
||||||
|
|
||||||
|
// GDI gives BGRA, the RFB encoder expects RGBA (img.Pix layout).
|
||||||
|
// Swap R and B in bulk using uint32 operations (one load + mask + shift
|
||||||
|
// per pixel instead of three separate byte assignments).
|
||||||
|
img := image.NewRGBA(image.Rect(0, 0, c.width, c.height))
|
||||||
|
pix := img.Pix
|
||||||
|
copy(pix, raw)
|
||||||
|
swizzleBGRAtoRGBA(pix)
|
||||||
|
return img, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DesktopCapturer captures the interactive desktop, handling desktop transitions
|
||||||
|
// (login screen, UAC prompts). A dedicated OS-locked goroutine continuously
|
||||||
|
// captures frames, which are retrieved by the VNC session on demand.
|
||||||
|
// Capture pauses automatically when no clients are connected.
|
||||||
|
type DesktopCapturer struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
frame *image.RGBA
|
||||||
|
w, h int
|
||||||
|
|
||||||
|
// clients tracks the number of active VNC sessions. When zero, the
|
||||||
|
// capture loop idles instead of grabbing frames.
|
||||||
|
clients atomic.Int32
|
||||||
|
|
||||||
|
// wake is signaled when a client connects and the loop should resume.
|
||||||
|
wake chan struct{}
|
||||||
|
// done is closed when Close is called, terminating the capture loop.
|
||||||
|
done chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDesktopCapturer creates a capturer that continuously grabs the active desktop.
|
||||||
|
func NewDesktopCapturer() *DesktopCapturer {
|
||||||
|
c := &DesktopCapturer{
|
||||||
|
wake: make(chan struct{}, 1),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
go c.loop()
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientConnect increments the active client count, resuming capture if needed.
|
||||||
|
func (c *DesktopCapturer) ClientConnect() {
|
||||||
|
c.clients.Add(1)
|
||||||
|
select {
|
||||||
|
case c.wake <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientDisconnect decrements the active client count.
|
||||||
|
func (c *DesktopCapturer) ClientDisconnect() {
|
||||||
|
c.clients.Add(-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the capture loop and releases resources.
|
||||||
|
func (c *DesktopCapturer) Close() {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
default:
|
||||||
|
close(c.done)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Width returns the current screen width.
|
||||||
|
func (c *DesktopCapturer) Width() int {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
return c.w
|
||||||
|
}
|
||||||
|
|
||||||
|
// Height returns the current screen height.
|
||||||
|
func (c *DesktopCapturer) Height() int {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
return c.h
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capture returns the most recent desktop frame.
|
||||||
|
func (c *DesktopCapturer) Capture() (*image.RGBA, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
img := c.frame
|
||||||
|
c.mu.Unlock()
|
||||||
|
if img != nil {
|
||||||
|
return img, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("no frame available yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForClient blocks until a client connects or the capturer is closed.
|
||||||
|
func (c *DesktopCapturer) waitForClient() bool {
|
||||||
|
if c.clients.Load() > 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-c.wake:
|
||||||
|
return true
|
||||||
|
case <-c.done:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DesktopCapturer) loop() {
|
||||||
|
runtime.LockOSThread()
|
||||||
|
|
||||||
|
// When running as a Windows service (Session 0), we need to attach to the
|
||||||
|
// interactive window station before OpenInputDesktop will succeed.
|
||||||
|
if err := setupInteractiveWindowStation(); err != nil {
|
||||||
|
log.Warnf("attach to interactive window station: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
frameTicker := time.NewTicker(33 * time.Millisecond) // ~30 fps
|
||||||
|
defer frameTicker.Stop()
|
||||||
|
|
||||||
|
retryTimer := time.NewTimer(0)
|
||||||
|
retryTimer.Stop()
|
||||||
|
defer retryTimer.Stop()
|
||||||
|
|
||||||
|
type frameCapturer interface {
|
||||||
|
capture() (*image.RGBA, error)
|
||||||
|
close()
|
||||||
|
}
|
||||||
|
|
||||||
|
var cap frameCapturer
|
||||||
|
var desktopFails int
|
||||||
|
var lastDesktop string
|
||||||
|
|
||||||
|
createCapturer := func() (frameCapturer, error) {
|
||||||
|
dc, err := newDXGICapturer()
|
||||||
|
if err == nil {
|
||||||
|
log.Info("using DXGI Desktop Duplication for capture")
|
||||||
|
return dc, nil
|
||||||
|
}
|
||||||
|
log.Debugf("DXGI unavailable (%v), falling back to GDI", err)
|
||||||
|
gc, err := newGDICapturer()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
log.Info("using GDI BitBlt for capture")
|
||||||
|
return gc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
if !c.waitForClient() {
|
||||||
|
if cap != nil {
|
||||||
|
cap.close()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// No clients: release the capturer and wait.
|
||||||
|
if c.clients.Load() <= 0 {
|
||||||
|
if cap != nil {
|
||||||
|
cap.close()
|
||||||
|
cap = nil
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, desk := switchToInputDesktop()
|
||||||
|
if !ok {
|
||||||
|
desktopFails++
|
||||||
|
if desktopFails == 1 || desktopFails%100 == 0 {
|
||||||
|
log.Warnf("switchToInputDesktop failed (count=%d), no interactive desktop session?", desktopFails)
|
||||||
|
}
|
||||||
|
retryTimer.Reset(100 * time.Millisecond)
|
||||||
|
select {
|
||||||
|
case <-retryTimer.C:
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if desktopFails > 0 {
|
||||||
|
log.Infof("switchToInputDesktop recovered after %d failures, desktop=%q", desktopFails, desk)
|
||||||
|
desktopFails = 0
|
||||||
|
}
|
||||||
|
if desk != lastDesktop {
|
||||||
|
log.Infof("desktop changed: %q -> %q", lastDesktop, desk)
|
||||||
|
lastDesktop = desk
|
||||||
|
if cap != nil {
|
||||||
|
cap.close()
|
||||||
|
}
|
||||||
|
cap = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if cap == nil {
|
||||||
|
fc, err := createCapturer()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("create capturer: %v", err)
|
||||||
|
retryTimer.Reset(500 * time.Millisecond)
|
||||||
|
select {
|
||||||
|
case <-retryTimer.C:
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cap = fc
|
||||||
|
w, h := screenSize()
|
||||||
|
c.mu.Lock()
|
||||||
|
c.w, c.h = w, h
|
||||||
|
c.mu.Unlock()
|
||||||
|
log.Infof("screen capturer ready: %dx%d", w, h)
|
||||||
|
}
|
||||||
|
|
||||||
|
img, err := cap.capture()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("capture: %v", err)
|
||||||
|
cap.close()
|
||||||
|
cap = nil
|
||||||
|
retryTimer.Reset(100 * time.Millisecond)
|
||||||
|
select {
|
||||||
|
case <-retryTimer.C:
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
c.frame = img
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-frameTicker.C:
|
||||||
|
case <-c.done:
|
||||||
|
if cap != nil {
|
||||||
|
cap.close()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
385
client/vnc/server/capture_x11.go
Normal file
385
client/vnc/server/capture_x11.go
Normal file
@@ -0,0 +1,385 @@
|
|||||||
|
//go:build (linux && !android) || freebsd
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/jezek/xgb"
|
||||||
|
"github.com/jezek/xgb/xproto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// X11Capturer captures the screen from an X11 display using the MIT-SHM extension.
|
||||||
|
type X11Capturer struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
conn *xgb.Conn
|
||||||
|
screen *xproto.ScreenInfo
|
||||||
|
w, h int
|
||||||
|
shmID int
|
||||||
|
shmAddr []byte
|
||||||
|
shmSeg uint32 // shm.Seg
|
||||||
|
useSHM bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// detectX11Display finds the active X11 display and sets DISPLAY/XAUTHORITY
|
||||||
|
// environment variables if needed. This is required when running as a system
|
||||||
|
// service where these vars aren't set.
|
||||||
|
func detectX11Display() {
|
||||||
|
if os.Getenv("DISPLAY") != "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try /proc first (Linux), then ps fallback (FreeBSD and others).
|
||||||
|
if detectX11FromProc() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if detectX11FromSockets() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// detectX11FromProc scans /proc/*/cmdline for Xorg (Linux).
|
||||||
|
func detectX11FromProc() bool {
|
||||||
|
entries, err := os.ReadDir("/proc")
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, e := range entries {
|
||||||
|
if !e.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cmdline, err := os.ReadFile("/proc/" + e.Name() + "/cmdline")
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if display, auth := parseXorgArgs(splitCmdline(cmdline)); display != "" {
|
||||||
|
setDisplayEnv(display, auth)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// detectX11FromSockets checks /tmp/.X11-unix/ for X sockets and uses ps
|
||||||
|
// to find the auth file. Works on FreeBSD and other systems without /proc.
|
||||||
|
func detectX11FromSockets() bool {
|
||||||
|
entries, err := os.ReadDir("/tmp/.X11-unix")
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the lowest display number.
|
||||||
|
for _, e := range entries {
|
||||||
|
name := e.Name()
|
||||||
|
if len(name) < 2 || name[0] != 'X' {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
display := ":" + name[1:]
|
||||||
|
os.Setenv("DISPLAY", display)
|
||||||
|
log.Infof("auto-detected DISPLAY=%s (from socket)", display)
|
||||||
|
|
||||||
|
// Try to find -auth from ps output.
|
||||||
|
if auth := findXorgAuthFromPS(); auth != "" {
|
||||||
|
os.Setenv("XAUTHORITY", auth)
|
||||||
|
log.Infof("auto-detected XAUTHORITY=%s (from ps)", auth)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// findXorgAuthFromPS runs ps to find Xorg and extract its -auth argument.
|
||||||
|
func findXorgAuthFromPS() string {
|
||||||
|
out, err := exec.Command("ps", "auxww").Output()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
for _, line := range strings.Split(string(out), "\n") {
|
||||||
|
if !strings.Contains(line, "Xorg") && !strings.Contains(line, "/X ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
for i, f := range fields {
|
||||||
|
if f == "-auth" && i+1 < len(fields) {
|
||||||
|
return fields[i+1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseXorgArgs(args []string) (display, auth string) {
|
||||||
|
if len(args) == 0 {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
base := args[0]
|
||||||
|
if !(base == "Xorg" || base == "X" || len(base) > 0 && base[len(base)-1] == 'X' ||
|
||||||
|
strings.Contains(base, "/Xorg") || strings.Contains(base, "/X")) {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
for i, arg := range args[1:] {
|
||||||
|
if len(arg) > 0 && arg[0] == ':' {
|
||||||
|
display = arg
|
||||||
|
}
|
||||||
|
if arg == "-auth" && i+2 < len(args) {
|
||||||
|
auth = args[i+2]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return display, auth
|
||||||
|
}
|
||||||
|
|
||||||
|
func setDisplayEnv(display, auth string) {
|
||||||
|
os.Setenv("DISPLAY", display)
|
||||||
|
log.Infof("auto-detected DISPLAY=%s", display)
|
||||||
|
if auth != "" {
|
||||||
|
os.Setenv("XAUTHORITY", auth)
|
||||||
|
log.Infof("auto-detected XAUTHORITY=%s", auth)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitCmdline(data []byte) []string {
|
||||||
|
var args []string
|
||||||
|
for _, b := range splitNull(data) {
|
||||||
|
if len(b) > 0 {
|
||||||
|
args = append(args, string(b))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitNull(data []byte) [][]byte {
|
||||||
|
var parts [][]byte
|
||||||
|
start := 0
|
||||||
|
for i, b := range data {
|
||||||
|
if b == 0 {
|
||||||
|
parts = append(parts, data[start:i])
|
||||||
|
start = i + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if start < len(data) {
|
||||||
|
parts = append(parts, data[start:])
|
||||||
|
}
|
||||||
|
return parts
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewX11Capturer connects to the X11 display and sets up shared memory capture.
|
||||||
|
func NewX11Capturer(display string) (*X11Capturer, error) {
|
||||||
|
detectX11Display()
|
||||||
|
|
||||||
|
if display == "" {
|
||||||
|
display = os.Getenv("DISPLAY")
|
||||||
|
}
|
||||||
|
if display == "" {
|
||||||
|
return nil, fmt.Errorf("DISPLAY not set and no Xorg process found")
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := xgb.NewConnDisplay(display)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("connect to X11 display %s: %w", display, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
setup := xproto.Setup(conn)
|
||||||
|
if len(setup.Roots) == 0 {
|
||||||
|
conn.Close()
|
||||||
|
return nil, fmt.Errorf("no X11 screens")
|
||||||
|
}
|
||||||
|
screen := setup.Roots[0]
|
||||||
|
|
||||||
|
c := &X11Capturer{
|
||||||
|
conn: conn,
|
||||||
|
screen: &screen,
|
||||||
|
w: int(screen.WidthInPixels),
|
||||||
|
h: int(screen.HeightInPixels),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.initSHM(); err != nil {
|
||||||
|
log.Debugf("X11 SHM not available, using slow GetImage: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("X11 capturer ready: %dx%d (display=%s, shm=%v)", c.w, c.h, display, c.useSHM)
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// initSHM is implemented in capture_x11_shm_linux.go (requires SysV SHM).
|
||||||
|
// On platforms without SysV SHM (FreeBSD), a stub returns an error and
|
||||||
|
// the capturer falls back to GetImage.
|
||||||
|
|
||||||
|
// Width returns the screen width.
|
||||||
|
func (c *X11Capturer) Width() int { return c.w }
|
||||||
|
|
||||||
|
// Height returns the screen height.
|
||||||
|
func (c *X11Capturer) Height() int { return c.h }
|
||||||
|
|
||||||
|
// Capture returns the current screen as an RGBA image.
|
||||||
|
func (c *X11Capturer) Capture() (*image.RGBA, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if c.useSHM {
|
||||||
|
return c.captureSHM()
|
||||||
|
}
|
||||||
|
return c.captureGetImage()
|
||||||
|
}
|
||||||
|
|
||||||
|
// captureSHM is implemented in capture_x11_shm_linux.go.
|
||||||
|
|
||||||
|
func (c *X11Capturer) captureGetImage() (*image.RGBA, error) {
|
||||||
|
cookie := xproto.GetImage(c.conn, xproto.ImageFormatZPixmap,
|
||||||
|
xproto.Drawable(c.screen.Root),
|
||||||
|
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF)
|
||||||
|
|
||||||
|
reply, err := cookie.Reply()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("GetImage: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
|
||||||
|
data := reply.Data
|
||||||
|
n := c.w * c.h * 4
|
||||||
|
if len(data) < n {
|
||||||
|
return nil, fmt.Errorf("GetImage returned %d bytes, expected %d", len(data), n)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < n; i += 4 {
|
||||||
|
img.Pix[i+0] = data[i+2] // R
|
||||||
|
img.Pix[i+1] = data[i+1] // G
|
||||||
|
img.Pix[i+2] = data[i+0] // B
|
||||||
|
img.Pix[i+3] = 0xff
|
||||||
|
}
|
||||||
|
return img, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close releases X11 resources.
|
||||||
|
func (c *X11Capturer) Close() {
|
||||||
|
c.closeSHM()
|
||||||
|
c.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// closeSHM is implemented in capture_x11_shm_linux.go.
|
||||||
|
|
||||||
|
// X11Poller wraps X11Capturer in a continuous capture loop, matching the
|
||||||
|
// DesktopCapturer pattern from Windows.
|
||||||
|
type X11Poller struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
frame *image.RGBA
|
||||||
|
w, h int
|
||||||
|
display string
|
||||||
|
done chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewX11Poller creates a capturer that continuously grabs the X11 display.
|
||||||
|
func NewX11Poller(display string) *X11Poller {
|
||||||
|
p := &X11Poller{
|
||||||
|
display: display,
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
go p.loop()
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the capture loop.
|
||||||
|
func (p *X11Poller) Close() {
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
default:
|
||||||
|
close(p.done)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Width returns the screen width.
|
||||||
|
func (p *X11Poller) Width() int {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
return p.w
|
||||||
|
}
|
||||||
|
|
||||||
|
// Height returns the screen height.
|
||||||
|
func (p *X11Poller) Height() int {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
return p.h
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capture returns the most recent frame.
|
||||||
|
func (p *X11Poller) Capture() (*image.RGBA, error) {
|
||||||
|
p.mu.Lock()
|
||||||
|
img := p.frame
|
||||||
|
p.mu.Unlock()
|
||||||
|
if img != nil {
|
||||||
|
return img, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("no frame available yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *X11Poller) loop() {
|
||||||
|
var capturer *X11Capturer
|
||||||
|
var initFails int
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if capturer != nil {
|
||||||
|
capturer.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if capturer == nil {
|
||||||
|
var err error
|
||||||
|
capturer, err = NewX11Capturer(p.display)
|
||||||
|
if err != nil {
|
||||||
|
initFails++
|
||||||
|
if initFails <= maxCapturerRetries {
|
||||||
|
log.Debugf("X11 capturer: %v (attempt %d/%d)", err, initFails, maxCapturerRetries)
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
return
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Warnf("X11 capturer unavailable after %d attempts, stopping poller", maxCapturerRetries)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
initFails = 0
|
||||||
|
p.mu.Lock()
|
||||||
|
p.w, p.h = capturer.Width(), capturer.Height()
|
||||||
|
p.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
img, err := capturer.Capture()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("X11 capture: %v", err)
|
||||||
|
capturer.Close()
|
||||||
|
capturer = nil
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
return
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
p.mu.Lock()
|
||||||
|
p.frame = img
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
return
|
||||||
|
case <-time.After(33 * time.Millisecond): // ~30 fps
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
78
client/vnc/server/capture_x11_shm_linux.go
Normal file
78
client/vnc/server/capture_x11_shm_linux.go
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
|
||||||
|
"github.com/jezek/xgb/shm"
|
||||||
|
"github.com/jezek/xgb/xproto"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *X11Capturer) initSHM() error {
|
||||||
|
if err := shm.Init(c.conn); err != nil {
|
||||||
|
return fmt.Errorf("init SHM extension: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
size := c.w * c.h * 4
|
||||||
|
id, err := unix.SysvShmGet(unix.IPC_PRIVATE, size, unix.IPC_CREAT|0600)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("shmget: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, err := unix.SysvShmAttach(id, 0, 0)
|
||||||
|
if err != nil {
|
||||||
|
unix.SysvShmCtl(id, unix.IPC_RMID, nil)
|
||||||
|
return fmt.Errorf("shmat: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
unix.SysvShmCtl(id, unix.IPC_RMID, nil)
|
||||||
|
|
||||||
|
seg, err := shm.NewSegId(c.conn)
|
||||||
|
if err != nil {
|
||||||
|
unix.SysvShmDetach(addr)
|
||||||
|
return fmt.Errorf("new SHM seg: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := shm.AttachChecked(c.conn, seg, uint32(id), false).Check(); err != nil {
|
||||||
|
unix.SysvShmDetach(addr)
|
||||||
|
return fmt.Errorf("SHM attach to X: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.shmID = id
|
||||||
|
c.shmAddr = addr
|
||||||
|
c.shmSeg = uint32(seg)
|
||||||
|
c.useSHM = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
|
||||||
|
cookie := shm.GetImage(c.conn, xproto.Drawable(c.screen.Root),
|
||||||
|
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF,
|
||||||
|
xproto.ImageFormatZPixmap, shm.Seg(c.shmSeg), 0)
|
||||||
|
|
||||||
|
_, err := cookie.Reply()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("SHM GetImage: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
|
||||||
|
n := c.w * c.h * 4
|
||||||
|
|
||||||
|
for i := 0; i < n; i += 4 {
|
||||||
|
img.Pix[i+0] = c.shmAddr[i+2] // R
|
||||||
|
img.Pix[i+1] = c.shmAddr[i+1] // G
|
||||||
|
img.Pix[i+2] = c.shmAddr[i+0] // B
|
||||||
|
img.Pix[i+3] = 0xff
|
||||||
|
}
|
||||||
|
return img, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *X11Capturer) closeSHM() {
|
||||||
|
if c.useSHM {
|
||||||
|
shm.Detach(c.conn, shm.Seg(c.shmSeg))
|
||||||
|
unix.SysvShmDetach(c.shmAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
18
client/vnc/server/capture_x11_shm_stub.go
Normal file
18
client/vnc/server/capture_x11_shm_stub.go
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
//go:build freebsd
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *X11Capturer) initSHM() error {
|
||||||
|
return fmt.Errorf("SysV SHM not available on this platform")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
|
||||||
|
return nil, fmt.Errorf("SHM capture not available on this platform")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *X11Capturer) closeSHM() {}
|
||||||
151
client/vnc/server/crypto.go
Normal file
151
client/vnc/server/crypto.go
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/ecdh"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"crypto/sha256"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/hkdf"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
aesKeySize = 32 // AES-256
|
||||||
|
gcmNonceSize = 12
|
||||||
|
)
|
||||||
|
|
||||||
|
// recCrypto holds per-session encryption state.
|
||||||
|
type recCrypto struct {
|
||||||
|
gcm cipher.AEAD
|
||||||
|
frameCounter uint64
|
||||||
|
// ephemeralPub is stored in the recording header so the admin can derive the same key.
|
||||||
|
ephemeralPub []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// newRecCrypto sets up encryption for a new recording session.
|
||||||
|
// adminPubKeyB64 is the base64-encoded X25519 public key from management settings.
|
||||||
|
func newRecCrypto(adminPubKeyB64 string) (*recCrypto, error) {
|
||||||
|
adminPubBytes, err := base64.StdEncoding.DecodeString(adminPubKeyB64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decode admin public key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
adminPub, err := ecdh.X25519().NewPublicKey(adminPubBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse admin X25519 public key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate ephemeral keypair
|
||||||
|
ephemeral, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generate ephemeral key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ECDH shared secret
|
||||||
|
shared, err := ephemeral.ECDH(adminPub)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("ECDH: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Derive AES-256 key via HKDF
|
||||||
|
aesKey, err := deriveKey(shared, ephemeral.PublicKey().Bytes())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("derive key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
block, err := aes.NewCipher(aesKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create AES cipher: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create GCM: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &recCrypto{
|
||||||
|
gcm: gcm,
|
||||||
|
ephemeralPub: ephemeral.PublicKey().Bytes(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// encrypt encrypts plaintext using a counter-based nonce. Each call increments the counter.
|
||||||
|
func (c *recCrypto) encrypt(plaintext []byte) []byte {
|
||||||
|
nonce := make([]byte, gcmNonceSize)
|
||||||
|
binary.LittleEndian.PutUint64(nonce, c.frameCounter)
|
||||||
|
c.frameCounter++
|
||||||
|
return c.gcm.Seal(nil, nonce, plaintext, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecryptRecording creates a decryptor from the admin's private key and the ephemeral public key from the header.
|
||||||
|
func DecryptRecording(adminPrivKeyB64 string, ephemeralPubB64 string) (*recDecryptor, error) {
|
||||||
|
adminPrivBytes, err := base64.StdEncoding.DecodeString(adminPrivKeyB64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decode admin private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
adminPriv, err := ecdh.X25519().NewPrivateKey(adminPrivBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse admin X25519 private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ephPubBytes, err := base64.StdEncoding.DecodeString(ephemeralPubB64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decode ephemeral public key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ephPub, err := ecdh.X25519().NewPublicKey(ephPubBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse ephemeral public key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
shared, err := adminPriv.ECDH(ephPub)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("ECDH: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
aesKey, err := deriveKey(shared, ephPubBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("derive key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
block, err := aes.NewCipher(aesKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create AES cipher: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create GCM: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &recDecryptor{gcm: gcm}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type recDecryptor struct {
|
||||||
|
gcm cipher.AEAD
|
||||||
|
frameCounter uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt decrypts a frame. Must be called in the same order as encryption.
|
||||||
|
func (d *recDecryptor) Decrypt(ciphertext []byte) ([]byte, error) {
|
||||||
|
nonce := make([]byte, gcmNonceSize)
|
||||||
|
binary.LittleEndian.PutUint64(nonce, d.frameCounter)
|
||||||
|
d.frameCounter++
|
||||||
|
return d.gcm.Open(nil, nonce, ciphertext, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func deriveKey(shared, ephemeralPub []byte) ([]byte, error) {
|
||||||
|
hkdfReader := hkdf.New(sha256.New, shared, ephemeralPub, []byte("netbird-recording"))
|
||||||
|
key := make([]byte, aesKeySize)
|
||||||
|
if _, err := io.ReadFull(hkdfReader, key); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
129
client/vnc/server/crypto_test.go
Normal file
129
client/vnc/server/crypto_test.go
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdh"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCryptoRoundtrip(t *testing.T) {
|
||||||
|
// Generate admin keypair
|
||||||
|
adminPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
adminPubB64 := base64.StdEncoding.EncodeToString(adminPriv.PublicKey().Bytes())
|
||||||
|
adminPrivB64 := base64.StdEncoding.EncodeToString(adminPriv.Bytes())
|
||||||
|
|
||||||
|
// Create encryptor (recording side)
|
||||||
|
enc, err := newRecCrypto(adminPubB64)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, enc.ephemeralPub, 32)
|
||||||
|
|
||||||
|
ephPubB64 := base64.StdEncoding.EncodeToString(enc.ephemeralPub)
|
||||||
|
|
||||||
|
// Encrypt some frames
|
||||||
|
plaintext1 := []byte("frame data one - PNG bytes would go here")
|
||||||
|
plaintext2 := []byte("frame data two - different content")
|
||||||
|
plaintext3 := make([]byte, 1024*100) // 100KB frame
|
||||||
|
rand.Read(plaintext3)
|
||||||
|
|
||||||
|
ct1 := enc.encrypt(plaintext1)
|
||||||
|
ct2 := enc.encrypt(plaintext2)
|
||||||
|
ct3 := enc.encrypt(plaintext3)
|
||||||
|
|
||||||
|
// Ciphertext should differ from plaintext
|
||||||
|
assert.NotEqual(t, plaintext1, ct1)
|
||||||
|
// Ciphertext is larger (GCM tag overhead)
|
||||||
|
assert.Greater(t, len(ct1), len(plaintext1))
|
||||||
|
|
||||||
|
// Create decryptor (playback side)
|
||||||
|
dec, err := DecryptRecording(adminPrivB64, ephPubB64)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Decrypt in same order
|
||||||
|
got1, err := dec.Decrypt(ct1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, plaintext1, got1)
|
||||||
|
|
||||||
|
got2, err := dec.Decrypt(ct2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, plaintext2, got2)
|
||||||
|
|
||||||
|
got3, err := dec.Decrypt(ct3)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, plaintext3, got3)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCryptoWrongKey(t *testing.T) {
|
||||||
|
// Admin key
|
||||||
|
adminPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
adminPubB64 := base64.StdEncoding.EncodeToString(adminPriv.PublicKey().Bytes())
|
||||||
|
|
||||||
|
// Encrypt with admin's public key
|
||||||
|
enc, err := newRecCrypto(adminPubB64)
|
||||||
|
require.NoError(t, err)
|
||||||
|
ephPubB64 := base64.StdEncoding.EncodeToString(enc.ephemeralPub)
|
||||||
|
|
||||||
|
ct := enc.encrypt([]byte("secret frame data"))
|
||||||
|
|
||||||
|
// Try to decrypt with a different private key
|
||||||
|
wrongPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
wrongPrivB64 := base64.StdEncoding.EncodeToString(wrongPriv.Bytes())
|
||||||
|
|
||||||
|
dec, err := DecryptRecording(wrongPrivB64, ephPubB64)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = dec.Decrypt(ct)
|
||||||
|
assert.Error(t, err, "decryption with wrong key should fail")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCryptoInvalidKey(t *testing.T) {
|
||||||
|
_, err := newRecCrypto("")
|
||||||
|
assert.Error(t, err, "empty key should fail")
|
||||||
|
|
||||||
|
_, err = newRecCrypto("not-base64!!!")
|
||||||
|
assert.Error(t, err, "invalid base64 should fail")
|
||||||
|
|
||||||
|
_, err = newRecCrypto(base64.StdEncoding.EncodeToString([]byte("too-short")))
|
||||||
|
assert.Error(t, err, "wrong-length key should fail")
|
||||||
|
|
||||||
|
_, err = DecryptRecording("", "validbutirrelevant")
|
||||||
|
assert.Error(t, err, "empty private key should fail")
|
||||||
|
|
||||||
|
_, err = DecryptRecording("not-base64!!!", base64.StdEncoding.EncodeToString(make([]byte, 32)))
|
||||||
|
assert.Error(t, err, "invalid base64 private key should fail")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCryptoOutOfOrderFails(t *testing.T) {
|
||||||
|
adminPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
adminPubB64 := base64.StdEncoding.EncodeToString(adminPriv.PublicKey().Bytes())
|
||||||
|
adminPrivB64 := base64.StdEncoding.EncodeToString(adminPriv.Bytes())
|
||||||
|
|
||||||
|
enc, err := newRecCrypto(adminPubB64)
|
||||||
|
require.NoError(t, err)
|
||||||
|
ephPubB64 := base64.StdEncoding.EncodeToString(enc.ephemeralPub)
|
||||||
|
|
||||||
|
ct0 := enc.encrypt([]byte("frame 0"))
|
||||||
|
ct1 := enc.encrypt([]byte("frame 1"))
|
||||||
|
|
||||||
|
dec, err := DecryptRecording(adminPrivB64, ephPubB64)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Skip frame 0, try to decrypt frame 1 first (wrong nonce)
|
||||||
|
_, err = dec.Decrypt(ct1)
|
||||||
|
assert.Error(t, err, "out-of-order decryption should fail due to nonce mismatch")
|
||||||
|
|
||||||
|
// But frame 0 with a fresh decryptor should work
|
||||||
|
dec2, err := DecryptRecording(adminPrivB64, ephPubB64)
|
||||||
|
require.NoError(t, err)
|
||||||
|
got, err := dec2.Decrypt(ct0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte("frame 0"), got)
|
||||||
|
}
|
||||||
540
client/vnc/server/input_darwin.go
Normal file
540
client/vnc/server/input_darwin.go
Normal file
@@ -0,0 +1,540 @@
|
|||||||
|
//go:build darwin && !ios
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/ebitengine/purego"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Core Graphics event constants.
|
||||||
|
const (
|
||||||
|
kCGEventSourceStateCombinedSessionState int32 = 0
|
||||||
|
|
||||||
|
kCGEventLeftMouseDown int32 = 1
|
||||||
|
kCGEventLeftMouseUp int32 = 2
|
||||||
|
kCGEventRightMouseDown int32 = 3
|
||||||
|
kCGEventRightMouseUp int32 = 4
|
||||||
|
kCGEventMouseMoved int32 = 5
|
||||||
|
kCGEventLeftMouseDragged int32 = 6
|
||||||
|
kCGEventRightMouseDragged int32 = 7
|
||||||
|
kCGEventKeyDown int32 = 10
|
||||||
|
kCGEventKeyUp int32 = 11
|
||||||
|
kCGEventOtherMouseDown int32 = 25
|
||||||
|
kCGEventOtherMouseUp int32 = 26
|
||||||
|
|
||||||
|
kCGMouseButtonLeft int32 = 0
|
||||||
|
kCGMouseButtonRight int32 = 1
|
||||||
|
kCGMouseButtonCenter int32 = 2
|
||||||
|
|
||||||
|
kCGHIDEventTap int32 = 0
|
||||||
|
|
||||||
|
// IOKit power management constants.
|
||||||
|
kIOPMUserActiveLocal int32 = 0
|
||||||
|
kIOPMAssertionLevelOn uint32 = 255
|
||||||
|
kCFStringEncodingUTF8 uint32 = 0x08000100
|
||||||
|
)
|
||||||
|
|
||||||
|
var darwinInputOnce sync.Once
|
||||||
|
|
||||||
|
var (
|
||||||
|
cgEventSourceCreate func(int32) uintptr
|
||||||
|
cgEventCreateKeyboardEvent func(uintptr, uint16, bool) uintptr
|
||||||
|
// CGEventCreateMouseEvent takes CGPoint as two separate float64 args.
|
||||||
|
// purego can't handle array/struct types but individual float64s work.
|
||||||
|
cgEventCreateMouseEvent func(uintptr, int32, float64, float64, int32) uintptr
|
||||||
|
cgEventPost func(int32, uintptr)
|
||||||
|
|
||||||
|
// CGEventCreateScrollWheelEvent is variadic, call via SyscallN.
|
||||||
|
cgEventCreateScrollWheelEventAddr uintptr
|
||||||
|
|
||||||
|
axIsProcessTrusted func() bool
|
||||||
|
|
||||||
|
// IOKit power-management bindings used to wake the display and inhibit
|
||||||
|
// idle sleep while a VNC client is driving input.
|
||||||
|
iopmAssertionDeclareUserActivity func(uintptr, int32, *uint32) int32
|
||||||
|
iopmAssertionCreateWithName func(uintptr, uint32, uintptr, *uint32) int32
|
||||||
|
iopmAssertionRelease func(uint32) int32
|
||||||
|
cfStringCreateWithCString func(uintptr, string, uint32) uintptr
|
||||||
|
|
||||||
|
// Cached CFStrings for assertion name and idle-sleep type.
|
||||||
|
pmAssertionNameCFStr uintptr
|
||||||
|
pmPreventIdleDisplayCFStr uintptr
|
||||||
|
|
||||||
|
// Assertion IDs. userActivityID is reused across input events so repeated
|
||||||
|
// calls refresh the same assertion rather than create new ones.
|
||||||
|
pmMu sync.Mutex
|
||||||
|
userActivityID uint32
|
||||||
|
preventSleepID uint32
|
||||||
|
preventSleepHeld bool
|
||||||
|
|
||||||
|
darwinInputReady bool
|
||||||
|
darwinEventSource uintptr
|
||||||
|
)
|
||||||
|
|
||||||
|
func initDarwinInput() {
|
||||||
|
darwinInputOnce.Do(func() {
|
||||||
|
cg, err := purego.Dlopen("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("load CoreGraphics for input: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
purego.RegisterLibFunc(&cgEventSourceCreate, cg, "CGEventSourceCreate")
|
||||||
|
purego.RegisterLibFunc(&cgEventCreateKeyboardEvent, cg, "CGEventCreateKeyboardEvent")
|
||||||
|
purego.RegisterLibFunc(&cgEventCreateMouseEvent, cg, "CGEventCreateMouseEvent")
|
||||||
|
purego.RegisterLibFunc(&cgEventPost, cg, "CGEventPost")
|
||||||
|
|
||||||
|
sym, err := purego.Dlsym(cg, "CGEventCreateScrollWheelEvent")
|
||||||
|
if err == nil {
|
||||||
|
cgEventCreateScrollWheelEventAddr = sym
|
||||||
|
}
|
||||||
|
|
||||||
|
if ax, err := purego.Dlopen("/System/Library/Frameworks/ApplicationServices.framework/ApplicationServices", purego.RTLD_NOW|purego.RTLD_GLOBAL); err == nil {
|
||||||
|
if sym, err := purego.Dlsym(ax, "AXIsProcessTrusted"); err == nil {
|
||||||
|
purego.RegisterFunc(&axIsProcessTrusted, sym)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
initPowerAssertions()
|
||||||
|
|
||||||
|
darwinInputReady = true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func initPowerAssertions() {
|
||||||
|
iokit, err := purego.Dlopen("/System/Library/Frameworks/IOKit.framework/IOKit", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("load IOKit: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cf, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("load CoreFoundation for power assertions: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
purego.RegisterLibFunc(&cfStringCreateWithCString, cf, "CFStringCreateWithCString")
|
||||||
|
purego.RegisterLibFunc(&iopmAssertionDeclareUserActivity, iokit, "IOPMAssertionDeclareUserActivity")
|
||||||
|
purego.RegisterLibFunc(&iopmAssertionCreateWithName, iokit, "IOPMAssertionCreateWithName")
|
||||||
|
purego.RegisterLibFunc(&iopmAssertionRelease, iokit, "IOPMAssertionRelease")
|
||||||
|
|
||||||
|
pmAssertionNameCFStr = cfStringCreateWithCString(0, "NetBird VNC input", kCFStringEncodingUTF8)
|
||||||
|
pmPreventIdleDisplayCFStr = cfStringCreateWithCString(0, "PreventUserIdleDisplaySleep", kCFStringEncodingUTF8)
|
||||||
|
}
|
||||||
|
|
||||||
|
// wakeDisplay declares user activity so macOS treats the synthesized input as
|
||||||
|
// real HID activity, waking the display if it is asleep. Called on every key
|
||||||
|
// and pointer event; the kernel coalesces repeated calls cheaply.
|
||||||
|
func wakeDisplay() {
|
||||||
|
if iopmAssertionDeclareUserActivity == nil || pmAssertionNameCFStr == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pmMu.Lock()
|
||||||
|
id := userActivityID
|
||||||
|
pmMu.Unlock()
|
||||||
|
r := iopmAssertionDeclareUserActivity(pmAssertionNameCFStr, kIOPMUserActiveLocal, &id)
|
||||||
|
if r != 0 {
|
||||||
|
log.Tracef("IOPMAssertionDeclareUserActivity returned %d", r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pmMu.Lock()
|
||||||
|
userActivityID = id
|
||||||
|
pmMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// holdPreventIdleSleep creates an assertion that keeps the display from going
|
||||||
|
// idle-to-sleep while a VNC session is active. Safe to call repeatedly.
|
||||||
|
func holdPreventIdleSleep() {
|
||||||
|
if iopmAssertionCreateWithName == nil || pmPreventIdleDisplayCFStr == 0 || pmAssertionNameCFStr == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pmMu.Lock()
|
||||||
|
defer pmMu.Unlock()
|
||||||
|
if preventSleepHeld {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var id uint32
|
||||||
|
r := iopmAssertionCreateWithName(pmPreventIdleDisplayCFStr, kIOPMAssertionLevelOn, pmAssertionNameCFStr, &id)
|
||||||
|
if r != 0 {
|
||||||
|
log.Debugf("IOPMAssertionCreateWithName returned %d", r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
preventSleepID = id
|
||||||
|
preventSleepHeld = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// releasePreventIdleSleep drops the idle-sleep assertion.
|
||||||
|
func releasePreventIdleSleep() {
|
||||||
|
if iopmAssertionRelease == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pmMu.Lock()
|
||||||
|
defer pmMu.Unlock()
|
||||||
|
if !preventSleepHeld {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r := iopmAssertionRelease(preventSleepID); r != 0 {
|
||||||
|
log.Debugf("IOPMAssertionRelease returned %d", r)
|
||||||
|
}
|
||||||
|
preventSleepHeld = false
|
||||||
|
preventSleepID = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureEventSource() uintptr {
|
||||||
|
if darwinEventSource != 0 {
|
||||||
|
return darwinEventSource
|
||||||
|
}
|
||||||
|
darwinEventSource = cgEventSourceCreate(kCGEventSourceStateCombinedSessionState)
|
||||||
|
return darwinEventSource
|
||||||
|
}
|
||||||
|
|
||||||
|
// MacInputInjector injects keyboard and mouse events via Core Graphics.
|
||||||
|
type MacInputInjector struct {
|
||||||
|
lastButtons uint8
|
||||||
|
pbcopyPath string
|
||||||
|
pbpastePath string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMacInputInjector creates a macOS input injector.
|
||||||
|
func NewMacInputInjector() (*MacInputInjector, error) {
|
||||||
|
initDarwinInput()
|
||||||
|
if !darwinInputReady {
|
||||||
|
return nil, fmt.Errorf("CoreGraphics not available for input injection")
|
||||||
|
}
|
||||||
|
checkMacPermissions()
|
||||||
|
|
||||||
|
m := &MacInputInjector{}
|
||||||
|
if path, err := exec.LookPath("pbcopy"); err == nil {
|
||||||
|
m.pbcopyPath = path
|
||||||
|
}
|
||||||
|
if path, err := exec.LookPath("pbpaste"); err == nil {
|
||||||
|
m.pbpastePath = path
|
||||||
|
}
|
||||||
|
if m.pbcopyPath == "" || m.pbpastePath == "" {
|
||||||
|
log.Debugf("clipboard tools not found (pbcopy=%q, pbpaste=%q)", m.pbcopyPath, m.pbpastePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
holdPreventIdleSleep()
|
||||||
|
|
||||||
|
log.Info("macOS input injector ready")
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkMacPermissions warns and opens the Privacy pane if Accessibility is
|
||||||
|
// missing. Uses AXIsProcessTrusted which returns immediately; the previous
|
||||||
|
// osascript probe blocked for 120s (AppleEvent timeout) when access was
|
||||||
|
// denied, which delayed VNC server startup past client deadlines.
|
||||||
|
func checkMacPermissions() {
|
||||||
|
if axIsProcessTrusted != nil && !axIsProcessTrusted() {
|
||||||
|
openPrivacyPane("Privacy_Accessibility")
|
||||||
|
log.Warn("Accessibility permission not granted. Input injection will not work. " +
|
||||||
|
"Opened System Settings > Privacy & Security > Accessibility; enable netbird.")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Screen Recording permission is required for screen capture. " +
|
||||||
|
"If the screen appears black, grant in System Settings > Privacy & Security > Screen Recording.")
|
||||||
|
}
|
||||||
|
|
||||||
|
// openPrivacyPane opens the given Privacy pane in System Settings so the user
|
||||||
|
// can toggle the permission without navigating manually.
|
||||||
|
func openPrivacyPane(pane string) {
|
||||||
|
url := "x-apple.systempreferences:com.apple.preference.security?" + pane
|
||||||
|
if err := exec.Command("open", url).Start(); err != nil {
|
||||||
|
log.Debugf("open privacy pane %s: %v", pane, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InjectKey simulates a key press or release.
|
||||||
|
func (m *MacInputInjector) InjectKey(keysym uint32, down bool) {
|
||||||
|
wakeDisplay()
|
||||||
|
src := ensureEventSource()
|
||||||
|
if src == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
keycode := keysymToMacKeycode(keysym)
|
||||||
|
if keycode == 0xFFFF {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
event := cgEventCreateKeyboardEvent(src, keycode, down)
|
||||||
|
if event == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cgEventPost(kCGHIDEventTap, event)
|
||||||
|
cfRelease(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InjectPointer simulates mouse movement and button events.
|
||||||
|
func (m *MacInputInjector) InjectPointer(buttonMask uint8, px, py, serverW, serverH int) {
|
||||||
|
wakeDisplay()
|
||||||
|
if serverW == 0 || serverH == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
src := ensureEventSource()
|
||||||
|
if src == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Framebuffer is in physical pixels (Retina). CGEventCreateMouseEvent
|
||||||
|
// expects logical points, so scale down by the display's pixel/point ratio.
|
||||||
|
x := float64(px)
|
||||||
|
y := float64(py)
|
||||||
|
if cgDisplayPixelsWide != nil && cgMainDisplayID != nil {
|
||||||
|
displayID := cgMainDisplayID()
|
||||||
|
logicalW := int(cgDisplayPixelsWide(displayID))
|
||||||
|
logicalH := int(cgDisplayPixelsHigh(displayID))
|
||||||
|
if logicalW > 0 && logicalH > 0 {
|
||||||
|
x = float64(px) * float64(logicalW) / float64(serverW)
|
||||||
|
y = float64(py) * float64(logicalH) / float64(serverH)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
leftDown := buttonMask&0x01 != 0
|
||||||
|
rightDown := buttonMask&0x04 != 0
|
||||||
|
middleDown := buttonMask&0x02 != 0
|
||||||
|
scrollUp := buttonMask&0x08 != 0
|
||||||
|
scrollDown := buttonMask&0x10 != 0
|
||||||
|
|
||||||
|
wasLeft := m.lastButtons&0x01 != 0
|
||||||
|
wasRight := m.lastButtons&0x04 != 0
|
||||||
|
wasMiddle := m.lastButtons&0x02 != 0
|
||||||
|
|
||||||
|
if leftDown {
|
||||||
|
m.postMouse(src, kCGEventLeftMouseDragged, x, y, kCGMouseButtonLeft)
|
||||||
|
} else if rightDown {
|
||||||
|
m.postMouse(src, kCGEventRightMouseDragged, x, y, kCGMouseButtonRight)
|
||||||
|
} else {
|
||||||
|
m.postMouse(src, kCGEventMouseMoved, x, y, kCGMouseButtonLeft)
|
||||||
|
}
|
||||||
|
|
||||||
|
if leftDown && !wasLeft {
|
||||||
|
m.postMouse(src, kCGEventLeftMouseDown, x, y, kCGMouseButtonLeft)
|
||||||
|
} else if !leftDown && wasLeft {
|
||||||
|
m.postMouse(src, kCGEventLeftMouseUp, x, y, kCGMouseButtonLeft)
|
||||||
|
}
|
||||||
|
if rightDown && !wasRight {
|
||||||
|
m.postMouse(src, kCGEventRightMouseDown, x, y, kCGMouseButtonRight)
|
||||||
|
} else if !rightDown && wasRight {
|
||||||
|
m.postMouse(src, kCGEventRightMouseUp, x, y, kCGMouseButtonRight)
|
||||||
|
}
|
||||||
|
if middleDown && !wasMiddle {
|
||||||
|
m.postMouse(src, kCGEventOtherMouseDown, x, y, kCGMouseButtonCenter)
|
||||||
|
} else if !middleDown && wasMiddle {
|
||||||
|
m.postMouse(src, kCGEventOtherMouseUp, x, y, kCGMouseButtonCenter)
|
||||||
|
}
|
||||||
|
|
||||||
|
if scrollUp {
|
||||||
|
m.postScroll(src, 3)
|
||||||
|
}
|
||||||
|
if scrollDown {
|
||||||
|
m.postScroll(src, -3)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.lastButtons = buttonMask
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MacInputInjector) postMouse(src uintptr, eventType int32, x, y float64, button int32) {
|
||||||
|
if cgEventCreateMouseEvent == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
event := cgEventCreateMouseEvent(src, eventType, x, y, button)
|
||||||
|
if event == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cgEventPost(kCGHIDEventTap, event)
|
||||||
|
cfRelease(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MacInputInjector) postScroll(src uintptr, deltaY int32) {
|
||||||
|
if cgEventCreateScrollWheelEventAddr == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CGEventCreateScrollWheelEvent(source, units, wheelCount, wheel1delta)
|
||||||
|
// units=0 (pixel), wheelCount=1, wheel1delta=deltaY
|
||||||
|
// Variadic C function: pass args as uintptr via SyscallN.
|
||||||
|
r1, _, _ := purego.SyscallN(cgEventCreateScrollWheelEventAddr,
|
||||||
|
src, 0, 1, uintptr(uint32(deltaY)))
|
||||||
|
if r1 == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cgEventPost(kCGHIDEventTap, r1)
|
||||||
|
cfRelease(r1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetClipboard sets the macOS clipboard using pbcopy.
|
||||||
|
func (m *MacInputInjector) SetClipboard(text string) {
|
||||||
|
if m.pbcopyPath == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cmd := exec.Command(m.pbcopyPath)
|
||||||
|
cmd.Stdin = strings.NewReader(text)
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
log.Tracef("set clipboard via pbcopy: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClipboard reads the macOS clipboard using pbpaste.
|
||||||
|
func (m *MacInputInjector) GetClipboard() string {
|
||||||
|
if m.pbpastePath == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
out, err := exec.Command(m.pbpastePath).Output()
|
||||||
|
if err != nil {
|
||||||
|
log.Tracef("get clipboard via pbpaste: %v", err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close releases the idle-sleep assertion held for the injector's lifetime.
|
||||||
|
func (m *MacInputInjector) Close() {
|
||||||
|
releasePreventIdleSleep()
|
||||||
|
}
|
||||||
|
|
||||||
|
func keysymToMacKeycode(keysym uint32) uint16 {
|
||||||
|
if keysym >= 0x61 && keysym <= 0x7a {
|
||||||
|
return asciiToMacKey[keysym-0x61]
|
||||||
|
}
|
||||||
|
if keysym >= 0x41 && keysym <= 0x5a {
|
||||||
|
return asciiToMacKey[keysym-0x41]
|
||||||
|
}
|
||||||
|
if keysym >= 0x30 && keysym <= 0x39 {
|
||||||
|
return digitToMacKey[keysym-0x30]
|
||||||
|
}
|
||||||
|
if code, ok := specialKeyMap[keysym]; ok {
|
||||||
|
return code
|
||||||
|
}
|
||||||
|
return 0xFFFF
|
||||||
|
}
|
||||||
|
|
||||||
|
var asciiToMacKey = [26]uint16{
|
||||||
|
0x00, 0x0B, 0x08, 0x02, 0x0E, 0x03, 0x05, 0x04,
|
||||||
|
0x22, 0x26, 0x28, 0x25, 0x2E, 0x2D, 0x1F, 0x23,
|
||||||
|
0x0C, 0x0F, 0x01, 0x11, 0x20, 0x09, 0x0D, 0x07,
|
||||||
|
0x10, 0x06,
|
||||||
|
}
|
||||||
|
|
||||||
|
var digitToMacKey = [10]uint16{
|
||||||
|
0x1D, 0x12, 0x13, 0x14, 0x15, 0x17, 0x16, 0x1A, 0x1C, 0x19,
|
||||||
|
}
|
||||||
|
|
||||||
|
var specialKeyMap = map[uint32]uint16{
|
||||||
|
// Whitespace and editing
|
||||||
|
0x0020: 0x31, // space
|
||||||
|
0xff08: 0x33, // BackSpace
|
||||||
|
0xff09: 0x30, // Tab
|
||||||
|
0xff0d: 0x24, // Return
|
||||||
|
0xff1b: 0x35, // Escape
|
||||||
|
0xffff: 0x75, // Delete (forward)
|
||||||
|
|
||||||
|
// Navigation
|
||||||
|
0xff50: 0x73, // Home
|
||||||
|
0xff51: 0x7B, // Left
|
||||||
|
0xff52: 0x7E, // Up
|
||||||
|
0xff53: 0x7C, // Right
|
||||||
|
0xff54: 0x7D, // Down
|
||||||
|
0xff55: 0x74, // Page_Up
|
||||||
|
0xff56: 0x79, // Page_Down
|
||||||
|
0xff57: 0x77, // End
|
||||||
|
0xff63: 0x72, // Insert (Help on Mac)
|
||||||
|
|
||||||
|
// Modifiers
|
||||||
|
0xffe1: 0x38, // Shift_L
|
||||||
|
0xffe2: 0x3C, // Shift_R
|
||||||
|
0xffe3: 0x3B, // Control_L
|
||||||
|
0xffe4: 0x3E, // Control_R
|
||||||
|
0xffe5: 0x39, // Caps_Lock
|
||||||
|
0xffe9: 0x3A, // Alt_L (Option)
|
||||||
|
0xffea: 0x3D, // Alt_R (Option)
|
||||||
|
0xffe7: 0x37, // Meta_L (Command)
|
||||||
|
0xffe8: 0x36, // Meta_R (Command)
|
||||||
|
0xffeb: 0x37, // Super_L (Command) - noVNC sends this
|
||||||
|
0xffec: 0x36, // Super_R (Command)
|
||||||
|
|
||||||
|
// Mode_switch / ISO_Level3_Shift (sent by noVNC for macOS Option remap)
|
||||||
|
0xff7e: 0x3A, // Mode_switch -> Option
|
||||||
|
0xfe03: 0x3D, // ISO_Level3_Shift -> Right Option
|
||||||
|
|
||||||
|
// Function keys
|
||||||
|
0xffbe: 0x7A, // F1
|
||||||
|
0xffbf: 0x78, // F2
|
||||||
|
0xffc0: 0x63, // F3
|
||||||
|
0xffc1: 0x76, // F4
|
||||||
|
0xffc2: 0x60, // F5
|
||||||
|
0xffc3: 0x61, // F6
|
||||||
|
0xffc4: 0x62, // F7
|
||||||
|
0xffc5: 0x64, // F8
|
||||||
|
0xffc6: 0x65, // F9
|
||||||
|
0xffc7: 0x6D, // F10
|
||||||
|
0xffc8: 0x67, // F11
|
||||||
|
0xffc9: 0x6F, // F12
|
||||||
|
0xffca: 0x69, // F13
|
||||||
|
0xffcb: 0x6B, // F14
|
||||||
|
0xffcc: 0x71, // F15
|
||||||
|
0xffcd: 0x6A, // F16
|
||||||
|
0xffce: 0x40, // F17
|
||||||
|
0xffcf: 0x4F, // F18
|
||||||
|
0xffd0: 0x50, // F19
|
||||||
|
0xffd1: 0x5A, // F20
|
||||||
|
|
||||||
|
// Punctuation (US keyboard layout, keysym = ASCII code)
|
||||||
|
0x002d: 0x1B, // minus -
|
||||||
|
0x003d: 0x18, // equal =
|
||||||
|
0x005b: 0x21, // bracketleft [
|
||||||
|
0x005d: 0x1E, // bracketright ]
|
||||||
|
0x005c: 0x2A, // backslash
|
||||||
|
0x003b: 0x29, // semicolon ;
|
||||||
|
0x0027: 0x27, // apostrophe '
|
||||||
|
0x0060: 0x32, // grave `
|
||||||
|
0x002c: 0x2B, // comma ,
|
||||||
|
0x002e: 0x2F, // period .
|
||||||
|
0x002f: 0x2C, // slash /
|
||||||
|
|
||||||
|
// Shifted punctuation (noVNC sends these as separate keysyms)
|
||||||
|
0x005f: 0x1B, // underscore _ (shift+minus)
|
||||||
|
0x002b: 0x18, // plus + (shift+equal)
|
||||||
|
0x007b: 0x21, // braceleft { (shift+[)
|
||||||
|
0x007d: 0x1E, // braceright } (shift+])
|
||||||
|
0x007c: 0x2A, // bar | (shift+\)
|
||||||
|
0x003a: 0x29, // colon : (shift+;)
|
||||||
|
0x0022: 0x27, // quotedbl " (shift+')
|
||||||
|
0x007e: 0x32, // tilde ~ (shift+`)
|
||||||
|
0x003c: 0x2B, // less < (shift+,)
|
||||||
|
0x003e: 0x2F, // greater > (shift+.)
|
||||||
|
0x003f: 0x2C, // question ? (shift+/)
|
||||||
|
0x0021: 0x12, // exclam ! (shift+1)
|
||||||
|
0x0040: 0x13, // at @ (shift+2)
|
||||||
|
0x0023: 0x14, // numbersign # (shift+3)
|
||||||
|
0x0024: 0x15, // dollar $ (shift+4)
|
||||||
|
0x0025: 0x17, // percent % (shift+5)
|
||||||
|
0x005e: 0x16, // asciicircum ^ (shift+6)
|
||||||
|
0x0026: 0x1A, // ampersand & (shift+7)
|
||||||
|
0x002a: 0x1C, // asterisk * (shift+8)
|
||||||
|
0x0028: 0x19, // parenleft ( (shift+9)
|
||||||
|
0x0029: 0x1D, // parenright ) (shift+0)
|
||||||
|
|
||||||
|
// Numpad
|
||||||
|
0xffb0: 0x52, // KP_0
|
||||||
|
0xffb1: 0x53, // KP_1
|
||||||
|
0xffb2: 0x54, // KP_2
|
||||||
|
0xffb3: 0x55, // KP_3
|
||||||
|
0xffb4: 0x56, // KP_4
|
||||||
|
0xffb5: 0x57, // KP_5
|
||||||
|
0xffb6: 0x58, // KP_6
|
||||||
|
0xffb7: 0x59, // KP_7
|
||||||
|
0xffb8: 0x5B, // KP_8
|
||||||
|
0xffb9: 0x5C, // KP_9
|
||||||
|
0xffae: 0x41, // KP_Decimal
|
||||||
|
0xffaa: 0x43, // KP_Multiply
|
||||||
|
0xffab: 0x45, // KP_Add
|
||||||
|
0xffad: 0x4E, // KP_Subtract
|
||||||
|
0xffaf: 0x4B, // KP_Divide
|
||||||
|
0xff8d: 0x4C, // KP_Enter
|
||||||
|
0xffbd: 0x51, // KP_Equal
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ InputInjector = (*MacInputInjector)(nil)
|
||||||
398
client/vnc/server/input_windows.go
Normal file
398
client/vnc/server/input_windows.go
Normal file
@@ -0,0 +1,398 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"runtime"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
procOpenEventW = kernel32.NewProc("OpenEventW")
|
||||||
|
procSendInput = user32.NewProc("SendInput")
|
||||||
|
procVkKeyScanA = user32.NewProc("VkKeyScanA")
|
||||||
|
)
|
||||||
|
|
||||||
|
const eventModifyState = 0x0002
|
||||||
|
|
||||||
|
const (
|
||||||
|
inputMouse = 0
|
||||||
|
inputKeyboard = 1
|
||||||
|
|
||||||
|
mouseeventfMove = 0x0001
|
||||||
|
mouseeventfLeftDown = 0x0002
|
||||||
|
mouseeventfLeftUp = 0x0004
|
||||||
|
mouseeventfRightDown = 0x0008
|
||||||
|
mouseeventfRightUp = 0x0010
|
||||||
|
mouseeventfMiddleDown = 0x0020
|
||||||
|
mouseeventfMiddleUp = 0x0040
|
||||||
|
mouseeventfWheel = 0x0800
|
||||||
|
mouseeventfAbsolute = 0x8000
|
||||||
|
|
||||||
|
wheelDelta = 120
|
||||||
|
|
||||||
|
keyeventfKeyUp = 0x0002
|
||||||
|
keyeventfScanCode = 0x0008
|
||||||
|
)
|
||||||
|
|
||||||
|
type mouseInput struct {
|
||||||
|
Dx int32
|
||||||
|
Dy int32
|
||||||
|
MouseData uint32
|
||||||
|
DwFlags uint32
|
||||||
|
Time uint32
|
||||||
|
DwExtraInfo uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
type keybdInput struct {
|
||||||
|
WVk uint16
|
||||||
|
WScan uint16
|
||||||
|
DwFlags uint32
|
||||||
|
Time uint32
|
||||||
|
DwExtraInfo uintptr
|
||||||
|
_ [8]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type inputUnion [32]byte
|
||||||
|
|
||||||
|
type winInput struct {
|
||||||
|
Type uint32
|
||||||
|
_ [4]byte
|
||||||
|
Data inputUnion
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendMouseInput(flags uint32, dx, dy int32, mouseData uint32) {
|
||||||
|
mi := mouseInput{
|
||||||
|
Dx: dx,
|
||||||
|
Dy: dy,
|
||||||
|
MouseData: mouseData,
|
||||||
|
DwFlags: flags,
|
||||||
|
}
|
||||||
|
inp := winInput{Type: inputMouse}
|
||||||
|
copy(inp.Data[:], (*[unsafe.Sizeof(mi)]byte)(unsafe.Pointer(&mi))[:])
|
||||||
|
r, _, err := procSendInput.Call(1, uintptr(unsafe.Pointer(&inp)), unsafe.Sizeof(inp))
|
||||||
|
if r == 0 {
|
||||||
|
log.Tracef("SendInput(mouse flags=0x%x): %v", flags, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendKeyInput(vk uint16, scanCode uint16, flags uint32) {
|
||||||
|
ki := keybdInput{
|
||||||
|
WVk: vk,
|
||||||
|
WScan: scanCode,
|
||||||
|
DwFlags: flags,
|
||||||
|
}
|
||||||
|
inp := winInput{Type: inputKeyboard}
|
||||||
|
copy(inp.Data[:], (*[unsafe.Sizeof(ki)]byte)(unsafe.Pointer(&ki))[:])
|
||||||
|
r, _, err := procSendInput.Call(1, uintptr(unsafe.Pointer(&inp)), unsafe.Sizeof(inp))
|
||||||
|
if r == 0 {
|
||||||
|
log.Tracef("SendInput(key vk=0x%x): %v", vk, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const sasEventName = `Global\NetBirdVNC_SAS`
|
||||||
|
|
||||||
|
type inputCmd struct {
|
||||||
|
isKey bool
|
||||||
|
keysym uint32
|
||||||
|
down bool
|
||||||
|
buttonMask uint8
|
||||||
|
x, y int
|
||||||
|
serverW int
|
||||||
|
serverH int
|
||||||
|
}
|
||||||
|
|
||||||
|
// WindowsInputInjector delivers input events from a dedicated OS thread that
|
||||||
|
// calls switchToInputDesktop before each injection. SendInput targets the
|
||||||
|
// calling thread's desktop, so the injection thread must be on the same
|
||||||
|
// desktop the user sees.
|
||||||
|
type WindowsInputInjector struct {
|
||||||
|
ch chan inputCmd
|
||||||
|
prevButtonMask uint8
|
||||||
|
ctrlDown bool
|
||||||
|
altDown bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWindowsInputInjector creates a desktop-aware input injector.
|
||||||
|
func NewWindowsInputInjector() *WindowsInputInjector {
|
||||||
|
w := &WindowsInputInjector{ch: make(chan inputCmd, 64)}
|
||||||
|
go w.loop()
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *WindowsInputInjector) loop() {
|
||||||
|
runtime.LockOSThread()
|
||||||
|
|
||||||
|
for cmd := range w.ch {
|
||||||
|
// Switch to the current input desktop so SendInput reaches the right target.
|
||||||
|
switchToInputDesktop()
|
||||||
|
|
||||||
|
if cmd.isKey {
|
||||||
|
w.doInjectKey(cmd.keysym, cmd.down)
|
||||||
|
} else {
|
||||||
|
w.doInjectPointer(cmd.buttonMask, cmd.x, cmd.y, cmd.serverW, cmd.serverH)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InjectKey queues a key event for injection on the input desktop thread.
|
||||||
|
func (w *WindowsInputInjector) InjectKey(keysym uint32, down bool) {
|
||||||
|
w.ch <- inputCmd{isKey: true, keysym: keysym, down: down}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InjectPointer queues a pointer event for injection on the input desktop thread.
|
||||||
|
func (w *WindowsInputInjector) InjectPointer(buttonMask uint8, x, y, serverW, serverH int) {
|
||||||
|
w.ch <- inputCmd{buttonMask: buttonMask, x: x, y: y, serverW: serverW, serverH: serverH}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *WindowsInputInjector) doInjectKey(keysym uint32, down bool) {
|
||||||
|
switch keysym {
|
||||||
|
case 0xffe3, 0xffe4:
|
||||||
|
w.ctrlDown = down
|
||||||
|
case 0xffe9, 0xffea:
|
||||||
|
w.altDown = down
|
||||||
|
}
|
||||||
|
|
||||||
|
if (keysym == 0xff9f || keysym == 0xffff) && w.ctrlDown && w.altDown && down {
|
||||||
|
signalSAS()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
vk, _, extended := keysym2VK(keysym)
|
||||||
|
if vk == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var flags uint32
|
||||||
|
if !down {
|
||||||
|
flags |= keyeventfKeyUp
|
||||||
|
}
|
||||||
|
if extended {
|
||||||
|
flags |= keyeventfScanCode
|
||||||
|
}
|
||||||
|
sendKeyInput(vk, 0, flags)
|
||||||
|
}
|
||||||
|
|
||||||
|
// signalSAS signals the SAS named event. A listener in Session 0
|
||||||
|
// (startSASListener) calls SendSAS to trigger the Secure Attention Sequence.
|
||||||
|
func signalSAS() {
|
||||||
|
namePtr, err := windows.UTF16PtrFromString(sasEventName)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("SAS UTF16: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h, _, lerr := procOpenEventW.Call(
|
||||||
|
uintptr(eventModifyState),
|
||||||
|
0,
|
||||||
|
uintptr(unsafe.Pointer(namePtr)),
|
||||||
|
)
|
||||||
|
if h == 0 {
|
||||||
|
log.Warnf("OpenEvent(%s): %v", sasEventName, lerr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ev := windows.Handle(h)
|
||||||
|
defer windows.CloseHandle(ev)
|
||||||
|
if err := windows.SetEvent(ev); err != nil {
|
||||||
|
log.Warnf("SetEvent SAS: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Info("SAS event signaled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *WindowsInputInjector) doInjectPointer(buttonMask uint8, x, y, serverW, serverH int) {
|
||||||
|
if serverW == 0 || serverH == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
absX := int32(x * 65535 / serverW)
|
||||||
|
absY := int32(y * 65535 / serverH)
|
||||||
|
|
||||||
|
sendMouseInput(mouseeventfMove|mouseeventfAbsolute, absX, absY, 0)
|
||||||
|
|
||||||
|
changed := buttonMask ^ w.prevButtonMask
|
||||||
|
w.prevButtonMask = buttonMask
|
||||||
|
|
||||||
|
type btnMap struct {
|
||||||
|
bit uint8
|
||||||
|
down uint32
|
||||||
|
up uint32
|
||||||
|
}
|
||||||
|
buttons := [...]btnMap{
|
||||||
|
{0x01, mouseeventfLeftDown, mouseeventfLeftUp},
|
||||||
|
{0x02, mouseeventfMiddleDown, mouseeventfMiddleUp},
|
||||||
|
{0x04, mouseeventfRightDown, mouseeventfRightUp},
|
||||||
|
}
|
||||||
|
for _, b := range buttons {
|
||||||
|
if changed&b.bit == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var flags uint32
|
||||||
|
if buttonMask&b.bit != 0 {
|
||||||
|
flags = b.down
|
||||||
|
} else {
|
||||||
|
flags = b.up
|
||||||
|
}
|
||||||
|
sendMouseInput(flags|mouseeventfAbsolute, absX, absY, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
negWheelDelta := ^uint32(wheelDelta - 1)
|
||||||
|
if changed&0x08 != 0 && buttonMask&0x08 != 0 {
|
||||||
|
sendMouseInput(mouseeventfWheel|mouseeventfAbsolute, absX, absY, wheelDelta)
|
||||||
|
}
|
||||||
|
if changed&0x10 != 0 && buttonMask&0x10 != 0 {
|
||||||
|
sendMouseInput(mouseeventfWheel|mouseeventfAbsolute, absX, absY, negWheelDelta)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// keysym2VK converts an X11 keysym to a Windows virtual key code.
|
||||||
|
func keysym2VK(keysym uint32) (vk uint16, scan uint16, extended bool) {
|
||||||
|
if keysym >= 0x20 && keysym <= 0x7e {
|
||||||
|
r, _, _ := procVkKeyScanA.Call(uintptr(keysym))
|
||||||
|
vk = uint16(r & 0xff)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if keysym >= 0xffbe && keysym <= 0xffc9 {
|
||||||
|
vk = uint16(0x70 + keysym - 0xffbe)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch keysym {
|
||||||
|
case 0xff08:
|
||||||
|
vk = 0x08 // Backspace
|
||||||
|
case 0xff09:
|
||||||
|
vk = 0x09 // Tab
|
||||||
|
case 0xff0d:
|
||||||
|
vk = 0x0d // Return
|
||||||
|
case 0xff1b:
|
||||||
|
vk = 0x1b // Escape
|
||||||
|
case 0xff63:
|
||||||
|
vk, extended = 0x2d, true // Insert
|
||||||
|
case 0xff9f, 0xffff:
|
||||||
|
vk, extended = 0x2e, true // Delete
|
||||||
|
case 0xff50:
|
||||||
|
vk, extended = 0x24, true // Home
|
||||||
|
case 0xff57:
|
||||||
|
vk, extended = 0x23, true // End
|
||||||
|
case 0xff55:
|
||||||
|
vk, extended = 0x21, true // PageUp
|
||||||
|
case 0xff56:
|
||||||
|
vk, extended = 0x22, true // PageDown
|
||||||
|
case 0xff51:
|
||||||
|
vk, extended = 0x25, true // Left
|
||||||
|
case 0xff52:
|
||||||
|
vk, extended = 0x26, true // Up
|
||||||
|
case 0xff53:
|
||||||
|
vk, extended = 0x27, true // Right
|
||||||
|
case 0xff54:
|
||||||
|
vk, extended = 0x28, true // Down
|
||||||
|
case 0xffe1, 0xffe2:
|
||||||
|
vk = 0x10 // Shift
|
||||||
|
case 0xffe3, 0xffe4:
|
||||||
|
vk = 0x11 // Control
|
||||||
|
case 0xffe9, 0xffea:
|
||||||
|
vk = 0x12 // Alt
|
||||||
|
case 0xffe5:
|
||||||
|
vk = 0x14 // CapsLock
|
||||||
|
case 0xffe7, 0xffeb:
|
||||||
|
vk, extended = 0x5B, true // Meta_L / Super_L -> Left Windows
|
||||||
|
case 0xffe8, 0xffec:
|
||||||
|
vk, extended = 0x5C, true // Meta_R / Super_R -> Right Windows
|
||||||
|
case 0xff61:
|
||||||
|
vk = 0x2c // PrintScreen
|
||||||
|
case 0xff13:
|
||||||
|
vk = 0x13 // Pause
|
||||||
|
case 0xff14:
|
||||||
|
vk = 0x91 // ScrollLock
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
procOpenClipboard = user32.NewProc("OpenClipboard")
|
||||||
|
procCloseClipboard = user32.NewProc("CloseClipboard")
|
||||||
|
procEmptyClipboard = user32.NewProc("EmptyClipboard")
|
||||||
|
procSetClipboardData = user32.NewProc("SetClipboardData")
|
||||||
|
procGetClipboardData = user32.NewProc("GetClipboardData")
|
||||||
|
procIsClipboardFormatAvailable = user32.NewProc("IsClipboardFormatAvailable")
|
||||||
|
|
||||||
|
procGlobalAlloc = kernel32.NewProc("GlobalAlloc")
|
||||||
|
procGlobalLock = kernel32.NewProc("GlobalLock")
|
||||||
|
procGlobalUnlock = kernel32.NewProc("GlobalUnlock")
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
cfUnicodeText = 13
|
||||||
|
gmemMoveable = 0x0002
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetClipboard sets the Windows clipboard to the given UTF-8 text.
|
||||||
|
func (w *WindowsInputInjector) SetClipboard(text string) {
|
||||||
|
utf16, err := windows.UTF16FromString(text)
|
||||||
|
if err != nil {
|
||||||
|
log.Tracef("clipboard UTF16 encode: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
size := uintptr(len(utf16) * 2)
|
||||||
|
hMem, _, _ := procGlobalAlloc.Call(gmemMoveable, size)
|
||||||
|
if hMem == 0 {
|
||||||
|
log.Tracef("GlobalAlloc for clipboard: allocation returned nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ptr, _, _ := procGlobalLock.Call(hMem)
|
||||||
|
if ptr == 0 {
|
||||||
|
log.Tracef("GlobalLock for clipboard: lock returned nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
copy(unsafe.Slice((*uint16)(unsafe.Pointer(ptr)), len(utf16)), utf16)
|
||||||
|
procGlobalUnlock.Call(hMem)
|
||||||
|
|
||||||
|
r, _, lerr := procOpenClipboard.Call(0)
|
||||||
|
if r == 0 {
|
||||||
|
log.Tracef("OpenClipboard: %v", lerr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer procCloseClipboard.Call()
|
||||||
|
|
||||||
|
procEmptyClipboard.Call()
|
||||||
|
r, _, lerr = procSetClipboardData.Call(cfUnicodeText, hMem)
|
||||||
|
if r == 0 {
|
||||||
|
log.Tracef("SetClipboardData: %v", lerr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClipboard reads the Windows clipboard as UTF-8 text.
|
||||||
|
func (w *WindowsInputInjector) GetClipboard() string {
|
||||||
|
r, _, _ := procIsClipboardFormatAvailable.Call(cfUnicodeText)
|
||||||
|
if r == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
r, _, lerr := procOpenClipboard.Call(0)
|
||||||
|
if r == 0 {
|
||||||
|
log.Tracef("OpenClipboard for read: %v", lerr)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
defer procCloseClipboard.Call()
|
||||||
|
|
||||||
|
hData, _, _ := procGetClipboardData.Call(cfUnicodeText)
|
||||||
|
if hData == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
ptr, _, _ := procGlobalLock.Call(hData)
|
||||||
|
if ptr == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
defer procGlobalUnlock.Call(hData)
|
||||||
|
|
||||||
|
return windows.UTF16PtrToString((*uint16)(unsafe.Pointer(ptr)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ InputInjector = (*WindowsInputInjector)(nil)
|
||||||
|
|
||||||
|
var _ ScreenCapturer = (*DesktopCapturer)(nil)
|
||||||
242
client/vnc/server/input_x11.go
Normal file
242
client/vnc/server/input_x11.go
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
//go:build (linux && !android) || freebsd
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/jezek/xgb"
|
||||||
|
"github.com/jezek/xgb/xproto"
|
||||||
|
"github.com/jezek/xgb/xtest"
|
||||||
|
)
|
||||||
|
|
||||||
|
// X11InputInjector injects keyboard and mouse events via the XTest extension.
|
||||||
|
type X11InputInjector struct {
|
||||||
|
conn *xgb.Conn
|
||||||
|
root xproto.Window
|
||||||
|
screen *xproto.ScreenInfo
|
||||||
|
display string
|
||||||
|
keysymMap map[uint32]byte
|
||||||
|
lastButtons uint8
|
||||||
|
clipboardTool string
|
||||||
|
clipboardToolName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewX11InputInjector connects to the X11 display and initializes XTest.
|
||||||
|
func NewX11InputInjector(display string) (*X11InputInjector, error) {
|
||||||
|
detectX11Display()
|
||||||
|
|
||||||
|
if display == "" {
|
||||||
|
display = os.Getenv("DISPLAY")
|
||||||
|
}
|
||||||
|
if display == "" {
|
||||||
|
return nil, fmt.Errorf("DISPLAY not set and no Xorg process found")
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := xgb.NewConnDisplay(display)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("connect to X11 display %s: %w", display, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := xtest.Init(conn); err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, fmt.Errorf("init XTest extension: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
setup := xproto.Setup(conn)
|
||||||
|
if len(setup.Roots) == 0 {
|
||||||
|
conn.Close()
|
||||||
|
return nil, fmt.Errorf("no X11 screens")
|
||||||
|
}
|
||||||
|
screen := setup.Roots[0]
|
||||||
|
|
||||||
|
inj := &X11InputInjector{
|
||||||
|
conn: conn,
|
||||||
|
root: screen.Root,
|
||||||
|
screen: &screen,
|
||||||
|
display: display,
|
||||||
|
}
|
||||||
|
inj.cacheKeyboardMapping()
|
||||||
|
inj.resolveClipboardTool()
|
||||||
|
|
||||||
|
log.Infof("X11 input injector ready (display=%s)", display)
|
||||||
|
return inj, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InjectKey simulates a key press or release. keysym is an X11 KeySym.
|
||||||
|
func (x *X11InputInjector) InjectKey(keysym uint32, down bool) {
|
||||||
|
keycode := x.keysymToKeycode(keysym)
|
||||||
|
if keycode == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var eventType byte
|
||||||
|
if down {
|
||||||
|
eventType = xproto.KeyPress
|
||||||
|
} else {
|
||||||
|
eventType = xproto.KeyRelease
|
||||||
|
}
|
||||||
|
|
||||||
|
xtest.FakeInput(x.conn, eventType, keycode, 0, x.root, 0, 0, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InjectPointer simulates mouse movement and button events.
|
||||||
|
func (x *X11InputInjector) InjectPointer(buttonMask uint8, px, py, serverW, serverH int) {
|
||||||
|
if serverW == 0 || serverH == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scale to actual screen coordinates.
|
||||||
|
screenW := int(x.screen.WidthInPixels)
|
||||||
|
screenH := int(x.screen.HeightInPixels)
|
||||||
|
absX := px * screenW / serverW
|
||||||
|
absY := py * screenH / serverH
|
||||||
|
|
||||||
|
// Move pointer.
|
||||||
|
xtest.FakeInput(x.conn, xproto.MotionNotify, 0, 0, x.root, int16(absX), int16(absY), 0)
|
||||||
|
|
||||||
|
// Handle button events. RFB button mask: bit0=left, bit1=middle, bit2=right,
|
||||||
|
// bit3=scrollUp, bit4=scrollDown. X11 buttons: 1=left, 2=middle, 3=right,
|
||||||
|
// 4=scrollUp, 5=scrollDown.
|
||||||
|
type btnMap struct {
|
||||||
|
rfbBit uint8
|
||||||
|
x11Btn byte
|
||||||
|
}
|
||||||
|
buttons := [...]btnMap{
|
||||||
|
{0x01, 1}, // left
|
||||||
|
{0x02, 2}, // middle
|
||||||
|
{0x04, 3}, // right
|
||||||
|
{0x08, 4}, // scroll up
|
||||||
|
{0x10, 5}, // scroll down
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, b := range buttons {
|
||||||
|
pressed := buttonMask&b.rfbBit != 0
|
||||||
|
wasPressed := x.lastButtons&b.rfbBit != 0
|
||||||
|
if b.x11Btn >= 4 {
|
||||||
|
// Scroll: send press+release on each scroll event.
|
||||||
|
if pressed {
|
||||||
|
xtest.FakeInput(x.conn, xproto.ButtonPress, b.x11Btn, 0, x.root, 0, 0, 0)
|
||||||
|
xtest.FakeInput(x.conn, xproto.ButtonRelease, b.x11Btn, 0, x.root, 0, 0, 0)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if pressed && !wasPressed {
|
||||||
|
xtest.FakeInput(x.conn, xproto.ButtonPress, b.x11Btn, 0, x.root, 0, 0, 0)
|
||||||
|
} else if !pressed && wasPressed {
|
||||||
|
xtest.FakeInput(x.conn, xproto.ButtonRelease, b.x11Btn, 0, x.root, 0, 0, 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
x.lastButtons = buttonMask
|
||||||
|
}
|
||||||
|
|
||||||
|
// cacheKeyboardMapping fetches the X11 keyboard mapping once and stores it
|
||||||
|
// as a keysym-to-keycode map, avoiding a round-trip per keystroke.
|
||||||
|
func (x *X11InputInjector) cacheKeyboardMapping() {
|
||||||
|
setup := xproto.Setup(x.conn)
|
||||||
|
minKeycode := setup.MinKeycode
|
||||||
|
maxKeycode := setup.MaxKeycode
|
||||||
|
|
||||||
|
reply, err := xproto.GetKeyboardMapping(x.conn, minKeycode,
|
||||||
|
byte(maxKeycode-minKeycode+1)).Reply()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("cache keyboard mapping: %v", err)
|
||||||
|
x.keysymMap = make(map[uint32]byte)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m := make(map[uint32]byte, int(maxKeycode-minKeycode+1)*int(reply.KeysymsPerKeycode))
|
||||||
|
keysymsPerKeycode := int(reply.KeysymsPerKeycode)
|
||||||
|
for i := int(minKeycode); i <= int(maxKeycode); i++ {
|
||||||
|
offset := (i - int(minKeycode)) * keysymsPerKeycode
|
||||||
|
for j := 0; j < keysymsPerKeycode; j++ {
|
||||||
|
ks := uint32(reply.Keysyms[offset+j])
|
||||||
|
if ks != 0 {
|
||||||
|
if _, exists := m[ks]; !exists {
|
||||||
|
m[ks] = byte(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
x.keysymMap = m
|
||||||
|
}
|
||||||
|
|
||||||
|
// keysymToKeycode looks up a cached keysym-to-keycode mapping.
|
||||||
|
// Returns 0 if the keysym is not mapped.
|
||||||
|
func (x *X11InputInjector) keysymToKeycode(keysym uint32) byte {
|
||||||
|
return x.keysymMap[keysym]
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetClipboard sets the X11 clipboard using xclip or xsel.
|
||||||
|
func (x *X11InputInjector) SetClipboard(text string) {
|
||||||
|
if x.clipboardTool == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd *exec.Cmd
|
||||||
|
if x.clipboardToolName == "xclip" {
|
||||||
|
cmd = exec.Command(x.clipboardTool, "-selection", "clipboard")
|
||||||
|
} else {
|
||||||
|
cmd = exec.Command(x.clipboardTool, "--clipboard", "--input")
|
||||||
|
}
|
||||||
|
cmd.Env = x.clipboardEnv()
|
||||||
|
cmd.Stdin = strings.NewReader(text)
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
log.Debugf("set clipboard via %s: %v", x.clipboardToolName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *X11InputInjector) resolveClipboardTool() {
|
||||||
|
for _, name := range []string{"xclip", "xsel"} {
|
||||||
|
path, err := exec.LookPath(name)
|
||||||
|
if err == nil {
|
||||||
|
x.clipboardTool = path
|
||||||
|
x.clipboardToolName = name
|
||||||
|
log.Debugf("clipboard tool resolved to %s", path)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Debugf("no clipboard tool (xclip/xsel) found, clipboard sync disabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClipboard reads the X11 clipboard using xclip or xsel.
|
||||||
|
func (x *X11InputInjector) GetClipboard() string {
|
||||||
|
if x.clipboardTool == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd *exec.Cmd
|
||||||
|
if x.clipboardToolName == "xclip" {
|
||||||
|
cmd = exec.Command(x.clipboardTool, "-selection", "clipboard", "-o")
|
||||||
|
} else {
|
||||||
|
cmd = exec.Command(x.clipboardTool, "--clipboard", "--output")
|
||||||
|
}
|
||||||
|
cmd.Env = x.clipboardEnv()
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
log.Tracef("get clipboard via %s: %v", x.clipboardToolName, err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *X11InputInjector) clipboardEnv() []string {
|
||||||
|
env := []string{"DISPLAY=" + x.display}
|
||||||
|
if auth := os.Getenv("XAUTHORITY"); auth != "" {
|
||||||
|
env = append(env, "XAUTHORITY="+auth)
|
||||||
|
}
|
||||||
|
return env
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close releases X11 resources.
|
||||||
|
func (x *X11InputInjector) Close() {
|
||||||
|
x.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ InputInjector = (*X11InputInjector)(nil)
|
||||||
|
var _ ScreenCapturer = (*X11Poller)(nil)
|
||||||
175
client/vnc/server/recorder.go
Normal file
175
client/vnc/server/recorder.go
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
"image/png"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Recording file format:
|
||||||
|
//
|
||||||
|
// Header: magic(6) + width(2) + height(2) + startTime(8) + metaLen(4) + metaJSON
|
||||||
|
// Frames: offsetMs(4) + pngLen(4) + PNG image data
|
||||||
|
//
|
||||||
|
// Each frame is a PNG-encoded screenshot. Only changed frames are stored.
|
||||||
|
const recMagic = "NBVNC\x01"
|
||||||
|
|
||||||
|
// RecordingMeta holds metadata written to the recording file header.
|
||||||
|
type RecordingMeta struct {
|
||||||
|
User string `json:"user,omitempty"`
|
||||||
|
RemoteAddr string `json:"remote_addr"`
|
||||||
|
JWTUser string `json:"jwt_user,omitempty"`
|
||||||
|
Mode string `json:"mode,omitempty"`
|
||||||
|
Encrypted bool `json:"encrypted,omitempty"`
|
||||||
|
EphemeralKey string `json:"ephemeral_key,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// vncRecorder writes VNC session frames to a recording file.
|
||||||
|
type vncRecorder struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
file *os.File
|
||||||
|
startTime time.Time
|
||||||
|
closed bool
|
||||||
|
log *log.Entry
|
||||||
|
prevFrame *image.RGBA
|
||||||
|
pngEnc *png.Encoder
|
||||||
|
pngBuf bytes.Buffer
|
||||||
|
crypto *recCrypto
|
||||||
|
}
|
||||||
|
|
||||||
|
func newVNCRecorder(dir string, width, height int, meta *RecordingMeta, encryptionKey string, logger *log.Entry) (*vncRecorder, error) {
|
||||||
|
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||||
|
return nil, fmt.Errorf("create recording dir: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
filename := fmt.Sprintf("%s_vnc.rec", now.Format("20060102-150405"))
|
||||||
|
filePath := filepath.Join(dir, filename)
|
||||||
|
|
||||||
|
f, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_EXCL, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create recording file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var crypto *recCrypto
|
||||||
|
if encryptionKey != "" {
|
||||||
|
var cryptoErr error
|
||||||
|
crypto, cryptoErr = newRecCrypto(encryptionKey)
|
||||||
|
if cryptoErr != nil {
|
||||||
|
f.Close()
|
||||||
|
os.Remove(filePath)
|
||||||
|
return nil, fmt.Errorf("init encryption: %w", cryptoErr)
|
||||||
|
}
|
||||||
|
meta.Encrypted = true
|
||||||
|
meta.EphemeralKey = base64.StdEncoding.EncodeToString(crypto.ephemeralPub)
|
||||||
|
}
|
||||||
|
|
||||||
|
metaJSON, err := json.Marshal(meta)
|
||||||
|
if err != nil {
|
||||||
|
f.Close()
|
||||||
|
os.Remove(filePath)
|
||||||
|
return nil, fmt.Errorf("marshal meta: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var hdr [6 + 2 + 2 + 8 + 4]byte
|
||||||
|
copy(hdr[:6], recMagic)
|
||||||
|
binary.BigEndian.PutUint16(hdr[6:8], uint16(width))
|
||||||
|
binary.BigEndian.PutUint16(hdr[8:10], uint16(height))
|
||||||
|
binary.BigEndian.PutUint64(hdr[10:18], uint64(now.UnixMilli()))
|
||||||
|
binary.BigEndian.PutUint32(hdr[18:22], uint32(len(metaJSON)))
|
||||||
|
|
||||||
|
if _, err := f.Write(hdr[:]); err != nil {
|
||||||
|
f.Close()
|
||||||
|
os.Remove(filePath)
|
||||||
|
return nil, fmt.Errorf("write header: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := f.Write(metaJSON); err != nil {
|
||||||
|
f.Close()
|
||||||
|
os.Remove(filePath)
|
||||||
|
return nil, fmt.Errorf("write meta: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r := &vncRecorder{
|
||||||
|
file: f,
|
||||||
|
startTime: now,
|
||||||
|
log: logger.WithField("recording", filepath.Base(filePath)),
|
||||||
|
pngEnc: &png.Encoder{CompressionLevel: png.BestSpeed},
|
||||||
|
crypto: crypto,
|
||||||
|
}
|
||||||
|
if crypto != nil {
|
||||||
|
r.log.Infof("VNC recording started (encrypted): %s", filePath)
|
||||||
|
} else {
|
||||||
|
r.log.Infof("VNC recording started: %s", filePath)
|
||||||
|
}
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeFrame records a screen frame. Only writes if the frame differs from the previous one.
|
||||||
|
func (r *vncRecorder) writeFrame(img *image.RGBA) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
if r.closed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.prevFrame != nil && bytes.Equal(r.prevFrame.Pix, img.Pix) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
offsetMs := uint32(time.Since(r.startTime).Milliseconds())
|
||||||
|
|
||||||
|
r.pngBuf.Reset()
|
||||||
|
if err := r.pngEnc.Encode(&r.pngBuf, img); err != nil {
|
||||||
|
r.log.Debugf("encode PNG frame: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
frameData := r.pngBuf.Bytes()
|
||||||
|
if r.crypto != nil {
|
||||||
|
frameData = r.crypto.encrypt(frameData)
|
||||||
|
}
|
||||||
|
|
||||||
|
var frameHdr [8]byte
|
||||||
|
binary.BigEndian.PutUint32(frameHdr[0:4], offsetMs)
|
||||||
|
binary.BigEndian.PutUint32(frameHdr[4:8], uint32(len(frameData)))
|
||||||
|
|
||||||
|
if _, err := r.file.Write(frameHdr[:]); err != nil {
|
||||||
|
r.log.Debugf("write frame header: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := r.file.Write(frameData); err != nil {
|
||||||
|
r.log.Debugf("write frame data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.prevFrame == nil {
|
||||||
|
r.prevFrame = image.NewRGBA(img.Rect)
|
||||||
|
}
|
||||||
|
copy(r.prevFrame.Pix, img.Pix)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *vncRecorder) close() {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
if r.closed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.closed = true
|
||||||
|
|
||||||
|
duration := time.Since(r.startTime)
|
||||||
|
r.log.Infof("VNC recording stopped after %v", duration.Round(time.Millisecond))
|
||||||
|
r.file.Close()
|
||||||
|
}
|
||||||
|
|
||||||
202
client/vnc/server/recorder_test.go
Normal file
202
client/vnc/server/recorder_test.go
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdh"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"image"
|
||||||
|
"image/color"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func makeTestImage(w, h int, c color.RGBA) *image.RGBA {
|
||||||
|
img := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||||
|
for i := 0; i < len(img.Pix); i += 4 {
|
||||||
|
img.Pix[i] = c.R
|
||||||
|
img.Pix[i+1] = c.G
|
||||||
|
img.Pix[i+2] = c.B
|
||||||
|
img.Pix[i+3] = c.A
|
||||||
|
}
|
||||||
|
return img
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecorderWriteAndReadHeader(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
logger := log.WithField("test", t.Name())
|
||||||
|
|
||||||
|
meta := &RecordingMeta{
|
||||||
|
User: "alice",
|
||||||
|
RemoteAddr: "100.0.1.5:12345",
|
||||||
|
JWTUser: "google|123",
|
||||||
|
Mode: "session",
|
||||||
|
}
|
||||||
|
|
||||||
|
rec, err := newVNCRecorder(dir, 800, 600, meta, "", logger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Write some frames
|
||||||
|
red := makeTestImage(800, 600, color.RGBA{255, 0, 0, 255})
|
||||||
|
blue := makeTestImage(800, 600, color.RGBA{0, 0, 255, 255})
|
||||||
|
|
||||||
|
rec.writeFrame(red)
|
||||||
|
rec.writeFrame(red) // duplicate, should be skipped
|
||||||
|
rec.writeFrame(blue)
|
||||||
|
rec.close()
|
||||||
|
|
||||||
|
// Read back the header
|
||||||
|
files, err := os.ReadDir(dir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, files, 1)
|
||||||
|
|
||||||
|
filePath := filepath.Join(dir, files[0].Name())
|
||||||
|
header, err := ReadRecordingHeader(filePath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, 800, header.Width)
|
||||||
|
assert.Equal(t, 600, header.Height)
|
||||||
|
assert.Equal(t, "alice", header.Meta.User)
|
||||||
|
assert.Equal(t, "100.0.1.5:12345", header.Meta.RemoteAddr)
|
||||||
|
assert.Equal(t, "google|123", header.Meta.JWTUser)
|
||||||
|
assert.Equal(t, "session", header.Meta.Mode)
|
||||||
|
assert.False(t, header.Meta.Encrypted)
|
||||||
|
|
||||||
|
// Verify file is valid by checking size is reasonable
|
||||||
|
fi, err := os.Stat(filePath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Greater(t, fi.Size(), int64(100), "recording should have content")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecorderDuplicateFrameSkip(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
logger := log.WithField("test", t.Name())
|
||||||
|
|
||||||
|
rec, err := newVNCRecorder(dir, 100, 100, &RecordingMeta{RemoteAddr: "test"}, "", logger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
img := makeTestImage(100, 100, color.RGBA{128, 128, 128, 255})
|
||||||
|
|
||||||
|
rec.writeFrame(img)
|
||||||
|
rec.writeFrame(img) // duplicate
|
||||||
|
rec.writeFrame(img) // duplicate
|
||||||
|
rec.close()
|
||||||
|
|
||||||
|
files, _ := os.ReadDir(dir)
|
||||||
|
filePath := filepath.Join(dir, files[0].Name())
|
||||||
|
|
||||||
|
// Count frames by parsing
|
||||||
|
f, err := os.Open(filePath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
_, err = parseRecHeader(f)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
frameCount := 0
|
||||||
|
var hdr [8]byte
|
||||||
|
for {
|
||||||
|
if _, err := f.Read(hdr[:]); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
pngLen := int64(hdr[4])<<24 | int64(hdr[5])<<16 | int64(hdr[6])<<8 | int64(hdr[7])
|
||||||
|
f.Seek(pngLen, 1)
|
||||||
|
frameCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 1, frameCount, "duplicate frames should be skipped")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecorderEncrypted(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
logger := log.WithField("test", t.Name())
|
||||||
|
|
||||||
|
// Generate admin keypair
|
||||||
|
adminPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
adminPubB64 := base64.StdEncoding.EncodeToString(adminPriv.PublicKey().Bytes())
|
||||||
|
|
||||||
|
meta := &RecordingMeta{
|
||||||
|
RemoteAddr: "100.0.1.5:12345",
|
||||||
|
Mode: "attach",
|
||||||
|
}
|
||||||
|
|
||||||
|
rec, err := newVNCRecorder(dir, 200, 150, meta, adminPubB64, logger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
img := makeTestImage(200, 150, color.RGBA{255, 0, 0, 255})
|
||||||
|
rec.writeFrame(img)
|
||||||
|
rec.close()
|
||||||
|
|
||||||
|
// Read header and verify encryption metadata
|
||||||
|
files, _ := os.ReadDir(dir)
|
||||||
|
filePath := filepath.Join(dir, files[0].Name())
|
||||||
|
|
||||||
|
header, err := ReadRecordingHeader(filePath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.True(t, header.Meta.Encrypted)
|
||||||
|
assert.NotEmpty(t, header.Meta.EphemeralKey)
|
||||||
|
assert.Equal(t, 200, header.Width)
|
||||||
|
assert.Equal(t, 150, header.Height)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecorderEncryptedDecryptRoundtrip(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
logger := log.WithField("test", t.Name())
|
||||||
|
|
||||||
|
adminPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
adminPubB64 := base64.StdEncoding.EncodeToString(adminPriv.PublicKey().Bytes())
|
||||||
|
adminPrivB64 := base64.StdEncoding.EncodeToString(adminPriv.Bytes())
|
||||||
|
|
||||||
|
rec, err := newVNCRecorder(dir, 100, 100, &RecordingMeta{RemoteAddr: "test"}, adminPubB64, logger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
red := makeTestImage(100, 100, color.RGBA{255, 0, 0, 255})
|
||||||
|
green := makeTestImage(100, 100, color.RGBA{0, 255, 0, 255})
|
||||||
|
|
||||||
|
rec.writeFrame(red)
|
||||||
|
rec.writeFrame(green)
|
||||||
|
rec.close()
|
||||||
|
|
||||||
|
// Read back and decrypt
|
||||||
|
files, _ := os.ReadDir(dir)
|
||||||
|
filePath := filepath.Join(dir, files[0].Name())
|
||||||
|
|
||||||
|
header, err := ReadRecordingHeader(filePath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, header.Meta.Encrypted)
|
||||||
|
|
||||||
|
dec, err := DecryptRecording(adminPrivB64, header.Meta.EphemeralKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Read raw frames and decrypt
|
||||||
|
f, err := os.Open(filePath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
_, err = parseRecHeader(f)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
decryptedFrames := 0
|
||||||
|
var hdr [8]byte
|
||||||
|
for {
|
||||||
|
if _, readErr := f.Read(hdr[:]); readErr != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
frameLen := int(hdr[4])<<24 | int(hdr[5])<<16 | int(hdr[6])<<8 | int(hdr[7])
|
||||||
|
ct := make([]byte, frameLen)
|
||||||
|
f.Read(ct)
|
||||||
|
|
||||||
|
_, err := dec.Decrypt(ct)
|
||||||
|
require.NoError(t, err, "frame %d decrypt should succeed", decryptedFrames)
|
||||||
|
decryptedFrames++
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 2, decryptedFrames)
|
||||||
|
}
|
||||||
64
client/vnc/server/replay.go
Normal file
64
client/vnc/server/replay.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RecordingHeader holds parsed header data from a VNC recording file.
|
||||||
|
type RecordingHeader struct {
|
||||||
|
Width int
|
||||||
|
Height int
|
||||||
|
StartTime time.Time
|
||||||
|
Meta RecordingMeta
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadRecordingHeader parses and returns the recording header without loading frames.
|
||||||
|
func ReadRecordingHeader(filePath string) (*RecordingHeader, error) {
|
||||||
|
f, err := os.Open(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
return parseRecHeader(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseRecHeader(r io.Reader) (*RecordingHeader, error) {
|
||||||
|
var hdr [22]byte
|
||||||
|
if _, err := io.ReadFull(r, hdr[:]); err != nil {
|
||||||
|
return nil, fmt.Errorf("read header: %w", err)
|
||||||
|
}
|
||||||
|
if string(hdr[:6]) != recMagic {
|
||||||
|
return nil, fmt.Errorf("invalid magic: %x", hdr[:6])
|
||||||
|
}
|
||||||
|
|
||||||
|
width := int(binary.BigEndian.Uint16(hdr[6:8]))
|
||||||
|
height := int(binary.BigEndian.Uint16(hdr[8:10]))
|
||||||
|
startMs := int64(binary.BigEndian.Uint64(hdr[10:18]))
|
||||||
|
metaLen := binary.BigEndian.Uint32(hdr[18:22])
|
||||||
|
|
||||||
|
if metaLen > 1<<20 {
|
||||||
|
return nil, fmt.Errorf("meta too large: %d bytes", metaLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
metaJSON := make([]byte, metaLen)
|
||||||
|
if _, err := io.ReadFull(r, metaJSON); err != nil {
|
||||||
|
return nil, fmt.Errorf("read meta: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var meta RecordingMeta
|
||||||
|
if err := json.Unmarshal(metaJSON, &meta); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse meta: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RecordingHeader{
|
||||||
|
Width: width,
|
||||||
|
Height: height,
|
||||||
|
StartTime: time.UnixMilli(startMs),
|
||||||
|
Meta: meta,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
264
client/vnc/server/rfb.go
Normal file
264
client/vnc/server/rfb.go
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/zlib"
|
||||||
|
"crypto/des"
|
||||||
|
"encoding/binary"
|
||||||
|
"image"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
rfbProtocolVersion = "RFB 003.008\n"
|
||||||
|
|
||||||
|
secNone = 1
|
||||||
|
secVNCAuth = 2
|
||||||
|
|
||||||
|
// Client message types.
|
||||||
|
clientSetPixelFormat = 0
|
||||||
|
clientSetEncodings = 2
|
||||||
|
clientFramebufferUpdateRequest = 3
|
||||||
|
clientKeyEvent = 4
|
||||||
|
clientPointerEvent = 5
|
||||||
|
clientCutText = 6
|
||||||
|
|
||||||
|
// Server message types.
|
||||||
|
serverFramebufferUpdate = 0
|
||||||
|
serverCutText = 3
|
||||||
|
|
||||||
|
// Encoding types.
|
||||||
|
encRaw = 0
|
||||||
|
encZlib = 6
|
||||||
|
)
|
||||||
|
|
||||||
|
// serverPixelFormat is the default pixel format advertised by the server:
|
||||||
|
// 32bpp RGBA, big-endian, true-colour, 8 bits per channel.
|
||||||
|
var serverPixelFormat = [16]byte{
|
||||||
|
32, // bits-per-pixel
|
||||||
|
24, // depth
|
||||||
|
1, // big-endian-flag
|
||||||
|
1, // true-colour-flag
|
||||||
|
0, 255, // red-max
|
||||||
|
0, 255, // green-max
|
||||||
|
0, 255, // blue-max
|
||||||
|
16, // red-shift
|
||||||
|
8, // green-shift
|
||||||
|
0, // blue-shift
|
||||||
|
0, 0, 0, // padding
|
||||||
|
}
|
||||||
|
|
||||||
|
// clientPixelFormat holds the negotiated pixel format from the client.
|
||||||
|
type clientPixelFormat struct {
|
||||||
|
bpp uint8
|
||||||
|
bigEndian uint8
|
||||||
|
rMax uint16
|
||||||
|
gMax uint16
|
||||||
|
bMax uint16
|
||||||
|
rShift uint8
|
||||||
|
gShift uint8
|
||||||
|
bShift uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultClientPixelFormat() clientPixelFormat {
|
||||||
|
return clientPixelFormat{
|
||||||
|
bpp: serverPixelFormat[0],
|
||||||
|
bigEndian: serverPixelFormat[2],
|
||||||
|
rMax: binary.BigEndian.Uint16(serverPixelFormat[4:6]),
|
||||||
|
gMax: binary.BigEndian.Uint16(serverPixelFormat[6:8]),
|
||||||
|
bMax: binary.BigEndian.Uint16(serverPixelFormat[8:10]),
|
||||||
|
rShift: serverPixelFormat[10],
|
||||||
|
gShift: serverPixelFormat[11],
|
||||||
|
bShift: serverPixelFormat[12],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePixelFormat(pf []byte) clientPixelFormat {
|
||||||
|
return clientPixelFormat{
|
||||||
|
bpp: pf[0],
|
||||||
|
bigEndian: pf[2],
|
||||||
|
rMax: binary.BigEndian.Uint16(pf[4:6]),
|
||||||
|
gMax: binary.BigEndian.Uint16(pf[6:8]),
|
||||||
|
bMax: binary.BigEndian.Uint16(pf[8:10]),
|
||||||
|
rShift: pf[10],
|
||||||
|
gShift: pf[11],
|
||||||
|
bShift: pf[12],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodeRawRect encodes a framebuffer region as a raw RFB rectangle.
|
||||||
|
// The returned buffer includes the FramebufferUpdate header (1 rectangle).
|
||||||
|
func encodeRawRect(img *image.RGBA, pf clientPixelFormat, x, y, w, h int) []byte {
|
||||||
|
bytesPerPixel := max(int(pf.bpp)/8, 1)
|
||||||
|
|
||||||
|
pixelBytes := w * h * bytesPerPixel
|
||||||
|
buf := make([]byte, 4+12+pixelBytes)
|
||||||
|
|
||||||
|
// FramebufferUpdate header.
|
||||||
|
buf[0] = serverFramebufferUpdate
|
||||||
|
buf[1] = 0 // padding
|
||||||
|
binary.BigEndian.PutUint16(buf[2:4], 1)
|
||||||
|
|
||||||
|
// Rectangle header.
|
||||||
|
binary.BigEndian.PutUint16(buf[4:6], uint16(x))
|
||||||
|
binary.BigEndian.PutUint16(buf[6:8], uint16(y))
|
||||||
|
binary.BigEndian.PutUint16(buf[8:10], uint16(w))
|
||||||
|
binary.BigEndian.PutUint16(buf[10:12], uint16(h))
|
||||||
|
binary.BigEndian.PutUint32(buf[12:16], uint32(encRaw))
|
||||||
|
|
||||||
|
off := 16
|
||||||
|
stride := img.Stride
|
||||||
|
for row := y; row < y+h; row++ {
|
||||||
|
for col := x; col < x+w; col++ {
|
||||||
|
p := row*stride + col*4
|
||||||
|
r, g, b := img.Pix[p], img.Pix[p+1], img.Pix[p+2]
|
||||||
|
|
||||||
|
rv := uint32(r) * uint32(pf.rMax) / 255
|
||||||
|
gv := uint32(g) * uint32(pf.gMax) / 255
|
||||||
|
bv := uint32(b) * uint32(pf.bMax) / 255
|
||||||
|
pixel := (rv << pf.rShift) | (gv << pf.gShift) | (bv << pf.bShift)
|
||||||
|
|
||||||
|
if pf.bigEndian != 0 {
|
||||||
|
for i := range bytesPerPixel {
|
||||||
|
buf[off+i] = byte(pixel >> uint((bytesPerPixel-1-i)*8))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i := range bytesPerPixel {
|
||||||
|
buf[off+i] = byte(pixel >> uint(i*8))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
off += bytesPerPixel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
// vncAuthEncrypt encrypts a 16-byte challenge using the VNC DES scheme.
|
||||||
|
func vncAuthEncrypt(challenge []byte, password string) []byte {
|
||||||
|
key := make([]byte, 8)
|
||||||
|
for i, c := range []byte(password) {
|
||||||
|
if i >= 8 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
key[i] = reverseBits(c)
|
||||||
|
}
|
||||||
|
block, _ := des.NewCipher(key)
|
||||||
|
out := make([]byte, 16)
|
||||||
|
block.Encrypt(out[:8], challenge[:8])
|
||||||
|
block.Encrypt(out[8:], challenge[8:])
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func reverseBits(b byte) byte {
|
||||||
|
var r byte
|
||||||
|
for range 8 {
|
||||||
|
r = (r << 1) | (b & 1)
|
||||||
|
b >>= 1
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodeZlibRect encodes a framebuffer region using Zlib compression.
|
||||||
|
// The zlib stream is continuous for the entire VNC session: noVNC creates
|
||||||
|
// one inflate context at startup and reuses it for all zlib-encoded rects.
|
||||||
|
// We must NOT reset the zlib writer between calls.
|
||||||
|
func encodeZlibRect(img *image.RGBA, pf clientPixelFormat, x, y, w, h int, zw *zlib.Writer, zbuf *bytes.Buffer) []byte {
|
||||||
|
bytesPerPixel := max(int(pf.bpp)/8, 1)
|
||||||
|
|
||||||
|
// Clear the output buffer but keep the deflate dictionary intact.
|
||||||
|
zbuf.Reset()
|
||||||
|
|
||||||
|
stride := img.Stride
|
||||||
|
pixel := make([]byte, bytesPerPixel)
|
||||||
|
for row := y; row < y+h; row++ {
|
||||||
|
for col := x; col < x+w; col++ {
|
||||||
|
p := row*stride + col*4
|
||||||
|
r, g, b := img.Pix[p], img.Pix[p+1], img.Pix[p+2]
|
||||||
|
|
||||||
|
rv := uint32(r) * uint32(pf.rMax) / 255
|
||||||
|
gv := uint32(g) * uint32(pf.gMax) / 255
|
||||||
|
bv := uint32(b) * uint32(pf.bMax) / 255
|
||||||
|
val := (rv << pf.rShift) | (gv << pf.gShift) | (bv << pf.bShift)
|
||||||
|
|
||||||
|
if pf.bigEndian != 0 {
|
||||||
|
for i := range bytesPerPixel {
|
||||||
|
pixel[i] = byte(val >> uint((bytesPerPixel-1-i)*8))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i := range bytesPerPixel {
|
||||||
|
pixel[i] = byte(val >> uint(i*8))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
zw.Write(pixel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
zw.Flush()
|
||||||
|
|
||||||
|
compressed := zbuf.Bytes()
|
||||||
|
|
||||||
|
// Build the FramebufferUpdate message.
|
||||||
|
buf := make([]byte, 4+12+4+len(compressed))
|
||||||
|
buf[0] = serverFramebufferUpdate
|
||||||
|
buf[1] = 0
|
||||||
|
binary.BigEndian.PutUint16(buf[2:4], 1) // 1 rectangle
|
||||||
|
|
||||||
|
binary.BigEndian.PutUint16(buf[4:6], uint16(x))
|
||||||
|
binary.BigEndian.PutUint16(buf[6:8], uint16(y))
|
||||||
|
binary.BigEndian.PutUint16(buf[8:10], uint16(w))
|
||||||
|
binary.BigEndian.PutUint16(buf[10:12], uint16(h))
|
||||||
|
binary.BigEndian.PutUint32(buf[12:16], uint32(encZlib))
|
||||||
|
binary.BigEndian.PutUint32(buf[16:20], uint32(len(compressed)))
|
||||||
|
copy(buf[20:], compressed)
|
||||||
|
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
// diffRects compares two RGBA images and returns a list of dirty rectangles.
|
||||||
|
// Divides the screen into tiles and checks each for changes.
|
||||||
|
func diffRects(prev, cur *image.RGBA, w, h, tileSize int) [][4]int {
|
||||||
|
if prev == nil {
|
||||||
|
return [][4]int{{0, 0, w, h}}
|
||||||
|
}
|
||||||
|
|
||||||
|
var rects [][4]int
|
||||||
|
for ty := 0; ty < h; ty += tileSize {
|
||||||
|
th := min(tileSize, h-ty)
|
||||||
|
for tx := 0; tx < w; tx += tileSize {
|
||||||
|
tw := min(tileSize, w-tx)
|
||||||
|
if tileChanged(prev, cur, tx, ty, tw, th) {
|
||||||
|
rects = append(rects, [4]int{tx, ty, tw, th})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rects
|
||||||
|
}
|
||||||
|
|
||||||
|
func tileChanged(prev, cur *image.RGBA, x, y, w, h int) bool {
|
||||||
|
stride := prev.Stride
|
||||||
|
for row := y; row < y+h; row++ {
|
||||||
|
off := row*stride + x*4
|
||||||
|
end := off + w*4
|
||||||
|
prevRow := prev.Pix[off:end]
|
||||||
|
curRow := cur.Pix[off:end]
|
||||||
|
if !bytes.Equal(prevRow, curRow) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// zlibState holds the persistent zlib writer and buffer for a session.
|
||||||
|
type zlibState struct {
|
||||||
|
buf *bytes.Buffer
|
||||||
|
w *zlib.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func newZlibState() *zlibState {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
w, _ := zlib.NewWriterLevel(buf, zlib.BestSpeed)
|
||||||
|
return &zlibState{buf: buf, w: w}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (z *zlibState) Close() error {
|
||||||
|
return z.w.Close()
|
||||||
|
}
|
||||||
690
client/vnc/server/server.go
Normal file
690
client/vnc/server/server.go
Normal file
@@ -0,0 +1,690 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
gojwt "github.com/golang-jwt/jwt/v5"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
|
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
|
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Connection modes sent by the client in the session header.
|
||||||
|
const (
|
||||||
|
ModeAttach byte = 0 // Capture current display
|
||||||
|
ModeSession byte = 1 // Virtual session as specified user
|
||||||
|
)
|
||||||
|
|
||||||
|
// RFB security-failure reason codes sent to the client. These prefixes are
|
||||||
|
// stable so dashboard/noVNC integrations can branch on them without parsing
|
||||||
|
// free text. Format: "CODE: human message".
|
||||||
|
const (
|
||||||
|
RejectCodeJWTMissing = "AUTH_JWT_MISSING"
|
||||||
|
RejectCodeJWTExpired = "AUTH_JWT_EXPIRED"
|
||||||
|
RejectCodeJWTInvalid = "AUTH_JWT_INVALID"
|
||||||
|
RejectCodeAuthForbidden = "AUTH_FORBIDDEN"
|
||||||
|
RejectCodeAuthConfig = "AUTH_CONFIG"
|
||||||
|
RejectCodeSessionError = "SESSION_ERROR"
|
||||||
|
RejectCodeCapturerError = "CAPTURER_ERROR"
|
||||||
|
RejectCodeUnsupportedOS = "UNSUPPORTED"
|
||||||
|
RejectCodeBadRequest = "BAD_REQUEST"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EnvVNCDisableDownscale disables any platform-specific framebuffer
|
||||||
|
// downscaling (e.g. Retina 2:1). Set to 1/true to send the native resolution.
|
||||||
|
const EnvVNCDisableDownscale = "NB_VNC_DISABLE_DOWNSCALE"
|
||||||
|
|
||||||
|
// ScreenCapturer grabs desktop frames for the VNC server.
|
||||||
|
type ScreenCapturer interface {
|
||||||
|
// Width returns the current screen width in pixels.
|
||||||
|
Width() int
|
||||||
|
// Height returns the current screen height in pixels.
|
||||||
|
Height() int
|
||||||
|
// Capture returns the current desktop as an RGBA image.
|
||||||
|
Capture() (*image.RGBA, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InputInjector delivers keyboard and mouse events to the OS.
|
||||||
|
type InputInjector interface {
|
||||||
|
// InjectKey simulates a key press or release. keysym is an X11 KeySym.
|
||||||
|
InjectKey(keysym uint32, down bool)
|
||||||
|
// InjectPointer simulates mouse movement and button state.
|
||||||
|
InjectPointer(buttonMask uint8, x, y, serverW, serverH int)
|
||||||
|
// SetClipboard sets the system clipboard to the given text.
|
||||||
|
SetClipboard(text string)
|
||||||
|
// GetClipboard returns the current system clipboard text.
|
||||||
|
GetClipboard() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTConfig holds JWT validation configuration for VNC auth.
|
||||||
|
type JWTConfig struct {
|
||||||
|
Issuer string
|
||||||
|
KeysLocation string
|
||||||
|
MaxTokenAge int64
|
||||||
|
Audiences []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// connectionHeader is sent by the client before the RFB handshake to specify
|
||||||
|
// the VNC session mode and authenticate.
|
||||||
|
type connectionHeader struct {
|
||||||
|
mode byte
|
||||||
|
username string
|
||||||
|
jwt string
|
||||||
|
sessionID uint32 // Windows session ID (0 = console/auto)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server is the embedded VNC server that listens on the WireGuard interface.
|
||||||
|
// It supports two operating modes:
|
||||||
|
// - Direct mode: captures the screen and handles VNC sessions in-process.
|
||||||
|
// Used when running in a user session with desktop access.
|
||||||
|
// - Service mode: proxies VNC connections to an agent process spawned in
|
||||||
|
// the active console session. Used when running as a Windows service in
|
||||||
|
// Session 0.
|
||||||
|
//
|
||||||
|
// Within direct mode, each connection can request one of two session modes
|
||||||
|
// via the connection header:
|
||||||
|
// - Attach: capture the current physical display.
|
||||||
|
// - Session: start a virtual Xvfb display as the requested user.
|
||||||
|
type Server struct {
|
||||||
|
capturer ScreenCapturer
|
||||||
|
injector InputInjector
|
||||||
|
password string
|
||||||
|
serviceMode bool
|
||||||
|
disableAuth bool
|
||||||
|
localAddr netip.Addr // NetBird WireGuard IP this server is bound to
|
||||||
|
network netip.Prefix // NetBird overlay network
|
||||||
|
log *log.Entry
|
||||||
|
|
||||||
|
recordingDir string // when set, VNC sessions are recorded to this directory
|
||||||
|
recordingEncKey string // base64-encoded X25519 public key for encrypting recordings
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
listener net.Listener
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
vmgr virtualSessionManager
|
||||||
|
jwtConfig *JWTConfig
|
||||||
|
jwtValidator *nbjwt.Validator
|
||||||
|
jwtExtractor *nbjwt.ClaimsExtractor
|
||||||
|
authorizer *sshauth.Authorizer
|
||||||
|
netstackNet *netstack.Net
|
||||||
|
agentToken []byte // raw token bytes for agent-mode auth
|
||||||
|
}
|
||||||
|
|
||||||
|
// vncSession provides capturer and injector for a virtual display session.
|
||||||
|
type vncSession interface {
|
||||||
|
Capturer() ScreenCapturer
|
||||||
|
Injector() InputInjector
|
||||||
|
Display() string
|
||||||
|
ClientConnect()
|
||||||
|
ClientDisconnect()
|
||||||
|
}
|
||||||
|
|
||||||
|
// virtualSessionManager is implemented by sessionManager on Linux.
|
||||||
|
type virtualSessionManager interface {
|
||||||
|
GetOrCreate(username string) (vncSession, error)
|
||||||
|
StopAll()
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a VNC server with the given screen capturer and input injector.
|
||||||
|
func New(capturer ScreenCapturer, injector InputInjector, password string) *Server {
|
||||||
|
return &Server{
|
||||||
|
capturer: capturer,
|
||||||
|
injector: injector,
|
||||||
|
password: password,
|
||||||
|
authorizer: sshauth.NewAuthorizer(),
|
||||||
|
log: log.WithField("component", "vnc-server"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetServiceMode enables proxy-to-agent mode for Windows service operation.
|
||||||
|
func (s *Server) SetServiceMode(enabled bool) {
|
||||||
|
s.serviceMode = enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetJWTConfig configures JWT authentication for VNC connections.
|
||||||
|
// Pass nil to disable JWT (public mode).
|
||||||
|
func (s *Server) SetJWTConfig(config *JWTConfig) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.jwtConfig = config
|
||||||
|
s.jwtValidator = nil
|
||||||
|
s.jwtExtractor = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDisableAuth disables authentication entirely.
|
||||||
|
func (s *Server) SetDisableAuth(disable bool) {
|
||||||
|
s.disableAuth = disable
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAgentToken sets a hex-encoded token that must be presented by incoming
|
||||||
|
// connections before any VNC data. Used in agent mode to verify that only the
|
||||||
|
// trusted service process connects.
|
||||||
|
func (s *Server) SetAgentToken(hexToken string) {
|
||||||
|
if hexToken == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b, err := hex.DecodeString(hexToken)
|
||||||
|
if err != nil {
|
||||||
|
s.log.Warnf("invalid agent token: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.agentToken = b
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNetstackNet sets the netstack network for userspace-only listening.
|
||||||
|
// When set, the VNC server listens via netstack instead of a real OS socket.
|
||||||
|
func (s *Server) SetNetstackNet(n *netstack.Net) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.netstackNet = n
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRecordingDir enables VNC session recording to the given directory.
|
||||||
|
func (s *Server) SetRecordingDir(dir string) {
|
||||||
|
s.recordingDir = dir
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRecordingEncryptionKey sets the base64-encoded X25519 public key for encrypting recordings.
|
||||||
|
func (s *Server) SetRecordingEncryptionKey(key string) {
|
||||||
|
s.recordingEncKey = key
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateVNCAuth updates the fine-grained authorization configuration.
|
||||||
|
func (s *Server) UpdateVNCAuth(config *sshauth.Config) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.jwtValidator = nil
|
||||||
|
s.jwtExtractor = nil
|
||||||
|
s.authorizer.Update(config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins listening for VNC connections on the given address.
|
||||||
|
// network is the NetBird overlay prefix used to validate connection sources.
|
||||||
|
func (s *Server) Start(ctx context.Context, addr netip.AddrPort, network netip.Prefix) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.listener != nil {
|
||||||
|
return fmt.Errorf("server already running")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.ctx, s.cancel = context.WithCancel(ctx)
|
||||||
|
s.vmgr = s.platformSessionManager()
|
||||||
|
s.localAddr = addr.Addr()
|
||||||
|
s.network = network
|
||||||
|
|
||||||
|
var listener net.Listener
|
||||||
|
var listenDesc string
|
||||||
|
if s.netstackNet != nil {
|
||||||
|
ln, err := s.netstackNet.ListenTCPAddrPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("listen on netstack %s: %w", addr, err)
|
||||||
|
}
|
||||||
|
listener = ln
|
||||||
|
listenDesc = fmt.Sprintf("netstack %s", addr)
|
||||||
|
} else {
|
||||||
|
tcpAddr := net.TCPAddrFromAddrPort(addr)
|
||||||
|
ln, err := net.ListenTCP("tcp", tcpAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("listen on %s: %w", addr, err)
|
||||||
|
}
|
||||||
|
listener = ln
|
||||||
|
listenDesc = addr.String()
|
||||||
|
}
|
||||||
|
s.listener = listener
|
||||||
|
|
||||||
|
if s.serviceMode {
|
||||||
|
s.platformInit()
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.serviceMode {
|
||||||
|
go s.serviceAcceptLoop()
|
||||||
|
} else {
|
||||||
|
go s.acceptLoop()
|
||||||
|
}
|
||||||
|
|
||||||
|
s.log.Infof("started on %s (service_mode=%v)", listenDesc, s.serviceMode)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop shuts down the server and closes all connections.
|
||||||
|
func (s *Server) Stop() error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.cancel != nil {
|
||||||
|
s.cancel()
|
||||||
|
s.cancel = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.vmgr != nil {
|
||||||
|
s.vmgr.StopAll()
|
||||||
|
}
|
||||||
|
|
||||||
|
if c, ok := s.capturer.(interface{ Close() }); ok {
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.listener != nil {
|
||||||
|
err := s.listener.Close()
|
||||||
|
s.listener = nil
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("close VNC listener: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.log.Info("stopped")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// acceptLoop handles VNC connections directly (user session mode).
|
||||||
|
func (s *Server) acceptLoop() {
|
||||||
|
for {
|
||||||
|
conn, err := s.listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
s.log.Debugf("accept VNC connection: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.handleConnection(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) validateCapturer(cap ScreenCapturer) error {
|
||||||
|
// Quick check first: if already ready, return immediately.
|
||||||
|
if cap.Width() > 0 && cap.Height() > 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Capturer not ready: poke any retry loop that supports it so it doesn't
|
||||||
|
// wait out its full backoff (e.g. macOS waiting for Screen Recording).
|
||||||
|
if w, ok := cap.(interface{ Wake() }); ok {
|
||||||
|
w.Wake()
|
||||||
|
}
|
||||||
|
// Wait up to 5s for the capturer to become ready.
|
||||||
|
for range 50 {
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
if cap.Width() > 0 && cap.Height() > 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return errors.New("no display available (check X11 on Linux or Screen Recording permission on macOS)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// isAllowedSource rejects connections from outside the NetBird overlay network
|
||||||
|
// and from the local WireGuard IP (prevents local privilege escalation).
|
||||||
|
// Matches the SSH server's connectionValidator logic.
|
||||||
|
func (s *Server) isAllowedSource(addr net.Addr) bool {
|
||||||
|
tcpAddr, ok := addr.(*net.TCPAddr)
|
||||||
|
if !ok {
|
||||||
|
s.log.Warnf("connection rejected: non-TCP address %s", addr)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP)
|
||||||
|
if !ok {
|
||||||
|
s.log.Warnf("connection rejected: invalid remote IP %s", tcpAddr.IP)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
remoteIP = remoteIP.Unmap()
|
||||||
|
|
||||||
|
if remoteIP.IsLoopback() && s.localAddr.IsLoopback() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if remoteIP == s.localAddr {
|
||||||
|
s.log.Warnf("connection rejected from own IP %s", remoteIP)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.network.IsValid() && !s.network.Contains(remoteIP) {
|
||||||
|
s.log.Warnf("connection rejected from non-NetBird IP %s", remoteIP)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleConnection(conn net.Conn) {
|
||||||
|
connLog := s.log.WithField("remote", conn.RemoteAddr().String())
|
||||||
|
|
||||||
|
if !s.isAllowedSource(conn.RemoteAddr()) {
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(s.agentToken) > 0 {
|
||||||
|
buf := make([]byte, len(s.agentToken))
|
||||||
|
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||||
|
connLog.Debugf("set agent token deadline: %v", err)
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||||
|
connLog.Warnf("agent auth: read token: %v", err)
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn.SetReadDeadline(time.Time{}) //nolint:errcheck
|
||||||
|
if subtle.ConstantTimeCompare(buf, s.agentToken) != 1 {
|
||||||
|
connLog.Warn("agent auth: invalid token, rejecting")
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
header, err := readConnectionHeader(conn)
|
||||||
|
if err != nil {
|
||||||
|
connLog.Warnf("read connection header: %v", err)
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.disableAuth {
|
||||||
|
if s.jwtConfig == nil {
|
||||||
|
rejectConnection(conn, codeMessage(RejectCodeAuthConfig, "auth enabled but no identity provider configured"))
|
||||||
|
connLog.Warn("auth rejected: no identity provider configured")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
jwtUserID, err := s.authenticateJWT(header)
|
||||||
|
if err != nil {
|
||||||
|
rejectConnection(conn, codeMessage(jwtErrorCode(err), err.Error()))
|
||||||
|
connLog.Warnf("auth rejected: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
connLog = connLog.WithField("jwt_user", jwtUserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
var capturer ScreenCapturer
|
||||||
|
var injector InputInjector
|
||||||
|
|
||||||
|
switch header.mode {
|
||||||
|
case ModeSession:
|
||||||
|
if s.vmgr == nil {
|
||||||
|
rejectConnection(conn, codeMessage(RejectCodeUnsupportedOS, "virtual sessions not supported on this platform"))
|
||||||
|
connLog.Warn("session rejected: not supported on this platform")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if header.username == "" {
|
||||||
|
rejectConnection(conn, codeMessage(RejectCodeBadRequest, "session mode requires a username"))
|
||||||
|
connLog.Warn("session rejected: no username provided")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
vs, err := s.vmgr.GetOrCreate(header.username)
|
||||||
|
if err != nil {
|
||||||
|
rejectConnection(conn, codeMessage(RejectCodeSessionError, fmt.Sprintf("create virtual session: %v", err)))
|
||||||
|
connLog.Warnf("create virtual session for %s: %v", header.username, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
capturer = vs.Capturer()
|
||||||
|
injector = vs.Injector()
|
||||||
|
vs.ClientConnect()
|
||||||
|
defer vs.ClientDisconnect()
|
||||||
|
connLog = connLog.WithField("vnc_user", header.username)
|
||||||
|
connLog.Infof("session mode: user=%s display=%s", header.username, vs.Display())
|
||||||
|
|
||||||
|
default:
|
||||||
|
capturer = s.capturer
|
||||||
|
injector = s.injector
|
||||||
|
if cc, ok := capturer.(interface{ ClientConnect() }); ok {
|
||||||
|
cc.ClientConnect()
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if cd, ok := capturer.(interface{ ClientDisconnect() }); ok {
|
||||||
|
cd.ClientDisconnect()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.validateCapturer(capturer); err != nil {
|
||||||
|
rejectConnection(conn, codeMessage(RejectCodeCapturerError, fmt.Sprintf("screen capturer: %v", err)))
|
||||||
|
connLog.Warnf("capturer not ready: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var rec *vncRecorder
|
||||||
|
if s.recordingDir != "" {
|
||||||
|
mode := "attach"
|
||||||
|
if header.mode == ModeSession {
|
||||||
|
mode = "session"
|
||||||
|
}
|
||||||
|
jwtUser, _ := connLog.Data["jwt_user"].(string)
|
||||||
|
var err error
|
||||||
|
rec, err = newVNCRecorder(s.recordingDir, capturer.Width(), capturer.Height(), &RecordingMeta{
|
||||||
|
User: header.username,
|
||||||
|
RemoteAddr: conn.RemoteAddr().String(),
|
||||||
|
JWTUser: jwtUser,
|
||||||
|
Mode: mode,
|
||||||
|
}, s.recordingEncKey, connLog)
|
||||||
|
if err != nil {
|
||||||
|
connLog.Warnf("start VNC recording: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sess := &session{
|
||||||
|
conn: conn,
|
||||||
|
capturer: capturer,
|
||||||
|
injector: injector,
|
||||||
|
serverW: capturer.Width(),
|
||||||
|
serverH: capturer.Height(),
|
||||||
|
password: s.password,
|
||||||
|
log: connLog,
|
||||||
|
recorder: rec,
|
||||||
|
}
|
||||||
|
sess.serve()
|
||||||
|
}
|
||||||
|
|
||||||
|
// codeMessage formats a stable reject code with a human-readable message.
|
||||||
|
// Dashboards split on the first ": " to recover the code without parsing the
|
||||||
|
// free-text suffix.
|
||||||
|
func codeMessage(code, msg string) string {
|
||||||
|
return code + ": " + msg
|
||||||
|
}
|
||||||
|
|
||||||
|
// jwtErrorCode maps a JWT auth error to a stable reject code.
|
||||||
|
func jwtErrorCode(err error) string {
|
||||||
|
if err == nil {
|
||||||
|
return RejectCodeJWTInvalid
|
||||||
|
}
|
||||||
|
if errors.Is(err, nbjwt.ErrTokenExpired) {
|
||||||
|
return RejectCodeJWTExpired
|
||||||
|
}
|
||||||
|
msg := err.Error()
|
||||||
|
switch {
|
||||||
|
case strings.Contains(msg, "JWT required but not provided"):
|
||||||
|
return RejectCodeJWTMissing
|
||||||
|
case strings.Contains(msg, "authorize") || strings.Contains(msg, "not authorized"):
|
||||||
|
return RejectCodeAuthForbidden
|
||||||
|
default:
|
||||||
|
return RejectCodeJWTInvalid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// rejectConnection sends a minimal RFB handshake with a security failure
|
||||||
|
// reason, so VNC clients display the error message instead of a generic
|
||||||
|
// "unexpected disconnect."
|
||||||
|
func rejectConnection(conn net.Conn, reason string) {
|
||||||
|
defer conn.Close()
|
||||||
|
// RFB 3.8 server version.
|
||||||
|
io.WriteString(conn, "RFB 003.008\n")
|
||||||
|
// Read client version (12 bytes), ignore errors.
|
||||||
|
var clientVer [12]byte
|
||||||
|
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
io.ReadFull(conn, clientVer[:])
|
||||||
|
conn.SetReadDeadline(time.Time{})
|
||||||
|
// Send 0 security types = connection failed, followed by reason.
|
||||||
|
msg := []byte(reason)
|
||||||
|
buf := make([]byte, 1+4+len(msg))
|
||||||
|
buf[0] = 0 // 0 security types = failure
|
||||||
|
binary.BigEndian.PutUint32(buf[1:5], uint32(len(msg)))
|
||||||
|
copy(buf[5:], msg)
|
||||||
|
conn.Write(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
const defaultJWTMaxTokenAge = 10 * 60 // 10 minutes
|
||||||
|
|
||||||
|
// authenticateJWT validates the JWT from the connection header and checks
|
||||||
|
// authorization. For attach mode, just checks membership in the authorized
|
||||||
|
// user list. For session mode, additionally validates the OS user mapping.
|
||||||
|
func (s *Server) authenticateJWT(header *connectionHeader) (string, error) {
|
||||||
|
if header.jwt == "" {
|
||||||
|
return "", fmt.Errorf("JWT required but not provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
if err := s.ensureJWTValidator(); err != nil {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return "", fmt.Errorf("initialize JWT validator: %w", err)
|
||||||
|
}
|
||||||
|
validator := s.jwtValidator
|
||||||
|
extractor := s.jwtExtractor
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
token, err := validator.ValidateAndParse(context.Background(), header.jwt)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("validate JWT: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.checkTokenAge(token); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth, err := extractor.ToUserAuth(token)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("extract user from JWT: %w", err)
|
||||||
|
}
|
||||||
|
if userAuth.UserId == "" {
|
||||||
|
return "", fmt.Errorf("JWT has no user ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch header.mode {
|
||||||
|
case ModeSession:
|
||||||
|
// Session mode: check user + OS username mapping.
|
||||||
|
if _, err := s.authorizer.Authorize(userAuth.UserId, header.username); err != nil {
|
||||||
|
return "", fmt.Errorf("authorize session for %s: %w", header.username, err)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Attach mode: just check user is in the authorized list (wildcard OS user).
|
||||||
|
if _, err := s.authorizer.Authorize(userAuth.UserId, "*"); err != nil {
|
||||||
|
return "", fmt.Errorf("user not authorized for VNC: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return userAuth.UserId, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureJWTValidator lazily initializes the JWT validator. Must be called with mu held.
|
||||||
|
func (s *Server) ensureJWTValidator() error {
|
||||||
|
if s.jwtValidator != nil && s.jwtExtractor != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if s.jwtConfig == nil {
|
||||||
|
return fmt.Errorf("no JWT config")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.jwtValidator = nbjwt.NewValidator(
|
||||||
|
s.jwtConfig.Issuer,
|
||||||
|
s.jwtConfig.Audiences,
|
||||||
|
s.jwtConfig.KeysLocation,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
|
||||||
|
opts := []nbjwt.ClaimsExtractorOption{nbjwt.WithAudience(s.jwtConfig.Audiences[0])}
|
||||||
|
if claim := s.authorizer.GetUserIDClaim(); claim != "" {
|
||||||
|
opts = append(opts, nbjwt.WithUserIDClaim(claim))
|
||||||
|
}
|
||||||
|
s.jwtExtractor = nbjwt.NewClaimsExtractor(opts...)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) checkTokenAge(token *gojwt.Token) error {
|
||||||
|
maxAge := defaultJWTMaxTokenAge
|
||||||
|
if s.jwtConfig != nil && s.jwtConfig.MaxTokenAge > 0 {
|
||||||
|
maxAge = int(s.jwtConfig.MaxTokenAge)
|
||||||
|
}
|
||||||
|
return nbjwt.CheckTokenAge(token, time.Duration(maxAge)*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
// readConnectionHeader reads the NetBird VNC session header from the connection.
|
||||||
|
// Format: [mode: 1 byte] [username_len: 2 bytes BE] [username: N bytes]
|
||||||
|
//
|
||||||
|
// [jwt_len: 2 bytes BE] [jwt: N bytes]
|
||||||
|
//
|
||||||
|
// Uses a short timeout: our WASM proxy sends the header immediately after
|
||||||
|
// connecting. Standard VNC clients don't send anything first (server speaks
|
||||||
|
// first in RFB), so they time out and get the default attach mode.
|
||||||
|
func readConnectionHeader(conn net.Conn) (*connectionHeader, error) {
|
||||||
|
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||||
|
return nil, fmt.Errorf("set deadline: %w", err)
|
||||||
|
}
|
||||||
|
defer conn.SetReadDeadline(time.Time{}) //nolint:errcheck
|
||||||
|
|
||||||
|
var hdr [3]byte
|
||||||
|
if _, err := io.ReadFull(conn, hdr[:]); err != nil {
|
||||||
|
// Timeout or error: assume no header, use attach mode.
|
||||||
|
return &connectionHeader{mode: ModeAttach}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore a longer deadline for reading variable-length fields.
|
||||||
|
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||||
|
return nil, fmt.Errorf("set deadline: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mode := hdr[0]
|
||||||
|
usernameLen := binary.BigEndian.Uint16(hdr[1:3])
|
||||||
|
|
||||||
|
var username string
|
||||||
|
if usernameLen > 0 {
|
||||||
|
if usernameLen > 256 {
|
||||||
|
return nil, fmt.Errorf("username too long: %d", usernameLen)
|
||||||
|
}
|
||||||
|
buf := make([]byte, usernameLen)
|
||||||
|
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||||
|
return nil, fmt.Errorf("read username: %w", err)
|
||||||
|
}
|
||||||
|
username = string(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read JWT token length and data.
|
||||||
|
var jwtLenBuf [2]byte
|
||||||
|
var jwtToken string
|
||||||
|
if _, err := io.ReadFull(conn, jwtLenBuf[:]); err == nil {
|
||||||
|
jwtLen := binary.BigEndian.Uint16(jwtLenBuf[:])
|
||||||
|
if jwtLen > 0 && jwtLen < 8192 {
|
||||||
|
buf := make([]byte, jwtLen)
|
||||||
|
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||||
|
return nil, fmt.Errorf("read JWT: %w", err)
|
||||||
|
}
|
||||||
|
jwtToken = string(buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read optional Windows session ID (4 bytes BE). Missing = 0 (console/auto).
|
||||||
|
var sessionID uint32
|
||||||
|
var sidBuf [4]byte
|
||||||
|
if _, err := io.ReadFull(conn, sidBuf[:]); err == nil {
|
||||||
|
sessionID = binary.BigEndian.Uint32(sidBuf[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
return &connectionHeader{mode: mode, username: username, jwt: jwtToken, sessionID: sessionID}, nil
|
||||||
|
}
|
||||||
15
client/vnc/server/server_darwin.go
Normal file
15
client/vnc/server/server_darwin.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
//go:build darwin && !ios
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
func (s *Server) platformInit() {}
|
||||||
|
|
||||||
|
// serviceAcceptLoop is not supported on macOS.
|
||||||
|
func (s *Server) serviceAcceptLoop() {
|
||||||
|
s.log.Warn("service mode not supported on macOS, falling back to direct mode")
|
||||||
|
s.acceptLoop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) platformSessionManager() virtualSessionManager {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
15
client/vnc/server/server_stub.go
Normal file
15
client/vnc/server/server_stub.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
//go:build !windows && !darwin && !freebsd && !(linux && !android)
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
func (s *Server) platformInit() {}
|
||||||
|
|
||||||
|
// serviceAcceptLoop is not supported on non-Windows platforms.
|
||||||
|
func (s *Server) serviceAcceptLoop() {
|
||||||
|
s.log.Warn("service mode not supported on this platform, falling back to direct mode")
|
||||||
|
s.acceptLoop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) platformSessionManager() virtualSessionManager {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
136
client/vnc/server/server_test.go
Normal file
136
client/vnc/server/server_test.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"image"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// testCapturer returns a 100x100 image for test sessions.
|
||||||
|
type testCapturer struct{}
|
||||||
|
|
||||||
|
func (t *testCapturer) Width() int { return 100 }
|
||||||
|
func (t *testCapturer) Height() int { return 100 }
|
||||||
|
func (t *testCapturer) Capture() (*image.RGBA, error) { return image.NewRGBA(image.Rect(0, 0, 100, 100)), nil }
|
||||||
|
|
||||||
|
func startTestServer(t *testing.T, disableAuth bool, jwtConfig *JWTConfig) (net.Addr, *Server) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
srv := New(&testCapturer{}, &StubInputInjector{}, "")
|
||||||
|
srv.SetDisableAuth(disableAuth)
|
||||||
|
if jwtConfig != nil {
|
||||||
|
srv.SetJWTConfig(jwtConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||||
|
network := netip.MustParsePrefix("127.0.0.0/8")
|
||||||
|
require.NoError(t, srv.Start(t.Context(), addr, network))
|
||||||
|
// Override local address so source validation doesn't reject 127.0.0.1 as "own IP".
|
||||||
|
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
||||||
|
t.Cleanup(func() { _ = srv.Stop() })
|
||||||
|
|
||||||
|
return srv.listener.Addr(), srv
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthEnabled_NoJWTConfig_RejectsConnection(t *testing.T) {
|
||||||
|
addr, _ := startTestServer(t, false, nil)
|
||||||
|
|
||||||
|
conn, err := net.Dial("tcp", addr.String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Send session header: attach mode, no username, no JWT.
|
||||||
|
header := []byte{ModeAttach, 0, 0, 0, 0}
|
||||||
|
_, err = conn.Write(header)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Server should send RFB version then security failure.
|
||||||
|
var version [12]byte
|
||||||
|
_, err = io.ReadFull(conn, version[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "RFB 003.008\n", string(version[:]))
|
||||||
|
|
||||||
|
// Write client version to proceed through handshake.
|
||||||
|
_, err = conn.Write(version[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Read security types: 0 means failure, followed by reason.
|
||||||
|
var numTypes [1]byte
|
||||||
|
_, err = io.ReadFull(conn, numTypes[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, byte(0), numTypes[0], "should have 0 security types (failure)")
|
||||||
|
|
||||||
|
var reasonLen [4]byte
|
||||||
|
_, err = io.ReadFull(conn, reasonLen[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
reason := make([]byte, binary.BigEndian.Uint32(reasonLen[:]))
|
||||||
|
_, err = io.ReadFull(conn, reason)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, string(reason), "identity provider", "rejection reason should mention missing IdP config")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthDisabled_AllowsConnection(t *testing.T) {
|
||||||
|
addr, _ := startTestServer(t, true, nil)
|
||||||
|
|
||||||
|
conn, err := net.Dial("tcp", addr.String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Send session header: attach mode, no username, no JWT.
|
||||||
|
header := []byte{ModeAttach, 0, 0, 0, 0}
|
||||||
|
_, err = conn.Write(header)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Server should send RFB version.
|
||||||
|
var version [12]byte
|
||||||
|
_, err = io.ReadFull(conn, version[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "RFB 003.008\n", string(version[:]))
|
||||||
|
|
||||||
|
// Write client version.
|
||||||
|
_, err = conn.Write(version[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Should get security types (not 0 = failure).
|
||||||
|
var numTypes [1]byte
|
||||||
|
_, err = io.ReadFull(conn, numTypes[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEqual(t, byte(0), numTypes[0], "should have at least one security type (auth disabled)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthEnabled_EmptyJWT_Rejected(t *testing.T) {
|
||||||
|
// Auth enabled with a (bogus) JWT config: connections without JWT should be rejected.
|
||||||
|
addr, _ := startTestServer(t, false, &JWTConfig{
|
||||||
|
Issuer: "https://example.com",
|
||||||
|
KeysLocation: "https://example.com/.well-known/jwks.json",
|
||||||
|
Audiences: []string{"test"},
|
||||||
|
})
|
||||||
|
|
||||||
|
conn, err := net.Dial("tcp", addr.String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Send session header with empty JWT.
|
||||||
|
header := []byte{ModeAttach, 0, 0, 0, 0}
|
||||||
|
_, err = conn.Write(header)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var version [12]byte
|
||||||
|
_, err = io.ReadFull(conn, version[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = conn.Write(version[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var numTypes [1]byte
|
||||||
|
_, err = io.ReadFull(conn, numTypes[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, byte(0), numTypes[0], "should reject with 0 security types")
|
||||||
|
}
|
||||||
222
client/vnc/server/server_windows.go
Normal file
222
client/vnc/server/server_windows.go
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
"golang.org/x/sys/windows/registry"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
sasDLL = windows.NewLazySystemDLL("sas.dll")
|
||||||
|
procSendSAS = sasDLL.NewProc("SendSAS")
|
||||||
|
|
||||||
|
procConvertStringSecurityDescriptorToSecurityDescriptor = advapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW")
|
||||||
|
)
|
||||||
|
|
||||||
|
// sasSecurityAttributes builds a SECURITY_ATTRIBUTES that grants
|
||||||
|
// EVENT_MODIFY_STATE only to the SYSTEM account, preventing unprivileged
|
||||||
|
// local processes from triggering the Secure Attention Sequence.
|
||||||
|
func sasSecurityAttributes() (*windows.SecurityAttributes, error) {
|
||||||
|
// SDDL: grant full access to SYSTEM (creates/waits) and EVENT_MODIFY_STATE
|
||||||
|
// to the interactive user (IU) so the VNC agent in the console session can
|
||||||
|
// signal it. Other local users and network users are denied.
|
||||||
|
sddl, err := windows.UTF16PtrFromString("D:(A;;GA;;;SY)(A;;0x0002;;;IU)")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var sd uintptr
|
||||||
|
r, _, lerr := procConvertStringSecurityDescriptorToSecurityDescriptor.Call(
|
||||||
|
uintptr(unsafe.Pointer(sddl)),
|
||||||
|
1, // SDDL_REVISION_1
|
||||||
|
uintptr(unsafe.Pointer(&sd)),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
if r == 0 {
|
||||||
|
return nil, lerr
|
||||||
|
}
|
||||||
|
return &windows.SecurityAttributes{
|
||||||
|
Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})),
|
||||||
|
SecurityDescriptor: (*windows.SECURITY_DESCRIPTOR)(unsafe.Pointer(sd)),
|
||||||
|
InheritHandle: 0,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// enableSoftwareSAS sets the SoftwareSASGeneration registry key to allow
|
||||||
|
// services to trigger the Secure Attention Sequence via SendSAS. Without this,
|
||||||
|
// SendSAS silently does nothing on most Windows editions.
|
||||||
|
func enableSoftwareSAS() {
|
||||||
|
key, _, err := registry.CreateKey(
|
||||||
|
registry.LOCAL_MACHINE,
|
||||||
|
`SOFTWARE\Microsoft\Windows\CurrentVersion\Policies\System`,
|
||||||
|
registry.SET_VALUE,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("open SoftwareSASGeneration registry key: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer key.Close()
|
||||||
|
|
||||||
|
if err := key.SetDWordValue("SoftwareSASGeneration", 1); err != nil {
|
||||||
|
log.Warnf("set SoftwareSASGeneration: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Debug("SoftwareSASGeneration registry key set to 1 (services allowed)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// startSASListener creates a named event with a restricted DACL and waits for
|
||||||
|
// the VNC input injector to signal it. When signaled, it calls SendSAS(FALSE)
|
||||||
|
// from Session 0 to trigger the Secure Attention Sequence (Ctrl+Alt+Del).
|
||||||
|
// Only SYSTEM processes can open the event.
|
||||||
|
func startSASListener() {
|
||||||
|
enableSoftwareSAS()
|
||||||
|
namePtr, err := windows.UTF16PtrFromString(sasEventName)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("SAS listener UTF16: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sa, err := sasSecurityAttributes()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("build SAS security descriptor: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ev, err := windows.CreateEvent(sa, 0, 0, namePtr)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("SAS CreateEvent: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Info("SAS listener ready (Session 0)")
|
||||||
|
go func() {
|
||||||
|
defer windows.CloseHandle(ev)
|
||||||
|
for {
|
||||||
|
ret, _ := windows.WaitForSingleObject(ev, windows.INFINITE)
|
||||||
|
if ret == windows.WAIT_OBJECT_0 {
|
||||||
|
r, _, sasErr := procSendSAS.Call(0) // FALSE = not from service desktop
|
||||||
|
if r == 0 {
|
||||||
|
log.Warnf("SendSAS: %v", sasErr)
|
||||||
|
} else {
|
||||||
|
log.Info("SendSAS called from Session 0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// enablePrivilege enables a named privilege on the current process token.
|
||||||
|
func enablePrivilege(name string) error {
|
||||||
|
var token windows.Token
|
||||||
|
if err := windows.OpenProcessToken(windows.CurrentProcess(),
|
||||||
|
windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, &token); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer token.Close()
|
||||||
|
|
||||||
|
var luid windows.LUID
|
||||||
|
namePtr, _ := windows.UTF16PtrFromString(name)
|
||||||
|
if err := windows.LookupPrivilegeValue(nil, namePtr, &luid); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
tp := windows.Tokenprivileges{PrivilegeCount: 1}
|
||||||
|
tp.Privileges[0].Luid = luid
|
||||||
|
tp.Privileges[0].Attributes = windows.SE_PRIVILEGE_ENABLED
|
||||||
|
return windows.AdjustTokenPrivileges(token, false, &tp, 0, nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) platformSessionManager() virtualSessionManager {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// platformInit starts the SAS listener and enables privileges needed for
|
||||||
|
// Session 0 operations (agent spawning, SendSAS).
|
||||||
|
func (s *Server) platformInit() {
|
||||||
|
for _, priv := range []string{"SeTcbPrivilege", "SeAssignPrimaryTokenPrivilege"} {
|
||||||
|
if err := enablePrivilege(priv); err != nil {
|
||||||
|
log.Debugf("enable %s: %v", priv, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
startSASListener()
|
||||||
|
}
|
||||||
|
|
||||||
|
// serviceAcceptLoop runs in Session 0. It validates source IP and
|
||||||
|
// authenticates via JWT before proxying connections to the user-session agent.
|
||||||
|
func (s *Server) serviceAcceptLoop() {
|
||||||
|
|
||||||
|
sm := newSessionManager(agentPort)
|
||||||
|
go sm.run()
|
||||||
|
|
||||||
|
log.Infof("service mode, proxying connections to agent on 127.0.0.1:%s", agentPort)
|
||||||
|
|
||||||
|
for {
|
||||||
|
conn, err := s.listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
sm.Stop()
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
s.log.Debugf("accept VNC connection: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.handleServiceConnection(conn, sm)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleServiceConnection validates the source IP and JWT, then proxies
|
||||||
|
// the connection (with header bytes replayed) to the agent.
|
||||||
|
func (s *Server) handleServiceConnection(conn net.Conn, sm *sessionManager) {
|
||||||
|
connLog := s.log.WithField("remote", conn.RemoteAddr().String())
|
||||||
|
|
||||||
|
if !s.isAllowedSource(conn.RemoteAddr()) {
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var headerBuf bytes.Buffer
|
||||||
|
tee := io.TeeReader(conn, &headerBuf)
|
||||||
|
teeConn := &prefixConn{Reader: tee, Conn: conn}
|
||||||
|
|
||||||
|
header, err := readConnectionHeader(teeConn)
|
||||||
|
if err != nil {
|
||||||
|
connLog.Debugf("read connection header: %v", err)
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.disableAuth {
|
||||||
|
if s.jwtConfig == nil {
|
||||||
|
rejectConnection(conn, codeMessage(RejectCodeAuthConfig, "auth enabled but no identity provider configured"))
|
||||||
|
connLog.Warn("auth rejected: no identity provider configured")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := s.authenticateJWT(header); err != nil {
|
||||||
|
rejectConnection(conn, codeMessage(jwtErrorCode(err), err.Error()))
|
||||||
|
connLog.Warnf("auth rejected: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replay buffered header bytes + remaining stream to the agent.
|
||||||
|
replayConn := &prefixConn{
|
||||||
|
Reader: io.MultiReader(&headerBuf, conn),
|
||||||
|
Conn: conn,
|
||||||
|
}
|
||||||
|
proxyToAgent(replayConn, agentPort, sm.AuthToken())
|
||||||
|
}
|
||||||
|
|
||||||
|
// prefixConn wraps a net.Conn, overriding Read to use a different reader.
|
||||||
|
type prefixConn struct {
|
||||||
|
io.Reader
|
||||||
|
net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *prefixConn) Read(b []byte) (int, error) {
|
||||||
|
return p.Reader.Read(b)
|
||||||
|
}
|
||||||
15
client/vnc/server/server_x11.go
Normal file
15
client/vnc/server/server_x11.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
//go:build (linux && !android) || freebsd
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
func (s *Server) platformInit() {}
|
||||||
|
|
||||||
|
// serviceAcceptLoop is not supported on Linux.
|
||||||
|
func (s *Server) serviceAcceptLoop() {
|
||||||
|
s.log.Warn("service mode not supported on Linux, falling back to direct mode")
|
||||||
|
s.acceptLoop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) platformSessionManager() virtualSessionManager {
|
||||||
|
return newSessionManager(s.log)
|
||||||
|
}
|
||||||
451
client/vnc/server/session.go
Normal file
451
client/vnc/server/session.go
Normal file
@@ -0,0 +1,451 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
readDeadline = 60 * time.Second
|
||||||
|
maxCutTextBytes = 1 << 20 // 1 MiB
|
||||||
|
)
|
||||||
|
|
||||||
|
const tileSize = 64 // pixels per tile for dirty-rect detection
|
||||||
|
|
||||||
|
type session struct {
|
||||||
|
conn net.Conn
|
||||||
|
capturer ScreenCapturer
|
||||||
|
injector InputInjector
|
||||||
|
serverW int
|
||||||
|
serverH int
|
||||||
|
password string
|
||||||
|
log *log.Entry
|
||||||
|
recorder *vncRecorder
|
||||||
|
|
||||||
|
writeMu sync.Mutex
|
||||||
|
pf clientPixelFormat
|
||||||
|
useZlib bool
|
||||||
|
zlib *zlibState
|
||||||
|
prevFrame *image.RGBA
|
||||||
|
idleFrames int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) addr() string { return s.conn.RemoteAddr().String() }
|
||||||
|
|
||||||
|
// serve runs the full RFB session lifecycle.
|
||||||
|
func (s *session) serve() {
|
||||||
|
defer s.conn.Close()
|
||||||
|
if s.recorder != nil {
|
||||||
|
defer s.recorder.close()
|
||||||
|
}
|
||||||
|
s.pf = defaultClientPixelFormat()
|
||||||
|
|
||||||
|
if err := s.handshake(); err != nil {
|
||||||
|
s.log.Warnf("handshake with %s: %v", s.addr(), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.log.Infof("client connected: %s", s.addr())
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
defer close(done)
|
||||||
|
go s.clipboardPoll(done)
|
||||||
|
|
||||||
|
if err := s.messageLoop(); err != nil && err != io.EOF {
|
||||||
|
s.log.Warnf("client %s disconnected: %v", s.addr(), err)
|
||||||
|
} else {
|
||||||
|
s.log.Infof("client disconnected: %s", s.addr())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clipboardPoll periodically checks the server-side clipboard and sends
|
||||||
|
// changes to the VNC client. Only runs during active sessions.
|
||||||
|
func (s *session) clipboardPoll(done <-chan struct{}) {
|
||||||
|
ticker := time.NewTicker(2 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
var lastClip string
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
text := s.injector.GetClipboard()
|
||||||
|
if len(text) > maxCutTextBytes {
|
||||||
|
text = text[:maxCutTextBytes]
|
||||||
|
}
|
||||||
|
if text != "" && text != lastClip {
|
||||||
|
lastClip = text
|
||||||
|
if err := s.sendServerCutText(text); err != nil {
|
||||||
|
s.log.Debugf("send clipboard to client: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) handshake() error {
|
||||||
|
// Send protocol version.
|
||||||
|
if _, err := io.WriteString(s.conn, rfbProtocolVersion); err != nil {
|
||||||
|
return fmt.Errorf("send version: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read client version.
|
||||||
|
var clientVer [12]byte
|
||||||
|
if _, err := io.ReadFull(s.conn, clientVer[:]); err != nil {
|
||||||
|
return fmt.Errorf("read client version: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send supported security types.
|
||||||
|
if err := s.sendSecurityTypes(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read chosen security type.
|
||||||
|
var secType [1]byte
|
||||||
|
if _, err := io.ReadFull(s.conn, secType[:]); err != nil {
|
||||||
|
return fmt.Errorf("read security type: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.handleSecurity(secType[0]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read ClientInit.
|
||||||
|
var clientInit [1]byte
|
||||||
|
if _, err := io.ReadFull(s.conn, clientInit[:]); err != nil {
|
||||||
|
return fmt.Errorf("read ClientInit: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.sendServerInit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) sendSecurityTypes() error {
|
||||||
|
if s.password == "" {
|
||||||
|
_, err := s.conn.Write([]byte{1, secNone})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err := s.conn.Write([]byte{1, secVNCAuth})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) handleSecurity(secType byte) error {
|
||||||
|
switch secType {
|
||||||
|
case secVNCAuth:
|
||||||
|
return s.doVNCAuth()
|
||||||
|
case secNone:
|
||||||
|
return binary.Write(s.conn, binary.BigEndian, uint32(0))
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported security type: %d", secType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) doVNCAuth() error {
|
||||||
|
challenge := make([]byte, 16)
|
||||||
|
if _, err := rand.Read(challenge); err != nil {
|
||||||
|
return fmt.Errorf("generate challenge: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := s.conn.Write(challenge); err != nil {
|
||||||
|
return fmt.Errorf("send challenge: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
response := make([]byte, 16)
|
||||||
|
if _, err := io.ReadFull(s.conn, response); err != nil {
|
||||||
|
return fmt.Errorf("read auth response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result uint32
|
||||||
|
if s.password != "" {
|
||||||
|
expected := vncAuthEncrypt(challenge, s.password)
|
||||||
|
if !bytes.Equal(expected, response) {
|
||||||
|
result = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Write(s.conn, binary.BigEndian, result); err != nil {
|
||||||
|
return fmt.Errorf("send auth result: %w", err)
|
||||||
|
}
|
||||||
|
if result != 0 {
|
||||||
|
msg := "authentication failed"
|
||||||
|
_ = binary.Write(s.conn, binary.BigEndian, uint32(len(msg)))
|
||||||
|
_, _ = s.conn.Write([]byte(msg))
|
||||||
|
return fmt.Errorf("authentication failed from %s", s.addr())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) sendServerInit() error {
|
||||||
|
name := []byte("NetBird VNC")
|
||||||
|
buf := make([]byte, 0, 4+16+4+len(name))
|
||||||
|
|
||||||
|
// Framebuffer width and height.
|
||||||
|
buf = append(buf, byte(s.serverW>>8), byte(s.serverW))
|
||||||
|
buf = append(buf, byte(s.serverH>>8), byte(s.serverH))
|
||||||
|
|
||||||
|
// Server pixel format.
|
||||||
|
buf = append(buf, serverPixelFormat[:]...)
|
||||||
|
|
||||||
|
// Desktop name.
|
||||||
|
buf = append(buf,
|
||||||
|
byte(len(name)>>24), byte(len(name)>>16),
|
||||||
|
byte(len(name)>>8), byte(len(name)),
|
||||||
|
)
|
||||||
|
buf = append(buf, name...)
|
||||||
|
|
||||||
|
_, err := s.conn.Write(buf)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) messageLoop() error {
|
||||||
|
for {
|
||||||
|
var msgType [1]byte
|
||||||
|
if err := s.conn.SetDeadline(time.Now().Add(readDeadline)); err != nil {
|
||||||
|
return fmt.Errorf("set deadline: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := io.ReadFull(s.conn, msgType[:]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_ = s.conn.SetDeadline(time.Time{})
|
||||||
|
|
||||||
|
switch msgType[0] {
|
||||||
|
case clientSetPixelFormat:
|
||||||
|
if err := s.handleSetPixelFormat(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case clientSetEncodings:
|
||||||
|
if err := s.handleSetEncodings(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case clientFramebufferUpdateRequest:
|
||||||
|
if err := s.handleFBUpdateRequest(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case clientKeyEvent:
|
||||||
|
if err := s.handleKeyEvent(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case clientPointerEvent:
|
||||||
|
if err := s.handlePointerEvent(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case clientCutText:
|
||||||
|
if err := s.handleCutText(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown client message type: %d", msgType[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) handleSetPixelFormat() error {
|
||||||
|
var buf [19]byte // 3 padding + 16 pixel format
|
||||||
|
if _, err := io.ReadFull(s.conn, buf[:]); err != nil {
|
||||||
|
return fmt.Errorf("read SetPixelFormat: %w", err)
|
||||||
|
}
|
||||||
|
s.pf = parsePixelFormat(buf[3:19])
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) handleSetEncodings() error {
|
||||||
|
var header [3]byte // 1 padding + 2 number-of-encodings
|
||||||
|
if _, err := io.ReadFull(s.conn, header[:]); err != nil {
|
||||||
|
return fmt.Errorf("read SetEncodings header: %w", err)
|
||||||
|
}
|
||||||
|
numEnc := binary.BigEndian.Uint16(header[1:3])
|
||||||
|
buf := make([]byte, int(numEnc)*4)
|
||||||
|
if _, err := io.ReadFull(s.conn, buf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if client supports zlib encoding.
|
||||||
|
for i := range int(numEnc) {
|
||||||
|
enc := int32(binary.BigEndian.Uint32(buf[i*4 : i*4+4]))
|
||||||
|
if enc == encZlib {
|
||||||
|
s.useZlib = true
|
||||||
|
if s.zlib == nil {
|
||||||
|
s.zlib = newZlibState()
|
||||||
|
}
|
||||||
|
s.log.Debugf("client supports zlib encoding")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) handleFBUpdateRequest() error {
|
||||||
|
var req [9]byte
|
||||||
|
if _, err := io.ReadFull(s.conn, req[:]); err != nil {
|
||||||
|
return fmt.Errorf("read FBUpdateRequest: %w", err)
|
||||||
|
}
|
||||||
|
incremental := req[0]
|
||||||
|
|
||||||
|
img, err := s.capturer.Capture()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("capture screen: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.recorder != nil {
|
||||||
|
s.recorder.writeFrame(img)
|
||||||
|
}
|
||||||
|
|
||||||
|
if incremental == 1 && s.prevFrame != nil {
|
||||||
|
rects := diffRects(s.prevFrame, img, s.serverW, s.serverH, tileSize)
|
||||||
|
if len(rects) == 0 {
|
||||||
|
// Nothing changed. Back off briefly before responding to reduce
|
||||||
|
// CPU usage when the screen is static. The client re-requests
|
||||||
|
// immediately after receiving our empty response, so without
|
||||||
|
// this delay we'd spin at ~1000fps checking for changes.
|
||||||
|
s.idleFrames++
|
||||||
|
delay := min(s.idleFrames*5, 100) // 5ms → 100ms adaptive backoff
|
||||||
|
time.Sleep(time.Duration(delay) * time.Millisecond)
|
||||||
|
s.savePrevFrame(img)
|
||||||
|
return s.sendEmptyUpdate()
|
||||||
|
}
|
||||||
|
s.idleFrames = 0
|
||||||
|
s.savePrevFrame(img)
|
||||||
|
return s.sendDirtyRects(img, rects)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Full update.
|
||||||
|
s.idleFrames = 0
|
||||||
|
s.savePrevFrame(img)
|
||||||
|
return s.sendFullUpdate(img)
|
||||||
|
}
|
||||||
|
|
||||||
|
// savePrevFrame copies img's pixel data into prevFrame. This is necessary
|
||||||
|
// because some capturers (DXGI) reuse the same image buffer across calls,
|
||||||
|
// so a simple pointer assignment would make prevFrame alias the live buffer
|
||||||
|
// and diffRects would always see zero changes.
|
||||||
|
func (s *session) savePrevFrame(img *image.RGBA) {
|
||||||
|
if s.prevFrame == nil || s.prevFrame.Rect != img.Rect {
|
||||||
|
s.prevFrame = image.NewRGBA(img.Rect)
|
||||||
|
}
|
||||||
|
copy(s.prevFrame.Pix, img.Pix)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendEmptyUpdate sends a FramebufferUpdate with zero rectangles.
|
||||||
|
func (s *session) sendEmptyUpdate() error {
|
||||||
|
var buf [4]byte
|
||||||
|
buf[0] = serverFramebufferUpdate
|
||||||
|
s.writeMu.Lock()
|
||||||
|
_, err := s.conn.Write(buf[:])
|
||||||
|
s.writeMu.Unlock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) sendFullUpdate(img *image.RGBA) error {
|
||||||
|
w, h := s.serverW, s.serverH
|
||||||
|
|
||||||
|
var buf []byte
|
||||||
|
if s.useZlib && s.zlib != nil {
|
||||||
|
buf = encodeZlibRect(img, s.pf, 0, 0, w, h, s.zlib.w, s.zlib.buf)
|
||||||
|
} else {
|
||||||
|
buf = encodeRawRect(img, s.pf, 0, 0, w, h)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.writeMu.Lock()
|
||||||
|
_, err := s.conn.Write(buf)
|
||||||
|
s.writeMu.Unlock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) sendDirtyRects(img *image.RGBA, rects [][4]int) error {
|
||||||
|
// Build a multi-rectangle FramebufferUpdate.
|
||||||
|
// Header: type(1) + padding(1) + numRects(2)
|
||||||
|
header := make([]byte, 4)
|
||||||
|
header[0] = serverFramebufferUpdate
|
||||||
|
binary.BigEndian.PutUint16(header[2:4], uint16(len(rects)))
|
||||||
|
|
||||||
|
s.writeMu.Lock()
|
||||||
|
defer s.writeMu.Unlock()
|
||||||
|
|
||||||
|
if _, err := s.conn.Write(header); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range rects {
|
||||||
|
x, y, w, h := r[0], r[1], r[2], r[3]
|
||||||
|
|
||||||
|
var rectBuf []byte
|
||||||
|
if s.useZlib && s.zlib != nil {
|
||||||
|
rectBuf = encodeZlibRect(img, s.pf, x, y, w, h, s.zlib.w, s.zlib.buf)
|
||||||
|
// encodeZlibRect includes its own FBUpdate header for 1 rect.
|
||||||
|
// For multi-rect, we need just the rect data without the FBUpdate header.
|
||||||
|
// Skip the 4-byte FBUpdate header since we already sent ours.
|
||||||
|
rectBuf = rectBuf[4:]
|
||||||
|
} else {
|
||||||
|
rectBuf = encodeRawRect(img, s.pf, x, y, w, h)
|
||||||
|
rectBuf = rectBuf[4:] // skip FBUpdate header
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := s.conn.Write(rectBuf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) handleKeyEvent() error {
|
||||||
|
var data [7]byte
|
||||||
|
if _, err := io.ReadFull(s.conn, data[:]); err != nil {
|
||||||
|
return fmt.Errorf("read KeyEvent: %w", err)
|
||||||
|
}
|
||||||
|
down := data[0] == 1
|
||||||
|
keysym := binary.BigEndian.Uint32(data[3:7])
|
||||||
|
s.injector.InjectKey(keysym, down)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) handlePointerEvent() error {
|
||||||
|
var data [5]byte
|
||||||
|
if _, err := io.ReadFull(s.conn, data[:]); err != nil {
|
||||||
|
return fmt.Errorf("read PointerEvent: %w", err)
|
||||||
|
}
|
||||||
|
buttonMask := data[0]
|
||||||
|
x := int(binary.BigEndian.Uint16(data[1:3]))
|
||||||
|
y := int(binary.BigEndian.Uint16(data[3:5]))
|
||||||
|
s.injector.InjectPointer(buttonMask, x, y, s.serverW, s.serverH)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) handleCutText() error {
|
||||||
|
var header [7]byte // 3 padding + 4 length
|
||||||
|
if _, err := io.ReadFull(s.conn, header[:]); err != nil {
|
||||||
|
return fmt.Errorf("read CutText header: %w", err)
|
||||||
|
}
|
||||||
|
length := binary.BigEndian.Uint32(header[3:7])
|
||||||
|
if length > maxCutTextBytes {
|
||||||
|
return fmt.Errorf("cut text too large: %d bytes", length)
|
||||||
|
}
|
||||||
|
buf := make([]byte, length)
|
||||||
|
if _, err := io.ReadFull(s.conn, buf); err != nil {
|
||||||
|
return fmt.Errorf("read CutText payload: %w", err)
|
||||||
|
}
|
||||||
|
s.injector.SetClipboard(string(buf))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendServerCutText sends clipboard text from the server to the client.
|
||||||
|
func (s *session) sendServerCutText(text string) error {
|
||||||
|
data := []byte(text)
|
||||||
|
buf := make([]byte, 8+len(data))
|
||||||
|
buf[0] = serverCutText
|
||||||
|
// buf[1:4] = padding (zero)
|
||||||
|
binary.BigEndian.PutUint32(buf[4:8], uint32(len(data)))
|
||||||
|
copy(buf[8:], data)
|
||||||
|
|
||||||
|
s.writeMu.Lock()
|
||||||
|
_, err := s.conn.Write(buf)
|
||||||
|
s.writeMu.Unlock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user