mirror of
https://github.com/fosrl/olm.git
synced 2026-02-12 07:56:44 +00:00
Compare commits
31 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5e84f802ed | ||
|
|
f40b0ff820 | ||
|
|
95a4840374 | ||
|
|
27424170e4 | ||
|
|
a8ace6f64a | ||
|
|
3fa1073f49 | ||
|
|
76d86c10ff | ||
|
|
2d34c6c8b2 | ||
|
|
a7f3477bdd | ||
|
|
af0a72d296 | ||
|
|
d1e836e760 | ||
|
|
8dd45c4ca2 | ||
|
|
9db009058b | ||
|
|
29c01deb05 | ||
|
|
7224d9824d | ||
|
|
8afc28fdff | ||
|
|
4ba2fb7b53 | ||
|
|
2e6076923d | ||
|
|
4c001dc751 | ||
|
|
2b8e240752 | ||
|
|
bee490713d | ||
|
|
1cb7fd94ab | ||
|
|
dc9a547950 | ||
|
|
2be0933246 | ||
|
|
c0b1cd6bde | ||
|
|
dd00289f8e | ||
|
|
b23a02ee97 | ||
|
|
80f726cfea | ||
|
|
aa8828186f | ||
|
|
0a990d196d | ||
|
|
18ee4c93fb |
47
.github/DISCUSSION_TEMPLATE/feature-requests.yml
vendored
Normal file
47
.github/DISCUSSION_TEMPLATE/feature-requests.yml
vendored
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Summary
|
||||||
|
description: A clear and concise summary of the requested feature.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Motivation
|
||||||
|
description: |
|
||||||
|
Why is this feature important?
|
||||||
|
Explain the problem this feature would solve or what use case it would enable.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Proposed Solution
|
||||||
|
description: |
|
||||||
|
How would you like to see this feature implemented?
|
||||||
|
Provide as much detail as possible about the desired behavior, configuration, or changes.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Alternatives Considered
|
||||||
|
description: Describe any alternative solutions or workarounds you've thought about.
|
||||||
|
validations:
|
||||||
|
required: false
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Additional Context
|
||||||
|
description: Add any other context, mockups, or screenshots about the feature request here.
|
||||||
|
validations:
|
||||||
|
required: false
|
||||||
|
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
Before submitting, please:
|
||||||
|
- Check if there is an existing issue for this feature.
|
||||||
|
- Clearly explain the benefit and use case.
|
||||||
|
- Be as specific as possible to help contributors evaluate and implement.
|
||||||
51
.github/ISSUE_TEMPLATE/1.bug_report.yml
vendored
Normal file
51
.github/ISSUE_TEMPLATE/1.bug_report.yml
vendored
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
name: Bug Report
|
||||||
|
description: Create a bug report
|
||||||
|
labels: []
|
||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Describe the Bug
|
||||||
|
description: A clear and concise description of what the bug is.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Environment
|
||||||
|
description: Please fill out the relevant details below for your environment.
|
||||||
|
value: |
|
||||||
|
- OS Type & Version: (e.g., Ubuntu 22.04)
|
||||||
|
- Pangolin Version:
|
||||||
|
- Gerbil Version:
|
||||||
|
- Traefik Version:
|
||||||
|
- Newt Version:
|
||||||
|
- Olm Version: (if applicable)
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: To Reproduce
|
||||||
|
description: |
|
||||||
|
Steps to reproduce the behavior, please provide a clear description of how to reproduce the issue, based on the linked minimal reproduction. Screenshots can be provided in the issue body below.
|
||||||
|
|
||||||
|
If using code blocks, make sure syntax highlighting is correct and double-check that the rendered preview is not broken.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Expected Behavior
|
||||||
|
description: A clear and concise description of what you expected to happen.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
Before posting the issue go through the steps you've written down to make sure the steps provided are detailed and clear.
|
||||||
|
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
Contributors should be able to follow the steps provided in order to reproduce the bug.
|
||||||
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
blank_issues_enabled: false
|
||||||
|
contact_links:
|
||||||
|
- name: Need help or have questions?
|
||||||
|
url: https://github.com/orgs/fosrl/discussions
|
||||||
|
about: Ask questions, get help, and discuss with other community members
|
||||||
|
- name: Request a Feature
|
||||||
|
url: https://github.com/orgs/fosrl/discussions/new?category=feature-requests
|
||||||
|
about: Feature requests should be opened as discussions so others can upvote and comment
|
||||||
2
.github/workflows/cicd.yml
vendored
2
.github/workflows/cicd.yml
vendored
@@ -8,7 +8,7 @@ on:
|
|||||||
jobs:
|
jobs:
|
||||||
release:
|
release:
|
||||||
name: Build and Release
|
name: Build and Release
|
||||||
runs-on: ubuntu-latest
|
runs-on: amd64-runner
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
|
|||||||
132
.github/workflows/mirror.yaml
vendored
Normal file
132
.github/workflows/mirror.yaml
vendored
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
name: Mirror & Sign (Docker Hub to GHCR)
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch: {}
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
packages: write
|
||||||
|
id-token: write # for keyless OIDC
|
||||||
|
|
||||||
|
env:
|
||||||
|
SOURCE_IMAGE: docker.io/fosrl/olm
|
||||||
|
DEST_IMAGE: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
mirror-and-dual-sign:
|
||||||
|
runs-on: amd64-runner
|
||||||
|
steps:
|
||||||
|
- name: Install skopeo + jq
|
||||||
|
run: |
|
||||||
|
sudo apt-get update -y
|
||||||
|
sudo apt-get install -y skopeo jq
|
||||||
|
skopeo --version
|
||||||
|
|
||||||
|
- name: Install cosign
|
||||||
|
uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0
|
||||||
|
|
||||||
|
- name: Input check
|
||||||
|
run: |
|
||||||
|
test -n "${SOURCE_IMAGE}" || (echo "SOURCE_IMAGE is empty" && exit 1)
|
||||||
|
echo "Source : ${SOURCE_IMAGE}"
|
||||||
|
echo "Target : ${DEST_IMAGE}"
|
||||||
|
|
||||||
|
# Auth for skopeo (containers-auth)
|
||||||
|
- name: Skopeo login to GHCR
|
||||||
|
run: |
|
||||||
|
skopeo login ghcr.io -u "${{ github.actor }}" -p "${{ secrets.GITHUB_TOKEN }}"
|
||||||
|
|
||||||
|
# Auth for cosign (docker-config)
|
||||||
|
- name: Docker login to GHCR (for cosign)
|
||||||
|
run: |
|
||||||
|
echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u "${{ github.actor }}" --password-stdin
|
||||||
|
|
||||||
|
- name: List source tags
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
skopeo list-tags --retry-times 3 docker://"${SOURCE_IMAGE}" \
|
||||||
|
| jq -r '.Tags[]' | sort -u > src-tags.txt
|
||||||
|
echo "Found source tags: $(wc -l < src-tags.txt)"
|
||||||
|
head -n 20 src-tags.txt || true
|
||||||
|
|
||||||
|
- name: List destination tags (skip existing)
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
if skopeo list-tags --retry-times 3 docker://"${DEST_IMAGE}" >/tmp/dst.json 2>/dev/null; then
|
||||||
|
jq -r '.Tags[]' /tmp/dst.json | sort -u > dst-tags.txt
|
||||||
|
else
|
||||||
|
: > dst-tags.txt
|
||||||
|
fi
|
||||||
|
echo "Existing destination tags: $(wc -l < dst-tags.txt)"
|
||||||
|
|
||||||
|
- name: Mirror, dual-sign, and verify
|
||||||
|
env:
|
||||||
|
# keyless
|
||||||
|
COSIGN_YES: "true"
|
||||||
|
# key-based
|
||||||
|
COSIGN_PRIVATE_KEY: ${{ secrets.COSIGN_PRIVATE_KEY }}
|
||||||
|
COSIGN_PASSWORD: ${{ secrets.COSIGN_PASSWORD }}
|
||||||
|
# verify
|
||||||
|
COSIGN_PUBLIC_KEY: ${{ secrets.COSIGN_PUBLIC_KEY }}
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
copied=0; skipped=0; v_ok=0; errs=0
|
||||||
|
|
||||||
|
issuer="https://token.actions.githubusercontent.com"
|
||||||
|
id_regex="^https://github.com/${{ github.repository }}/.+"
|
||||||
|
|
||||||
|
while read -r tag; do
|
||||||
|
[ -z "$tag" ] && continue
|
||||||
|
|
||||||
|
if grep -Fxq "$tag" dst-tags.txt; then
|
||||||
|
echo "::notice ::Skip (exists) ${DEST_IMAGE}:${tag}"
|
||||||
|
skipped=$((skipped+1))
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "==> Copy ${SOURCE_IMAGE}:${tag} → ${DEST_IMAGE}:${tag}"
|
||||||
|
if ! skopeo copy --all --retry-times 3 \
|
||||||
|
docker://"${SOURCE_IMAGE}:${tag}" docker://"${DEST_IMAGE}:${tag}"; then
|
||||||
|
echo "::warning title=Copy failed::${SOURCE_IMAGE}:${tag}"
|
||||||
|
errs=$((errs+1)); continue
|
||||||
|
fi
|
||||||
|
copied=$((copied+1))
|
||||||
|
|
||||||
|
digest="$(skopeo inspect --retry-times 3 docker://"${DEST_IMAGE}:${tag}" | jq -r '.Digest')"
|
||||||
|
ref="${DEST_IMAGE}@${digest}"
|
||||||
|
|
||||||
|
echo "==> cosign sign (keyless) --recursive ${ref}"
|
||||||
|
if ! cosign sign --recursive "${ref}"; then
|
||||||
|
echo "::warning title=Keyless sign failed::${ref}"
|
||||||
|
errs=$((errs+1))
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "==> cosign sign (key) --recursive ${ref}"
|
||||||
|
if ! cosign sign --key env://COSIGN_PRIVATE_KEY --recursive "${ref}"; then
|
||||||
|
echo "::warning title=Key sign failed::${ref}"
|
||||||
|
errs=$((errs+1))
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "==> cosign verify (public key) ${ref}"
|
||||||
|
if ! cosign verify --key env://COSIGN_PUBLIC_KEY "${ref}" -o text; then
|
||||||
|
echo "::warning title=Verify(pubkey) failed::${ref}"
|
||||||
|
errs=$((errs+1))
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "==> cosign verify (keyless policy) ${ref}"
|
||||||
|
if ! cosign verify \
|
||||||
|
--certificate-oidc-issuer "${issuer}" \
|
||||||
|
--certificate-identity-regexp "${id_regex}" \
|
||||||
|
"${ref}" -o text; then
|
||||||
|
echo "::warning title=Verify(keyless) failed::${ref}"
|
||||||
|
errs=$((errs+1))
|
||||||
|
else
|
||||||
|
v_ok=$((v_ok+1))
|
||||||
|
fi
|
||||||
|
done < src-tags.txt
|
||||||
|
|
||||||
|
echo "---- Summary ----"
|
||||||
|
echo "Copied : $copied"
|
||||||
|
echo "Skipped : $skipped"
|
||||||
|
echo "Verified OK : $v_ok"
|
||||||
|
echo "Errors : $errs"
|
||||||
@@ -4,11 +4,7 @@ Contributions are welcome!
|
|||||||
|
|
||||||
Please see the contribution and local development guide on the docs page before getting started:
|
Please see the contribution and local development guide on the docs page before getting started:
|
||||||
|
|
||||||
https://docs.fossorial.io/development
|
https://docs.pangolin.net/development/contributing
|
||||||
|
|
||||||
For ideas about what features to work on and our future plans, please see the roadmap:
|
|
||||||
|
|
||||||
https://docs.fossorial.io/roadmap
|
|
||||||
|
|
||||||
### Licensing Considerations
|
### Licensing Considerations
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ Olm is a [WireGuard](https://www.wireguard.com/) tunnel client designed to secur
|
|||||||
|
|
||||||
Olm is used with Pangolin and Newt as part of the larger system. See documentation below:
|
Olm is used with Pangolin and Newt as part of the larger system. See documentation below:
|
||||||
|
|
||||||
- [Full Documentation](https://docs.fossorial.io)
|
- [Full Documentation](https://docs.pangolin.net)
|
||||||
|
|
||||||
## Key Functions
|
## Key Functions
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ $ cat ~/.config/olm-client/config.json
|
|||||||
{
|
{
|
||||||
"id": "spmzu8rbpzj1qq6",
|
"id": "spmzu8rbpzj1qq6",
|
||||||
"secret": "f6v61mjutwme2kkydbw3fjo227zl60a2tsf5psw9r25hgae3",
|
"secret": "f6v61mjutwme2kkydbw3fjo227zl60a2tsf5psw9r25hgae3",
|
||||||
"endpoint": "https://pangolin.fossorial.io",
|
"endpoint": "https://app.pangolin.net",
|
||||||
"tlsClientCert": ""
|
"tlsClientCert": ""
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
If you discover a security vulnerability, please follow the steps below to responsibly disclose it to us:
|
If you discover a security vulnerability, please follow the steps below to responsibly disclose it to us:
|
||||||
|
|
||||||
1. **Do not create a public GitHub issue or discussion post.** This could put the security of other users at risk.
|
1. **Do not create a public GitHub issue or discussion post.** This could put the security of other users at risk.
|
||||||
2. Send a detailed report to [security@fossorial.io](mailto:security@fossorial.io) or send a **private** message to a maintainer on [Discord](https://discord.gg/HCJR8Xhme4). Include:
|
2. Send a detailed report to [security@pangolin.net](mailto:security@pangolin.net) or send a **private** message to a maintainer on [Discord](https://discord.gg/HCJR8Xhme4). Include:
|
||||||
|
|
||||||
- Description and location of the vulnerability.
|
- Description and location of the vulnerability.
|
||||||
- Potential impact of the vulnerability.
|
- Potential impact of the vulnerability.
|
||||||
|
|||||||
14
common.go
14
common.go
@@ -14,8 +14,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/newt/websocket"
|
|
||||||
"github.com/fosrl/olm/peermonitor"
|
"github.com/fosrl/olm/peermonitor"
|
||||||
|
"github.com/fosrl/olm/websocket"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/crypto/curve25519"
|
"golang.org/x/crypto/curve25519"
|
||||||
@@ -402,11 +402,17 @@ func keepSendingUDPHolePunchToMultipleExitNodes(exitNodes []ExitNode, olmID stri
|
|||||||
ticker := time.NewTicker(250 * time.Millisecond)
|
ticker := time.NewTicker(250 * time.Millisecond)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
timeout := time.NewTimer(15 * time.Second)
|
||||||
|
defer timeout.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-stopHolepunch:
|
case <-stopHolepunch:
|
||||||
logger.Info("Stopping UDP holepunch for all exit nodes")
|
logger.Info("Stopping UDP holepunch for all exit nodes")
|
||||||
return
|
return
|
||||||
|
case <-timeout.C:
|
||||||
|
logger.Info("UDP holepunch routine timed out after 15 seconds for all exit nodes")
|
||||||
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
// Send hole punch to all exit nodes
|
// Send hole punch to all exit nodes
|
||||||
for _, node := range resolvedNodes {
|
for _, node := range resolvedNodes {
|
||||||
@@ -471,11 +477,17 @@ func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16, s
|
|||||||
ticker := time.NewTicker(250 * time.Millisecond)
|
ticker := time.NewTicker(250 * time.Millisecond)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
timeout := time.NewTimer(15 * time.Second)
|
||||||
|
defer timeout.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-stopHolepunch:
|
case <-stopHolepunch:
|
||||||
logger.Info("Stopping UDP holepunch")
|
logger.Info("Stopping UDP holepunch")
|
||||||
return
|
return
|
||||||
|
case <-timeout.C:
|
||||||
|
logger.Info("UDP holepunch routine timed out after 15 seconds")
|
||||||
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
if err := sendUDPHolePunchWithConn(conn, remoteAddr, olmID, serverPubKey); err != nil {
|
if err := sendUDPHolePunchWithConn(conn, remoteAddr, olmID, serverPubKey); err != nil {
|
||||||
logger.Error("Failed to send UDP hole punch: %v", err)
|
logger.Error("Failed to send UDP hole punch: %v", err)
|
||||||
|
|||||||
484
config.go
Normal file
484
config.go
Normal file
@@ -0,0 +1,484 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OlmConfig holds all configuration options for the Olm client
|
||||||
|
type OlmConfig struct {
|
||||||
|
// Connection settings
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
Secret string `json:"secret"`
|
||||||
|
|
||||||
|
// Network settings
|
||||||
|
MTU int `json:"mtu"`
|
||||||
|
DNS string `json:"dns"`
|
||||||
|
InterfaceName string `json:"interface"`
|
||||||
|
|
||||||
|
// Logging
|
||||||
|
LogLevel string `json:"logLevel"`
|
||||||
|
|
||||||
|
// HTTP server
|
||||||
|
EnableHTTP bool `json:"enableHttp"`
|
||||||
|
HTTPAddr string `json:"httpAddr"`
|
||||||
|
|
||||||
|
// Ping settings
|
||||||
|
PingInterval string `json:"pingInterval"`
|
||||||
|
PingTimeout string `json:"pingTimeout"`
|
||||||
|
|
||||||
|
// Advanced
|
||||||
|
Holepunch bool `json:"holepunch"`
|
||||||
|
TlsClientCert string `json:"tlsClientCert"`
|
||||||
|
|
||||||
|
// Parsed values (not in JSON)
|
||||||
|
PingIntervalDuration time.Duration `json:"-"`
|
||||||
|
PingTimeoutDuration time.Duration `json:"-"`
|
||||||
|
|
||||||
|
// Source tracking (not in JSON)
|
||||||
|
sources map[string]string `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigSource tracks where each config value came from
|
||||||
|
type ConfigSource string
|
||||||
|
|
||||||
|
const (
|
||||||
|
SourceDefault ConfigSource = "default"
|
||||||
|
SourceFile ConfigSource = "file"
|
||||||
|
SourceEnv ConfigSource = "environment"
|
||||||
|
SourceCLI ConfigSource = "cli"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultConfig returns a config with default values
|
||||||
|
func DefaultConfig() *OlmConfig {
|
||||||
|
config := &OlmConfig{
|
||||||
|
MTU: 1280,
|
||||||
|
DNS: "8.8.8.8",
|
||||||
|
LogLevel: "INFO",
|
||||||
|
InterfaceName: "olm",
|
||||||
|
EnableHTTP: false,
|
||||||
|
HTTPAddr: ":9452",
|
||||||
|
PingInterval: "3s",
|
||||||
|
PingTimeout: "5s",
|
||||||
|
Holepunch: false,
|
||||||
|
sources: make(map[string]string),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track default sources
|
||||||
|
config.sources["mtu"] = string(SourceDefault)
|
||||||
|
config.sources["dns"] = string(SourceDefault)
|
||||||
|
config.sources["logLevel"] = string(SourceDefault)
|
||||||
|
config.sources["interface"] = string(SourceDefault)
|
||||||
|
config.sources["enableHttp"] = string(SourceDefault)
|
||||||
|
config.sources["httpAddr"] = string(SourceDefault)
|
||||||
|
config.sources["pingInterval"] = string(SourceDefault)
|
||||||
|
config.sources["pingTimeout"] = string(SourceDefault)
|
||||||
|
config.sources["holepunch"] = string(SourceDefault)
|
||||||
|
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
// getOlmConfigPath returns the path to the olm config file
|
||||||
|
func getOlmConfigPath() string {
|
||||||
|
configFile := os.Getenv("CONFIG_FILE")
|
||||||
|
if configFile != "" {
|
||||||
|
return configFile
|
||||||
|
}
|
||||||
|
|
||||||
|
var configDir string
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
configDir = filepath.Join(os.Getenv("HOME"), "Library", "Application Support", "olm-client")
|
||||||
|
case "windows":
|
||||||
|
configDir = filepath.Join(os.Getenv("PROGRAMDATA"), "olm", "olm-client")
|
||||||
|
default: // linux and others
|
||||||
|
configDir = filepath.Join(os.Getenv("HOME"), ".config", "olm-client")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||||
|
fmt.Printf("Warning: Failed to create config directory: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return filepath.Join(configDir, "config.json")
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadConfig loads configuration from file, env vars, and CLI args
|
||||||
|
// Priority: CLI args > Env vars > Config file > Defaults
|
||||||
|
// Returns: (config, showVersion, showConfig, error)
|
||||||
|
func LoadConfig(args []string) (*OlmConfig, bool, bool, error) {
|
||||||
|
// Start with defaults
|
||||||
|
config := DefaultConfig()
|
||||||
|
|
||||||
|
// Load from config file (if exists)
|
||||||
|
fileConfig, err := loadConfigFromFile()
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, false, fmt.Errorf("failed to load config file: %w", err)
|
||||||
|
}
|
||||||
|
if fileConfig != nil {
|
||||||
|
mergeConfigs(config, fileConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override with environment variables
|
||||||
|
loadConfigFromEnv(config)
|
||||||
|
|
||||||
|
// Override with CLI arguments
|
||||||
|
showVersion, showConfig, err := loadConfigFromCLI(config, args)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse duration strings
|
||||||
|
if err := config.parseDurations(); err != nil {
|
||||||
|
return nil, false, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, showVersion, showConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadConfigFromFile loads configuration from the JSON config file
|
||||||
|
func loadConfigFromFile() (*OlmConfig, error) {
|
||||||
|
configPath := getOlmConfigPath()
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil, nil // File doesn't exist, not an error
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var config OlmConfig
|
||||||
|
if err := json.Unmarshal(data, &config); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadConfigFromEnv loads configuration from environment variables
|
||||||
|
func loadConfigFromEnv(config *OlmConfig) {
|
||||||
|
if val := os.Getenv("PANGOLIN_ENDPOINT"); val != "" {
|
||||||
|
config.Endpoint = val
|
||||||
|
config.sources["endpoint"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("OLM_ID"); val != "" {
|
||||||
|
config.ID = val
|
||||||
|
config.sources["id"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("OLM_SECRET"); val != "" {
|
||||||
|
config.Secret = val
|
||||||
|
config.sources["secret"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("MTU"); val != "" {
|
||||||
|
if mtu, err := strconv.Atoi(val); err == nil {
|
||||||
|
config.MTU = mtu
|
||||||
|
config.sources["mtu"] = string(SourceEnv)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Invalid MTU value: %s, keeping current value\n", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if val := os.Getenv("DNS"); val != "" {
|
||||||
|
config.DNS = val
|
||||||
|
config.sources["dns"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("LOG_LEVEL"); val != "" {
|
||||||
|
config.LogLevel = val
|
||||||
|
config.sources["logLevel"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("INTERFACE"); val != "" {
|
||||||
|
config.InterfaceName = val
|
||||||
|
config.sources["interface"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("HTTP_ADDR"); val != "" {
|
||||||
|
config.HTTPAddr = val
|
||||||
|
config.sources["httpAddr"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("PING_INTERVAL"); val != "" {
|
||||||
|
config.PingInterval = val
|
||||||
|
config.sources["pingInterval"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("PING_TIMEOUT"); val != "" {
|
||||||
|
config.PingTimeout = val
|
||||||
|
config.sources["pingTimeout"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("ENABLE_HTTP"); val == "true" {
|
||||||
|
config.EnableHTTP = true
|
||||||
|
config.sources["enableHttp"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
if val := os.Getenv("HOLEPUNCH"); val == "true" {
|
||||||
|
config.Holepunch = true
|
||||||
|
config.sources["holepunch"] = string(SourceEnv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadConfigFromCLI loads configuration from command-line arguments
|
||||||
|
func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
|
||||||
|
serviceFlags := flag.NewFlagSet("service", flag.ContinueOnError)
|
||||||
|
|
||||||
|
// Store original values to detect changes
|
||||||
|
origValues := map[string]interface{}{
|
||||||
|
"endpoint": config.Endpoint,
|
||||||
|
"id": config.ID,
|
||||||
|
"secret": config.Secret,
|
||||||
|
"mtu": config.MTU,
|
||||||
|
"dns": config.DNS,
|
||||||
|
"logLevel": config.LogLevel,
|
||||||
|
"interface": config.InterfaceName,
|
||||||
|
"httpAddr": config.HTTPAddr,
|
||||||
|
"pingInterval": config.PingInterval,
|
||||||
|
"pingTimeout": config.PingTimeout,
|
||||||
|
"enableHttp": config.EnableHTTP,
|
||||||
|
"holepunch": config.Holepunch,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define flags
|
||||||
|
serviceFlags.StringVar(&config.Endpoint, "endpoint", config.Endpoint, "Endpoint of your Pangolin server")
|
||||||
|
serviceFlags.StringVar(&config.ID, "id", config.ID, "Olm ID")
|
||||||
|
serviceFlags.StringVar(&config.Secret, "secret", config.Secret, "Olm secret")
|
||||||
|
serviceFlags.IntVar(&config.MTU, "mtu", config.MTU, "MTU to use")
|
||||||
|
serviceFlags.StringVar(&config.DNS, "dns", config.DNS, "DNS server to use")
|
||||||
|
serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
||||||
|
serviceFlags.StringVar(&config.InterfaceName, "interface", config.InterfaceName, "Name of the WireGuard interface")
|
||||||
|
serviceFlags.StringVar(&config.HTTPAddr, "http-addr", config.HTTPAddr, "HTTP server address (e.g., ':9452')")
|
||||||
|
serviceFlags.StringVar(&config.PingInterval, "ping-interval", config.PingInterval, "Interval for pinging the server")
|
||||||
|
serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping")
|
||||||
|
serviceFlags.BoolVar(&config.EnableHTTP, "enable-http", config.EnableHTTP, "Enable HTTP server for receiving connection requests")
|
||||||
|
serviceFlags.BoolVar(&config.Holepunch, "holepunch", config.Holepunch, "Enable hole punching")
|
||||||
|
|
||||||
|
version := serviceFlags.Bool("version", false, "Print the version")
|
||||||
|
showConfig := serviceFlags.Bool("show-config", false, "Show configuration sources and exit")
|
||||||
|
|
||||||
|
// Parse the arguments
|
||||||
|
if err := serviceFlags.Parse(args); err != nil {
|
||||||
|
return false, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track which values were changed by CLI args
|
||||||
|
if config.Endpoint != origValues["endpoint"].(string) {
|
||||||
|
config.sources["endpoint"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.ID != origValues["id"].(string) {
|
||||||
|
config.sources["id"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.Secret != origValues["secret"].(string) {
|
||||||
|
config.sources["secret"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.MTU != origValues["mtu"].(int) {
|
||||||
|
config.sources["mtu"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.DNS != origValues["dns"].(string) {
|
||||||
|
config.sources["dns"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.LogLevel != origValues["logLevel"].(string) {
|
||||||
|
config.sources["logLevel"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.InterfaceName != origValues["interface"].(string) {
|
||||||
|
config.sources["interface"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.HTTPAddr != origValues["httpAddr"].(string) {
|
||||||
|
config.sources["httpAddr"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.PingInterval != origValues["pingInterval"].(string) {
|
||||||
|
config.sources["pingInterval"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.PingTimeout != origValues["pingTimeout"].(string) {
|
||||||
|
config.sources["pingTimeout"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.EnableHTTP != origValues["enableHttp"].(bool) {
|
||||||
|
config.sources["enableHttp"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
if config.Holepunch != origValues["holepunch"].(bool) {
|
||||||
|
config.sources["holepunch"] = string(SourceCLI)
|
||||||
|
}
|
||||||
|
|
||||||
|
return *version, *showConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseDurations parses the duration strings into time.Duration
|
||||||
|
func (c *OlmConfig) parseDurations() error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Parse ping interval
|
||||||
|
if c.PingInterval != "" {
|
||||||
|
c.PingIntervalDuration, err = time.ParseDuration(c.PingInterval)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Invalid PING_INTERVAL value: %s, using default 3 seconds\n", c.PingInterval)
|
||||||
|
c.PingIntervalDuration = 3 * time.Second
|
||||||
|
c.PingInterval = "3s"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
c.PingIntervalDuration = 3 * time.Second
|
||||||
|
c.PingInterval = "3s"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse ping timeout
|
||||||
|
if c.PingTimeout != "" {
|
||||||
|
c.PingTimeoutDuration, err = time.ParseDuration(c.PingTimeout)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 5 seconds\n", c.PingTimeout)
|
||||||
|
c.PingTimeoutDuration = 5 * time.Second
|
||||||
|
c.PingTimeout = "5s"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
c.PingTimeoutDuration = 5 * time.Second
|
||||||
|
c.PingTimeout = "5s"
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeConfigs merges source config into destination (only non-empty values)
|
||||||
|
// Also tracks that these values came from a file
|
||||||
|
func mergeConfigs(dest, src *OlmConfig) {
|
||||||
|
if src.Endpoint != "" {
|
||||||
|
dest.Endpoint = src.Endpoint
|
||||||
|
dest.sources["endpoint"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.ID != "" {
|
||||||
|
dest.ID = src.ID
|
||||||
|
dest.sources["id"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.Secret != "" {
|
||||||
|
dest.Secret = src.Secret
|
||||||
|
dest.sources["secret"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.MTU != 0 && src.MTU != 1280 {
|
||||||
|
dest.MTU = src.MTU
|
||||||
|
dest.sources["mtu"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.DNS != "" && src.DNS != "8.8.8.8" {
|
||||||
|
dest.DNS = src.DNS
|
||||||
|
dest.sources["dns"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.LogLevel != "" && src.LogLevel != "INFO" {
|
||||||
|
dest.LogLevel = src.LogLevel
|
||||||
|
dest.sources["logLevel"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.InterfaceName != "" && src.InterfaceName != "olm" {
|
||||||
|
dest.InterfaceName = src.InterfaceName
|
||||||
|
dest.sources["interface"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.HTTPAddr != "" && src.HTTPAddr != ":9452" {
|
||||||
|
dest.HTTPAddr = src.HTTPAddr
|
||||||
|
dest.sources["httpAddr"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.PingInterval != "" && src.PingInterval != "3s" {
|
||||||
|
dest.PingInterval = src.PingInterval
|
||||||
|
dest.sources["pingInterval"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.PingTimeout != "" && src.PingTimeout != "5s" {
|
||||||
|
dest.PingTimeout = src.PingTimeout
|
||||||
|
dest.sources["pingTimeout"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.TlsClientCert != "" {
|
||||||
|
dest.TlsClientCert = src.TlsClientCert
|
||||||
|
dest.sources["tlsClientCert"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
// For booleans, we always take the source value if explicitly set
|
||||||
|
if src.EnableHTTP {
|
||||||
|
dest.EnableHTTP = src.EnableHTTP
|
||||||
|
dest.sources["enableHttp"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
if src.Holepunch {
|
||||||
|
dest.Holepunch = src.Holepunch
|
||||||
|
dest.sources["holepunch"] = string(SourceFile)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveConfig saves the current configuration to the config file
|
||||||
|
func SaveConfig(config *OlmConfig) error {
|
||||||
|
configPath := getOlmConfigPath()
|
||||||
|
data, err := json.MarshalIndent(config, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal config: %w", err)
|
||||||
|
}
|
||||||
|
return os.WriteFile(configPath, data, 0644)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShowConfig prints the configuration and the source of each value
|
||||||
|
func (c *OlmConfig) ShowConfig() {
|
||||||
|
configPath := getOlmConfigPath()
|
||||||
|
|
||||||
|
fmt.Println("\n=== Olm Configuration ===\n")
|
||||||
|
fmt.Printf("Config File: %s\n", configPath)
|
||||||
|
|
||||||
|
// Check if config file exists
|
||||||
|
if _, err := os.Stat(configPath); err == nil {
|
||||||
|
fmt.Printf("Config File Status: ✓ exists\n")
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Config File Status: ✗ not found\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\n--- Configuration Values ---")
|
||||||
|
fmt.Println("(Format: Setting = Value [source])\n")
|
||||||
|
|
||||||
|
// Helper to get source or default
|
||||||
|
getSource := func(key string) string {
|
||||||
|
if source, ok := c.sources[key]; ok {
|
||||||
|
return source
|
||||||
|
}
|
||||||
|
return string(SourceDefault)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to format value (mask secrets)
|
||||||
|
formatValue := func(key, value string) string {
|
||||||
|
if key == "secret" && value != "" {
|
||||||
|
if len(value) > 8 {
|
||||||
|
return value[:4] + "****" + value[len(value)-4:]
|
||||||
|
}
|
||||||
|
return "****"
|
||||||
|
}
|
||||||
|
if value == "" {
|
||||||
|
return "(not set)"
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection settings
|
||||||
|
fmt.Println("Connection:")
|
||||||
|
fmt.Printf(" endpoint = %s [%s]\n", formatValue("endpoint", c.Endpoint), getSource("endpoint"))
|
||||||
|
fmt.Printf(" id = %s [%s]\n", formatValue("id", c.ID), getSource("id"))
|
||||||
|
fmt.Printf(" secret = %s [%s]\n", formatValue("secret", c.Secret), getSource("secret"))
|
||||||
|
|
||||||
|
// Network settings
|
||||||
|
fmt.Println("\nNetwork:")
|
||||||
|
fmt.Printf(" mtu = %d [%s]\n", c.MTU, getSource("mtu"))
|
||||||
|
fmt.Printf(" dns = %s [%s]\n", c.DNS, getSource("dns"))
|
||||||
|
fmt.Printf(" interface = %s [%s]\n", c.InterfaceName, getSource("interface"))
|
||||||
|
|
||||||
|
// Logging
|
||||||
|
fmt.Println("\nLogging:")
|
||||||
|
fmt.Printf(" log-level = %s [%s]\n", c.LogLevel, getSource("logLevel"))
|
||||||
|
|
||||||
|
// HTTP server
|
||||||
|
fmt.Println("\nHTTP Server:")
|
||||||
|
fmt.Printf(" enable-http = %v [%s]\n", c.EnableHTTP, getSource("enableHttp"))
|
||||||
|
fmt.Printf(" http-addr = %s [%s]\n", c.HTTPAddr, getSource("httpAddr"))
|
||||||
|
|
||||||
|
// Timing
|
||||||
|
fmt.Println("\nTiming:")
|
||||||
|
fmt.Printf(" ping-interval = %s [%s]\n", c.PingInterval, getSource("pingInterval"))
|
||||||
|
fmt.Printf(" ping-timeout = %s [%s]\n", c.PingTimeout, getSource("pingTimeout"))
|
||||||
|
|
||||||
|
// Advanced
|
||||||
|
fmt.Println("\nAdvanced:")
|
||||||
|
fmt.Printf(" holepunch = %v [%s]\n", c.Holepunch, getSource("holepunch"))
|
||||||
|
if c.TlsClientCert != "" {
|
||||||
|
fmt.Printf(" tls-cert = %s [%s]\n", c.TlsClientCert, getSource("tlsClientCert"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Source legend
|
||||||
|
fmt.Println("\n--- Source Legend ---")
|
||||||
|
fmt.Println(" default = Built-in default value")
|
||||||
|
fmt.Println(" file = Loaded from config file")
|
||||||
|
fmt.Println(" environment = Set via environment variable")
|
||||||
|
fmt.Println(" cli = Provided as command-line argument")
|
||||||
|
fmt.Println("\nPriority: cli > environment > file > default")
|
||||||
|
fmt.Println()
|
||||||
|
}
|
||||||
279
get-olm.sh
Normal file
279
get-olm.sh
Normal file
@@ -0,0 +1,279 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Get Olm - Cross-platform installation script
|
||||||
|
# Usage: curl -fsSL https://raw.githubusercontent.com/fosrl/olm/refs/heads/main/get-olm.sh | bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Colors for output
|
||||||
|
RED='\033[0;31m'
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
YELLOW='\033[1;33m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
# GitHub repository info
|
||||||
|
REPO="fosrl/olm"
|
||||||
|
GITHUB_API_URL="https://api.github.com/repos/${REPO}/releases/latest"
|
||||||
|
|
||||||
|
# Function to print colored output
|
||||||
|
print_status() {
|
||||||
|
echo -e "${GREEN}[INFO]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
print_warning() {
|
||||||
|
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
print_error() {
|
||||||
|
echo -e "${RED}[ERROR]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Function to get latest version from GitHub API
|
||||||
|
get_latest_version() {
|
||||||
|
local latest_info
|
||||||
|
|
||||||
|
if command -v curl >/dev/null 2>&1; then
|
||||||
|
latest_info=$(curl -fsSL "$GITHUB_API_URL" 2>/dev/null)
|
||||||
|
elif command -v wget >/dev/null 2>&1; then
|
||||||
|
latest_info=$(wget -qO- "$GITHUB_API_URL" 2>/dev/null)
|
||||||
|
else
|
||||||
|
print_error "Neither curl nor wget is available. Please install one of them." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -z "$latest_info" ]; then
|
||||||
|
print_error "Failed to fetch latest version information" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Extract version from JSON response (works without jq)
|
||||||
|
local version=$(echo "$latest_info" | grep '"tag_name"' | head -1 | sed 's/.*"tag_name": *"\([^"]*\)".*/\1/')
|
||||||
|
|
||||||
|
if [ -z "$version" ]; then
|
||||||
|
print_error "Could not parse version from GitHub API response" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Remove 'v' prefix if present
|
||||||
|
version=$(echo "$version" | sed 's/^v//')
|
||||||
|
|
||||||
|
echo "$version"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Detect OS and architecture
|
||||||
|
detect_platform() {
|
||||||
|
local os arch
|
||||||
|
|
||||||
|
# Detect OS
|
||||||
|
case "$(uname -s)" in
|
||||||
|
Linux*) os="linux" ;;
|
||||||
|
Darwin*) os="darwin" ;;
|
||||||
|
MINGW*|MSYS*|CYGWIN*) os="windows" ;;
|
||||||
|
FreeBSD*) os="freebsd" ;;
|
||||||
|
*)
|
||||||
|
print_error "Unsupported operating system: $(uname -s)"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
# Detect architecture
|
||||||
|
case "$(uname -m)" in
|
||||||
|
x86_64|amd64) arch="amd64" ;;
|
||||||
|
arm64|aarch64) arch="arm64" ;;
|
||||||
|
armv7l|armv6l)
|
||||||
|
if [ "$os" = "linux" ]; then
|
||||||
|
if [ "$(uname -m)" = "armv6l" ]; then
|
||||||
|
arch="arm32v6"
|
||||||
|
else
|
||||||
|
arch="arm32"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
arch="arm64" # Default for non-Linux ARM
|
||||||
|
fi
|
||||||
|
;;
|
||||||
|
riscv64)
|
||||||
|
if [ "$os" = "linux" ]; then
|
||||||
|
arch="riscv64"
|
||||||
|
else
|
||||||
|
print_error "RISC-V architecture only supported on Linux"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
print_error "Unsupported architecture: $(uname -m)"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
echo "${os}_${arch}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get installation directory
|
||||||
|
get_install_dir() {
|
||||||
|
local platform="$1"
|
||||||
|
|
||||||
|
if [[ "$platform" == *"windows"* ]]; then
|
||||||
|
echo "$HOME/bin"
|
||||||
|
else
|
||||||
|
# For Unix-like systems, prioritize system-wide directories for sudo access
|
||||||
|
# Check in order of preference: /usr/local/bin, /usr/bin, ~/.local/bin
|
||||||
|
if [ -d "/usr/local/bin" ]; then
|
||||||
|
echo "/usr/local/bin"
|
||||||
|
elif [ -d "/usr/bin" ]; then
|
||||||
|
echo "/usr/bin"
|
||||||
|
else
|
||||||
|
# Fallback to user directory if system directories don't exist
|
||||||
|
echo "$HOME/.local/bin"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if we need sudo for installation
|
||||||
|
need_sudo() {
|
||||||
|
local install_dir="$1"
|
||||||
|
|
||||||
|
# If installing to system directory and we don't have write permission, need sudo
|
||||||
|
if [[ "$install_dir" == "/usr/local/bin" || "$install_dir" == "/usr/bin" ]]; then
|
||||||
|
if [ ! -w "$install_dir" ] 2>/dev/null; then
|
||||||
|
return 0 # Need sudo
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
return 1 # Don't need sudo
|
||||||
|
}
|
||||||
|
|
||||||
|
# Download and install olm
|
||||||
|
install_olm() {
|
||||||
|
local platform="$1"
|
||||||
|
local install_dir="$2"
|
||||||
|
local binary_name="olm_${platform}"
|
||||||
|
local exe_suffix=""
|
||||||
|
|
||||||
|
# Add .exe suffix for Windows
|
||||||
|
if [[ "$platform" == *"windows"* ]]; then
|
||||||
|
binary_name="${binary_name}.exe"
|
||||||
|
exe_suffix=".exe"
|
||||||
|
fi
|
||||||
|
|
||||||
|
local download_url="${BASE_URL}/${binary_name}"
|
||||||
|
local temp_file="/tmp/olm${exe_suffix}"
|
||||||
|
local final_path="${install_dir}/olm${exe_suffix}"
|
||||||
|
|
||||||
|
print_status "Downloading olm from ${download_url}"
|
||||||
|
|
||||||
|
# Download the binary
|
||||||
|
if command -v curl >/dev/null 2>&1; then
|
||||||
|
curl -fsSL "$download_url" -o "$temp_file"
|
||||||
|
elif command -v wget >/dev/null 2>&1; then
|
||||||
|
wget -q "$download_url" -O "$temp_file"
|
||||||
|
else
|
||||||
|
print_error "Neither curl nor wget is available. Please install one of them."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if we need sudo for installation
|
||||||
|
local use_sudo=""
|
||||||
|
if need_sudo "$install_dir"; then
|
||||||
|
print_status "Administrator privileges required for system-wide installation"
|
||||||
|
if command -v sudo >/dev/null 2>&1; then
|
||||||
|
use_sudo="sudo"
|
||||||
|
else
|
||||||
|
print_error "sudo is required for system-wide installation but not available"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Create install directory if it doesn't exist
|
||||||
|
if [ -n "$use_sudo" ]; then
|
||||||
|
$use_sudo mkdir -p "$install_dir"
|
||||||
|
else
|
||||||
|
mkdir -p "$install_dir"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Move binary to install directory
|
||||||
|
if [ -n "$use_sudo" ]; then
|
||||||
|
$use_sudo mv "$temp_file" "$final_path"
|
||||||
|
$use_sudo chmod +x "$final_path"
|
||||||
|
else
|
||||||
|
mv "$temp_file" "$final_path"
|
||||||
|
chmod +x "$final_path"
|
||||||
|
fi
|
||||||
|
|
||||||
|
print_status "olm installed to ${final_path}"
|
||||||
|
|
||||||
|
# Check if install directory is in PATH (only warn for non-system directories)
|
||||||
|
if [[ "$install_dir" != "/usr/local/bin" && "$install_dir" != "/usr/bin" ]]; then
|
||||||
|
if ! echo "$PATH" | grep -q "$install_dir"; then
|
||||||
|
print_warning "Install directory ${install_dir} is not in your PATH."
|
||||||
|
print_warning "Add it to your PATH by adding this line to your shell profile:"
|
||||||
|
print_warning " export PATH=\"${install_dir}:\$PATH\""
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Verify installation
|
||||||
|
verify_installation() {
|
||||||
|
local install_dir="$1"
|
||||||
|
local exe_suffix=""
|
||||||
|
|
||||||
|
if [[ "$PLATFORM" == *"windows"* ]]; then
|
||||||
|
exe_suffix=".exe"
|
||||||
|
fi
|
||||||
|
|
||||||
|
local olm_path="${install_dir}/olm${exe_suffix}"
|
||||||
|
|
||||||
|
if [ -f "$olm_path" ] && [ -x "$olm_path" ]; then
|
||||||
|
print_status "Installation successful!"
|
||||||
|
print_status "olm version: $("$olm_path" --version 2>/dev/null || echo "unknown")"
|
||||||
|
return 0
|
||||||
|
else
|
||||||
|
print_error "Installation failed. Binary not found or not executable."
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Main installation process
|
||||||
|
main() {
|
||||||
|
print_status "Installing latest version of olm..."
|
||||||
|
|
||||||
|
# Get latest version
|
||||||
|
print_status "Fetching latest version from GitHub..."
|
||||||
|
VERSION=$(get_latest_version)
|
||||||
|
print_status "Latest version: v${VERSION}"
|
||||||
|
|
||||||
|
# Set base URL with the fetched version
|
||||||
|
BASE_URL="https://github.com/${REPO}/releases/download/${VERSION}"
|
||||||
|
|
||||||
|
# Detect platform
|
||||||
|
PLATFORM=$(detect_platform)
|
||||||
|
print_status "Detected platform: ${PLATFORM}"
|
||||||
|
|
||||||
|
# Get install directory
|
||||||
|
INSTALL_DIR=$(get_install_dir "$PLATFORM")
|
||||||
|
print_status "Install directory: ${INSTALL_DIR}"
|
||||||
|
|
||||||
|
# Inform user about system-wide installation
|
||||||
|
if [[ "$INSTALL_DIR" == "/usr/local/bin" || "$INSTALL_DIR" == "/usr/bin" ]]; then
|
||||||
|
print_status "Installing system-wide for sudo access"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Install olm
|
||||||
|
install_olm "$PLATFORM" "$INSTALL_DIR"
|
||||||
|
|
||||||
|
# Verify installation
|
||||||
|
if verify_installation "$INSTALL_DIR"; then
|
||||||
|
print_status "olm is ready to use!"
|
||||||
|
if [[ "$INSTALL_DIR" == "/usr/local/bin" || "$INSTALL_DIR" == "/usr/bin" ]]; then
|
||||||
|
print_status "olm is installed system-wide and accessible via sudo"
|
||||||
|
fi
|
||||||
|
if [[ "$PLATFORM" == *"windows"* ]]; then
|
||||||
|
print_status "Run 'olm --help' to get started"
|
||||||
|
else
|
||||||
|
print_status "Run 'olm --help' or 'sudo olm --help' to get started"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Run main function
|
||||||
|
main "$@"
|
||||||
8
go.mod
8
go.mod
@@ -3,11 +3,11 @@ module github.com/fosrl/olm
|
|||||||
go 1.25
|
go 1.25
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/fosrl/newt v0.0.0-20250730062419-3ccd755d557a
|
github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7
|
||||||
github.com/vishvananda/netlink v1.3.1
|
github.com/vishvananda/netlink v1.3.1
|
||||||
golang.org/x/crypto v0.42.0
|
golang.org/x/crypto v0.43.0
|
||||||
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792
|
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792
|
||||||
golang.org/x/sys v0.36.0
|
golang.org/x/sys v0.37.0
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
|
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||||
)
|
)
|
||||||
@@ -15,7 +15,7 @@ require (
|
|||||||
require (
|
require (
|
||||||
github.com/gorilla/websocket v1.5.3 // indirect
|
github.com/gorilla/websocket v1.5.3 // indirect
|
||||||
github.com/vishvananda/netns v0.0.5 // indirect
|
github.com/vishvananda/netns v0.0.5 // indirect
|
||||||
golang.org/x/net v0.43.0 // indirect
|
golang.org/x/net v0.45.0 // indirect
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||||
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 // indirect
|
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 // indirect
|
||||||
software.sslmate.com/src/go-pkcs12 v0.6.0 // indirect
|
software.sslmate.com/src/go-pkcs12 v0.6.0 // indirect
|
||||||
|
|||||||
16
go.sum
16
go.sum
@@ -1,5 +1,5 @@
|
|||||||
github.com/fosrl/newt v0.0.0-20250730062419-3ccd755d557a h1:bUGN4piHlcqgfdRLrwqiLZZxgcitzBzNDQS1+CHSmJI=
|
github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7 h1:6bSU8Efyhx1SR53iSw1Wjk5V8vDfizGAudq/GlE9b+o=
|
||||||
github.com/fosrl/newt v0.0.0-20250730062419-3ccd755d557a/go.mod h1:PbiPYp1hbL07awrmbqTSTz7lTenieTHN6cIkUVCGD3I=
|
github.com/fosrl/newt v0.0.0-20250929233849-71c5bf7e65f7/go.mod h1:Ac0k2FmAMC+hu21rAK+p7EnnEGrqKO/QZuGTVHA/XDM=
|
||||||
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
||||||
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
@@ -10,16 +10,16 @@ github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW
|
|||||||
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
|
||||||
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
||||||
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||||
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
|
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||||
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
|
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||||
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4=
|
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4=
|
||||||
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
|
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
|
||||||
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
|
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
|
||||||
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
|
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
|
||||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
|
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||||
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||||
|
|||||||
302
main.go
302
main.go
@@ -3,7 +3,6 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"flag"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
@@ -15,10 +14,10 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/newt/websocket"
|
"github.com/fosrl/newt/updates"
|
||||||
"github.com/fosrl/olm/httpserver"
|
"github.com/fosrl/olm/httpserver"
|
||||||
"github.com/fosrl/olm/peermonitor"
|
"github.com/fosrl/olm/peermonitor"
|
||||||
"github.com/fosrl/olm/wgtester"
|
"github.com/fosrl/olm/websocket"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
@@ -53,7 +52,6 @@ func formatEndpoint(endpoint string) string {
|
|||||||
return endpoint
|
return endpoint
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// Check if we're running as a Windows service
|
// Check if we're running as a Windows service
|
||||||
if isWindowsService() {
|
if isWindowsService() {
|
||||||
@@ -145,16 +143,27 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
case "config":
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
showServiceConfig()
|
||||||
|
} else {
|
||||||
|
fmt.Println("Service configuration is only available on Windows")
|
||||||
|
}
|
||||||
|
return
|
||||||
case "help", "--help", "-h":
|
case "help", "--help", "-h":
|
||||||
fmt.Println("Olm WireGuard VPN Client")
|
fmt.Println("Olm WireGuard VPN Client")
|
||||||
fmt.Println("\nWindows Service Management:")
|
fmt.Println("\nWindows Service Management:")
|
||||||
fmt.Println(" install Install the service")
|
fmt.Println(" install Install the service")
|
||||||
fmt.Println(" remove Remove the service")
|
fmt.Println(" remove Remove the service")
|
||||||
fmt.Println(" start Start the service")
|
fmt.Println(" start [args] Start the service with optional arguments")
|
||||||
fmt.Println(" stop Stop the service")
|
fmt.Println(" stop Stop the service")
|
||||||
fmt.Println(" status Show service status")
|
fmt.Println(" status Show service status")
|
||||||
fmt.Println(" debug Run service in debug mode")
|
fmt.Println(" debug [args] Run service in debug mode with optional arguments")
|
||||||
fmt.Println(" logs Tail the service log file")
|
fmt.Println(" logs Tail the service log file")
|
||||||
|
fmt.Println(" config Show current service configuration")
|
||||||
|
fmt.Println("\nExamples:")
|
||||||
|
fmt.Println(" olm start --enable-http --http-addr :9452")
|
||||||
|
fmt.Println(" olm debug --endpoint https://example.com --id myid --secret mysecret")
|
||||||
fmt.Println("\nFor console mode, run without arguments or with standard flags.")
|
fmt.Println("\nFor console mode, run without arguments or with standard flags.")
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
@@ -193,122 +202,40 @@ func runOlmMain(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func runOlmMainWithArgs(ctx context.Context, args []string) {
|
func runOlmMainWithArgs(ctx context.Context, args []string) {
|
||||||
// Log that we've entered the main function
|
// Load configuration from file, env vars, and CLI args
|
||||||
// fmt.Printf("runOlmMainWithArgs() called with args: %v\n", args)
|
// Priority: CLI args > Env vars > Config file > Defaults
|
||||||
|
config, showVersion, showConfig, err := LoadConfig(args)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to load configuration: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Create a new FlagSet for parsing service arguments
|
// Handle --show-config flag
|
||||||
serviceFlags := flag.NewFlagSet("service", flag.ContinueOnError)
|
if showConfig {
|
||||||
|
config.ShowConfig()
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract commonly used values from config for convenience
|
||||||
var (
|
var (
|
||||||
endpoint string
|
endpoint = config.Endpoint
|
||||||
id string
|
id = config.ID
|
||||||
secret string
|
secret = config.Secret
|
||||||
mtu string
|
mtu = config.MTU
|
||||||
mtuInt int
|
logLevel = config.LogLevel
|
||||||
dns string
|
interfaceName = config.InterfaceName
|
||||||
|
enableHTTP = config.EnableHTTP
|
||||||
|
httpAddr = config.HTTPAddr
|
||||||
|
pingInterval = config.PingIntervalDuration
|
||||||
|
pingTimeout = config.PingTimeoutDuration
|
||||||
|
doHolepunch = config.Holepunch
|
||||||
privateKey wgtypes.Key
|
privateKey wgtypes.Key
|
||||||
err error
|
|
||||||
logLevel string
|
|
||||||
interfaceName string
|
|
||||||
enableHTTP bool
|
|
||||||
httpAddr string
|
|
||||||
testMode bool // Add this var for the test flag
|
|
||||||
testTarget string // Add this var for test target
|
|
||||||
pingInterval time.Duration
|
|
||||||
pingTimeout time.Duration
|
|
||||||
doHolepunch bool
|
|
||||||
connected bool
|
connected bool
|
||||||
)
|
)
|
||||||
|
|
||||||
stopHolepunch = make(chan struct{})
|
stopHolepunch = make(chan struct{})
|
||||||
stopPing = make(chan struct{})
|
stopPing = make(chan struct{})
|
||||||
|
|
||||||
// if PANGOLIN_ENDPOINT, OLM_ID, and OLM_SECRET are set as environment variables, they will be used as default values
|
|
||||||
endpoint = os.Getenv("PANGOLIN_ENDPOINT")
|
|
||||||
id = os.Getenv("OLM_ID")
|
|
||||||
secret = os.Getenv("OLM_SECRET")
|
|
||||||
mtu = os.Getenv("MTU")
|
|
||||||
dns = os.Getenv("DNS")
|
|
||||||
logLevel = os.Getenv("LOG_LEVEL")
|
|
||||||
interfaceName = os.Getenv("INTERFACE")
|
|
||||||
httpAddr = os.Getenv("HTTP_ADDR")
|
|
||||||
pingIntervalStr := os.Getenv("PING_INTERVAL")
|
|
||||||
pingTimeoutStr := os.Getenv("PING_TIMEOUT")
|
|
||||||
enableHTTPEnv := os.Getenv("ENABLE_HTTP")
|
|
||||||
holepunchEnv := os.Getenv("HOLEPUNCH")
|
|
||||||
|
|
||||||
enableHTTP = enableHTTPEnv == "true"
|
|
||||||
doHolepunch = holepunchEnv == "true"
|
|
||||||
|
|
||||||
if endpoint == "" {
|
|
||||||
serviceFlags.StringVar(&endpoint, "endpoint", "", "Endpoint of your Pangolin server")
|
|
||||||
}
|
|
||||||
if id == "" {
|
|
||||||
serviceFlags.StringVar(&id, "id", "", "Olm ID")
|
|
||||||
}
|
|
||||||
if secret == "" {
|
|
||||||
serviceFlags.StringVar(&secret, "secret", "", "Olm secret")
|
|
||||||
}
|
|
||||||
if mtu == "" {
|
|
||||||
serviceFlags.StringVar(&mtu, "mtu", "1280", "MTU to use")
|
|
||||||
}
|
|
||||||
if dns == "" {
|
|
||||||
serviceFlags.StringVar(&dns, "dns", "8.8.8.8", "DNS server to use")
|
|
||||||
}
|
|
||||||
if logLevel == "" {
|
|
||||||
serviceFlags.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
|
||||||
}
|
|
||||||
if interfaceName == "" {
|
|
||||||
serviceFlags.StringVar(&interfaceName, "interface", "olm", "Name of the WireGuard interface")
|
|
||||||
}
|
|
||||||
if httpAddr == "" {
|
|
||||||
serviceFlags.StringVar(&httpAddr, "http-addr", ":9452", "HTTP server address (e.g., ':9452')")
|
|
||||||
}
|
|
||||||
if pingIntervalStr == "" {
|
|
||||||
serviceFlags.StringVar(&pingIntervalStr, "ping-interval", "3s", "Interval for pinging the server (default 3s)")
|
|
||||||
}
|
|
||||||
if pingTimeoutStr == "" {
|
|
||||||
serviceFlags.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 3s)")
|
|
||||||
}
|
|
||||||
if enableHTTPEnv == "" {
|
|
||||||
serviceFlags.BoolVar(&enableHTTP, "enable-http", false, "Enable HTT server for receiving connection requests")
|
|
||||||
}
|
|
||||||
if holepunchEnv == "" {
|
|
||||||
serviceFlags.BoolVar(&doHolepunch, "holepunch", false, "Enable hole punching (default false)")
|
|
||||||
}
|
|
||||||
|
|
||||||
version := serviceFlags.Bool("version", false, "Print the version")
|
|
||||||
|
|
||||||
// Parse the service arguments
|
|
||||||
if err := serviceFlags.Parse(args); err != nil {
|
|
||||||
fmt.Printf("Error parsing service arguments: %v\n", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Debug: Print final values after flag parsing
|
|
||||||
// fmt.Printf("After flag parsing: endpoint='%s', id='%s', secret='%s'\n", endpoint, id, secret)
|
|
||||||
|
|
||||||
// Parse ping intervals
|
|
||||||
if pingIntervalStr != "" {
|
|
||||||
pingInterval, err = time.ParseDuration(pingIntervalStr)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("Invalid PING_INTERVAL value: %s, using default 3 seconds\n", pingIntervalStr)
|
|
||||||
pingInterval = 3 * time.Second
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
pingInterval = 3 * time.Second
|
|
||||||
}
|
|
||||||
|
|
||||||
if pingTimeoutStr != "" {
|
|
||||||
pingTimeout, err = time.ParseDuration(pingTimeoutStr)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 5 seconds\n", pingTimeoutStr)
|
|
||||||
pingTimeout = 5 * time.Second
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
pingTimeout = 5 * time.Second
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setup Windows event logging if on Windows
|
// Setup Windows event logging if on Windows
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
setupWindowsEventLog()
|
setupWindowsEventLog()
|
||||||
@@ -320,11 +247,14 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
|
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
|
||||||
|
|
||||||
olmVersion := "version_replaceme"
|
olmVersion := "version_replaceme"
|
||||||
if *version {
|
if showVersion {
|
||||||
fmt.Println("Olm version " + olmVersion)
|
fmt.Println("Olm version " + olmVersion)
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
} else {
|
}
|
||||||
logger.Info("Olm version " + olmVersion)
|
logger.Info("Olm version " + olmVersion)
|
||||||
|
|
||||||
|
if err := updates.CheckForUpdate("fosrl", "olm", olmVersion); err != nil {
|
||||||
|
logger.Debug("Failed to check for updates: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Log startup information
|
// Log startup information
|
||||||
@@ -336,35 +266,6 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
|
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle test mode
|
|
||||||
if testMode {
|
|
||||||
if testTarget == "" {
|
|
||||||
logger.Fatal("Test mode requires -test-target to be set to a server:port")
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Running in test mode, connecting to %s", testTarget)
|
|
||||||
|
|
||||||
// Create a new tester client
|
|
||||||
tester, err := wgtester.NewClient(testTarget)
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal("Failed to create tester client: %v", err)
|
|
||||||
}
|
|
||||||
defer tester.Close()
|
|
||||||
|
|
||||||
// Test connection with a 2-second timeout
|
|
||||||
connected, rtt := tester.TestConnectionWithTimeout(2 * time.Second)
|
|
||||||
|
|
||||||
if connected {
|
|
||||||
logger.Info("Connection test successful! RTT: %v", rtt)
|
|
||||||
fmt.Printf("Connection test successful! RTT: %v\n", rtt)
|
|
||||||
os.Exit(0)
|
|
||||||
} else {
|
|
||||||
logger.Error("Connection test failed - no response received")
|
|
||||||
fmt.Println("Connection test failed - no response received")
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var httpServer *httpserver.HTTPServer
|
var httpServer *httpserver.HTTPServer
|
||||||
if enableHTTP {
|
if enableHTTP {
|
||||||
httpServer = httpserver.NewHTTPServer(httpAddr)
|
httpServer = httpserver.NewHTTPServer(httpAddr)
|
||||||
@@ -422,9 +323,6 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to create olm: %v", err)
|
logger.Fatal("Failed to create olm: %v", err)
|
||||||
}
|
}
|
||||||
endpoint = olm.GetConfig().Endpoint // Update endpoint from config
|
|
||||||
id = olm.GetConfig().ID // Update ID from config
|
|
||||||
secret = olm.GetConfig().Secret // Update secret from config
|
|
||||||
|
|
||||||
// wait until we have a client id and secret and endpoint
|
// wait until we have a client id and secret and endpoint
|
||||||
waitCount := 0
|
waitCount := 0
|
||||||
@@ -452,12 +350,6 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// parse the mtu string into an int
|
|
||||||
mtuInt, err = strconv.Atoi(mtu)
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal("Failed to parse MTU: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
privateKey, err = wgtypes.GeneratePrivateKey()
|
privateKey, err = wgtypes.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to generate private key: %v", err)
|
logger.Fatal("Failed to generate private key: %v", err)
|
||||||
@@ -578,12 +470,12 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return tun.CreateTUN(interfaceName, mtuInt)
|
return tun.CreateTUN(interfaceName, mtu)
|
||||||
}
|
}
|
||||||
if tunFdStr := os.Getenv(ENV_WG_TUN_FD); tunFdStr != "" {
|
if tunFdStr := os.Getenv(ENV_WG_TUN_FD); tunFdStr != "" {
|
||||||
return createTUNFromFD(tunFdStr, mtuInt)
|
return createTUNFromFD(tunFdStr, mtu)
|
||||||
}
|
}
|
||||||
return tun.CreateTUN(interfaceName, mtuInt)
|
return tun.CreateTUN(interfaceName, mtu)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -598,30 +490,47 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
fileUAPI, err := func() (*os.File, error) {
|
fileUAPI, err := func() (*os.File, error) {
|
||||||
if uapiFdStr := os.Getenv(ENV_WG_UAPI_FD); uapiFdStr != "" {
|
if uapiFdStr := os.Getenv(ENV_WG_UAPI_FD); uapiFdStr != "" {
|
||||||
fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
|
fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
|
||||||
if err != nil { return nil, err }
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return os.NewFile(uintptr(fd), ""), nil
|
return os.NewFile(uintptr(fd), ""), nil
|
||||||
}
|
}
|
||||||
return uapiOpen(interfaceName)
|
return uapiOpen(interfaceName)
|
||||||
}()
|
}()
|
||||||
if err != nil { logger.Error("UAPI listen error: %v", err); os.Exit(1); return }
|
if err != nil {
|
||||||
|
logger.Error("UAPI listen error: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
|
dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
|
||||||
|
|
||||||
uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
||||||
if err != nil { logger.Error("Failed to listen on uapi socket: %v", err); os.Exit(1) }
|
if err != nil {
|
||||||
|
logger.Error("Failed to listen on uapi socket: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
conn, err := uapiListener.Accept()
|
conn, err := uapiListener.Accept()
|
||||||
if err != nil { return }
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
go dev.IpcHandle(conn)
|
go dev.IpcHandle(conn)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
logger.Info("UAPI listener started")
|
logger.Info("UAPI listener started")
|
||||||
|
|
||||||
if err = dev.Up(); err != nil { logger.Error("Failed to bring up WireGuard device: %v", err) }
|
if err = dev.Up(); err != nil {
|
||||||
if err = ConfigureInterface(interfaceName, wgData); err != nil { logger.Error("Failed to configure interface: %v", err) }
|
logger.Error("Failed to bring up WireGuard device: %v", err)
|
||||||
if httpServer != nil { httpServer.SetTunnelIP(wgData.TunnelIP) }
|
}
|
||||||
|
if err = ConfigureInterface(interfaceName, wgData); err != nil {
|
||||||
|
logger.Error("Failed to configure interface: %v", err)
|
||||||
|
}
|
||||||
|
if httpServer != nil {
|
||||||
|
httpServer.SetTunnelIP(wgData.TunnelIP)
|
||||||
|
}
|
||||||
|
|
||||||
peerMonitor = peermonitor.NewPeerMonitor(
|
peerMonitor = peermonitor.NewPeerMonitor(
|
||||||
func(siteID int, connected bool, rtt time.Duration) {
|
func(siteID int, connected bool, rtt time.Duration) {
|
||||||
@@ -661,9 +570,18 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
// Format the endpoint before configuring the peer.
|
// Format the endpoint before configuring the peer.
|
||||||
site.Endpoint = formatEndpoint(site.Endpoint)
|
site.Endpoint = formatEndpoint(site.Endpoint)
|
||||||
|
|
||||||
if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil { logger.Error("Failed to configure peer: %v", err); return }
|
if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil {
|
||||||
if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil { logger.Error("Failed to add route for peer: %v", err); return }
|
logger.Error("Failed to configure peer: %v", err)
|
||||||
if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil { logger.Error("Failed to add routes for remote subnets: %v", err); return }
|
return
|
||||||
|
}
|
||||||
|
if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add route for peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add routes for remote subnets: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
logger.Info("Configured peer %s", site.PublicKey)
|
logger.Info("Configured peer %s", site.PublicKey)
|
||||||
}
|
}
|
||||||
@@ -702,19 +620,33 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
|
|
||||||
// Update the peer in WireGuard
|
// Update the peer in WireGuard
|
||||||
if dev != nil {
|
if dev != nil {
|
||||||
// Find the existing peer to get old RemoteSubnets
|
// Find the existing peer to get old data
|
||||||
var oldRemoteSubnets string
|
var oldRemoteSubnets string
|
||||||
|
var oldPublicKey string
|
||||||
for _, site := range wgData.Sites {
|
for _, site := range wgData.Sites {
|
||||||
if site.SiteId == updateData.SiteId {
|
if site.SiteId == updateData.SiteId {
|
||||||
oldRemoteSubnets = site.RemoteSubnets
|
oldRemoteSubnets = site.RemoteSubnets
|
||||||
|
oldPublicKey = site.PublicKey
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the public key has changed, remove the old peer first
|
||||||
|
if oldPublicKey != "" && oldPublicKey != updateData.PublicKey {
|
||||||
|
logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey)
|
||||||
|
if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil {
|
||||||
|
logger.Error("Failed to remove old peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Format the endpoint before updating the peer.
|
// Format the endpoint before updating the peer.
|
||||||
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
|
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
|
||||||
|
|
||||||
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { logger.Error("Failed to update peer: %v", err); return }
|
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
||||||
|
logger.Error("Failed to update peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Remove old remote subnet routes if they changed
|
// Remove old remote subnet routes if they changed
|
||||||
if oldRemoteSubnets != siteConfig.RemoteSubnets {
|
if oldRemoteSubnets != siteConfig.RemoteSubnets {
|
||||||
@@ -733,7 +665,10 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
// Update successful
|
// Update successful
|
||||||
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
|
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
|
||||||
for i := range wgData.Sites {
|
for i := range wgData.Sites {
|
||||||
if wgData.Sites[i].SiteId == updateData.SiteId { wgData.Sites[i] = siteConfig; break }
|
if wgData.Sites[i].SiteId == updateData.SiteId {
|
||||||
|
wgData.Sites[i] = siteConfig
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
logger.Error("WireGuard device not initialized")
|
logger.Error("WireGuard device not initialized")
|
||||||
@@ -771,9 +706,18 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
// Format the endpoint before adding the new peer.
|
// Format the endpoint before adding the new peer.
|
||||||
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
|
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
|
||||||
|
|
||||||
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { logger.Error("Failed to add peer: %v", err); return }
|
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
||||||
if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil { logger.Error("Failed to add route for new peer: %v", err); return }
|
logger.Error("Failed to add peer: %v", err)
|
||||||
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil { logger.Error("Failed to add routes for remote subnets: %v", err); return }
|
return
|
||||||
|
}
|
||||||
|
if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add route for new peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add routes for remote subnets: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Add successful
|
// Add successful
|
||||||
logger.Info("Successfully added peer for site %d", addData.SiteId)
|
logger.Info("Successfully added peer for site %d", addData.SiteId)
|
||||||
@@ -907,6 +851,14 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
httpServer.SetConnectionStatus(true)
|
httpServer.SetConnectionStatus(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CRITICAL: Save our full config AFTER websocket saves its limited config
|
||||||
|
// This ensures all 13 fields are preserved, not just the 4 that websocket saves
|
||||||
|
if err := SaveConfig(config); err != nil {
|
||||||
|
logger.Error("Failed to save full olm config: %v", err)
|
||||||
|
} else {
|
||||||
|
logger.Debug("Saved full olm config with all options")
|
||||||
|
}
|
||||||
|
|
||||||
if connected {
|
if connected {
|
||||||
logger.Debug("Already connected, skipping registration")
|
logger.Debug("Already connected, skipping registration")
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
88
olm.iss
Normal file
88
olm.iss
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
; Script generated by the Inno Setup Script Wizard.
|
||||||
|
; SEE THE DOCUMENTATION FOR DETAILS ON CREATING INNO SETUP SCRIPT FILES!
|
||||||
|
|
||||||
|
#define MyAppName "olm"
|
||||||
|
#define MyAppVersion "1.0.0"
|
||||||
|
#define MyAppPublisher "Fossorial Inc."
|
||||||
|
#define MyAppURL "https://pangolin.net"
|
||||||
|
#define MyAppExeName "olm.exe"
|
||||||
|
|
||||||
|
[Setup]
|
||||||
|
; NOTE: The value of AppId uniquely identifies this application. Do not use the same AppId value in installers for other applications.
|
||||||
|
; (To generate a new GUID, click Tools | Generate GUID inside the IDE.)
|
||||||
|
AppId={{44A24E4C-B616-476F-ADE7-8D56B930959E}
|
||||||
|
AppName={#MyAppName}
|
||||||
|
AppVersion={#MyAppVersion}
|
||||||
|
;AppVerName={#MyAppName} {#MyAppVersion}
|
||||||
|
AppPublisher={#MyAppPublisher}
|
||||||
|
AppPublisherURL={#MyAppURL}
|
||||||
|
AppSupportURL={#MyAppURL}
|
||||||
|
AppUpdatesURL={#MyAppURL}
|
||||||
|
DefaultDirName={autopf}\{#MyAppName}
|
||||||
|
UninstallDisplayIcon={app}\{#MyAppExeName}
|
||||||
|
; "ArchitecturesAllowed=x64compatible" specifies that Setup cannot run
|
||||||
|
; on anything but x64 and Windows 11 on Arm.
|
||||||
|
ArchitecturesAllowed=x64compatible
|
||||||
|
; "ArchitecturesInstallIn64BitMode=x64compatible" requests that the
|
||||||
|
; install be done in "64-bit mode" on x64 or Windows 11 on Arm,
|
||||||
|
; meaning it should use the native 64-bit Program Files directory and
|
||||||
|
; the 64-bit view of the registry.
|
||||||
|
ArchitecturesInstallIn64BitMode=x64compatible
|
||||||
|
DefaultGroupName={#MyAppName}
|
||||||
|
DisableProgramGroupPage=yes
|
||||||
|
; Uncomment the following line to run in non administrative install mode (install for current user only).
|
||||||
|
;PrivilegesRequired=lowest
|
||||||
|
OutputBaseFilename=mysetup
|
||||||
|
SolidCompression=yes
|
||||||
|
WizardStyle=modern
|
||||||
|
; Add this to ensure PATH changes are applied and the system is prompted for a restart if needed
|
||||||
|
RestartIfNeededByRun=no
|
||||||
|
ChangesEnvironment=true
|
||||||
|
|
||||||
|
[Languages]
|
||||||
|
Name: "english"; MessagesFile: "compiler:Default.isl"
|
||||||
|
|
||||||
|
[Files]
|
||||||
|
; The 'DestName' flag ensures that 'olm_windows_amd64.exe' is installed as 'olm.exe'
|
||||||
|
Source: "C:\Users\Administrator\Downloads\olm_windows_amd64.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}"; Flags: ignoreversion
|
||||||
|
Source: "C:\Users\Administrator\Downloads\wintun.dll"; DestDir: "{app}"; Flags: ignoreversion
|
||||||
|
; NOTE: Don't use "Flags: ignoreversion" on any shared system files
|
||||||
|
|
||||||
|
[Icons]
|
||||||
|
Name: "{group}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}"
|
||||||
|
|
||||||
|
[Registry]
|
||||||
|
; Add the application's installation directory to the system PATH environment variable.
|
||||||
|
; HKLM (HKEY_LOCAL_MACHINE) is used for system-wide changes.
|
||||||
|
; The 'Path' variable is located under 'SYSTEM\CurrentControlSet\Control\Session Manager\Environment'.
|
||||||
|
; ValueType: expandsz allows for environment variables (like %ProgramFiles%) in the path.
|
||||||
|
; ValueData: "{olddata};{app}" appends the current application directory to the existing PATH.
|
||||||
|
; Flags: uninsdeletevalue ensures the entry is removed upon uninstallation.
|
||||||
|
; Check: IsWin64 ensures this is applied on 64-bit systems, which matches ArchitecturesAllowed.
|
||||||
|
[Registry]
|
||||||
|
; Add the application's installation directory to the system PATH.
|
||||||
|
Root: HKLM; Subkey: "SYSTEM\CurrentControlSet\Control\Session Manager\Environment"; \
|
||||||
|
ValueType: expandsz; ValueName: "Path"; ValueData: "{olddata};{app}"; \
|
||||||
|
Flags: uninsdeletevalue; Check: NeedsAddPath(ExpandConstant('{app}'))
|
||||||
|
|
||||||
|
[Code]
|
||||||
|
function NeedsAddPath(Path: string): boolean;
|
||||||
|
var
|
||||||
|
OrigPath: string;
|
||||||
|
begin
|
||||||
|
if not RegQueryStringValue(HKEY_LOCAL_MACHINE,
|
||||||
|
'SYSTEM\CurrentControlSet\Control\Session Manager\Environment',
|
||||||
|
'Path', OrigPath)
|
||||||
|
then begin
|
||||||
|
// Path variable doesn't exist at all, so we definitely need to add it.
|
||||||
|
Result := True;
|
||||||
|
exit;
|
||||||
|
end;
|
||||||
|
|
||||||
|
// Perform a case-insensitive check to see if the path is already present.
|
||||||
|
// We add semicolons to prevent partial matches (e.g., matching C:\App in C:\App2).
|
||||||
|
if Pos(';' + UpperCase(Path) + ';', ';' + UpperCase(OrigPath) + ';') > 0 then
|
||||||
|
Result := False
|
||||||
|
else
|
||||||
|
Result := True;
|
||||||
|
end;
|
||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/newt/websocket"
|
"github.com/fosrl/olm/websocket"
|
||||||
"github.com/fosrl/olm/wgtester"
|
"github.com/fosrl/olm/wgtester"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
)
|
)
|
||||||
@@ -205,11 +205,11 @@ func (pm *PeerMonitor) HandleFailover(siteID int, relayEndpoint string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for IPv6 and format the endpoint correctly
|
// Check for IPv6 and format the endpoint correctly
|
||||||
formattedEndpoint := relayEndpoint
|
formattedEndpoint := relayEndpoint
|
||||||
if strings.Contains(relayEndpoint, ":") {
|
if strings.Contains(relayEndpoint, ":") {
|
||||||
formattedEndpoint = fmt.Sprintf("[%s]", relayEndpoint)
|
formattedEndpoint = fmt.Sprintf("[%s]", relayEndpoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure WireGuard to use the relay
|
// Configure WireGuard to use the relay
|
||||||
wgConfig := fmt.Sprintf(`private_key=%s
|
wgConfig := fmt.Sprintf(`private_key=%s
|
||||||
|
|||||||
@@ -48,3 +48,7 @@ func setupWindowsEventLog() {
|
|||||||
func watchLogFile(end bool) error {
|
func watchLogFile(end bool) error {
|
||||||
return fmt.Errorf("watching log file is only available on Windows")
|
return fmt.Errorf("watching log file is only available on Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func showServiceConfig() {
|
||||||
|
fmt.Println("Service configuration is only available on Windows")
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -69,12 +70,6 @@ func loadServiceArgs() ([]string, error) {
|
|||||||
return nil, fmt.Errorf("failed to read service args: %v", err)
|
return nil, fmt.Errorf("failed to read service args: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// delete the file after reading
|
|
||||||
err = os.Remove(argsPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to delete service args file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var args []string
|
var args []string
|
||||||
err = json.Unmarshal(data, &args)
|
err = json.Unmarshal(data, &args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -95,7 +90,7 @@ func (s *olmService) Execute(args []string, r <-chan svc.ChangeRequest, changes
|
|||||||
const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown
|
const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown
|
||||||
changes <- svc.Status{State: svc.StartPending}
|
changes <- svc.Status{State: svc.StartPending}
|
||||||
|
|
||||||
s.elog.Info(1, "Service Execute called, starting main logic")
|
s.elog.Info(1, fmt.Sprintf("Service Execute called with args: %v", args))
|
||||||
|
|
||||||
// Load saved service arguments
|
// Load saved service arguments
|
||||||
savedArgs, err := loadServiceArgs()
|
savedArgs, err := loadServiceArgs()
|
||||||
@@ -104,7 +99,24 @@ func (s *olmService) Execute(args []string, r <-chan svc.ChangeRequest, changes
|
|||||||
// Continue with empty args if loading fails
|
// Continue with empty args if loading fails
|
||||||
savedArgs = []string{}
|
savedArgs = []string{}
|
||||||
}
|
}
|
||||||
s.args = savedArgs
|
|
||||||
|
// Combine service start args with saved args, giving priority to service start args
|
||||||
|
finalArgs := []string{}
|
||||||
|
if len(args) > 0 {
|
||||||
|
// Skip the first arg which is typically the service name
|
||||||
|
if len(args) > 1 {
|
||||||
|
finalArgs = append(finalArgs, args[1:]...)
|
||||||
|
}
|
||||||
|
s.elog.Info(1, fmt.Sprintf("Using service start parameters: %v", finalArgs))
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no service start parameters, use saved args
|
||||||
|
if len(finalArgs) == 0 && len(savedArgs) > 0 {
|
||||||
|
finalArgs = savedArgs
|
||||||
|
s.elog.Info(1, fmt.Sprintf("Using saved service args: %v", finalArgs))
|
||||||
|
}
|
||||||
|
|
||||||
|
s.args = finalArgs
|
||||||
|
|
||||||
// Start the main olm functionality
|
// Start the main olm functionality
|
||||||
olmDone := make(chan struct{})
|
olmDone := make(chan struct{})
|
||||||
@@ -309,7 +321,7 @@ func removeService() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func startService(args []string) error {
|
func startService(args []string) error {
|
||||||
// Save the service arguments before starting
|
// Save the service arguments as backup
|
||||||
if len(args) > 0 {
|
if len(args) > 0 {
|
||||||
err := saveServiceArgs(args)
|
err := saveServiceArgs(args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -329,7 +341,8 @@ func startService(args []string) error {
|
|||||||
}
|
}
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
err = s.Start()
|
// Pass arguments directly to the service start call
|
||||||
|
err = s.Start(args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to start service: %v", err)
|
return fmt.Errorf("failed to start service: %v", err)
|
||||||
}
|
}
|
||||||
@@ -379,17 +392,12 @@ func debugService(args []string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// fmt.Printf("Starting service in debug mode...\n")
|
// Start the service with the provided arguments
|
||||||
|
err := startService(args)
|
||||||
// Start the service
|
|
||||||
err := startService([]string{}) // Pass empty args since we already saved them
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to start service: %v", err)
|
return fmt.Errorf("failed to start service: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// fmt.Printf("Service started. Watching logs (Press Ctrl+C to stop watching)...\n")
|
|
||||||
// fmt.Printf("================================================================================\n")
|
|
||||||
|
|
||||||
// Watch the log file
|
// Watch the log file
|
||||||
return watchLogFile(true)
|
return watchLogFile(true)
|
||||||
}
|
}
|
||||||
@@ -509,11 +517,89 @@ func getServiceStatus() (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// showServiceConfig displays current saved service configuration
|
||||||
|
func showServiceConfig() {
|
||||||
|
configPath := getServiceArgsPath()
|
||||||
|
fmt.Printf("Service configuration file: %s\n", configPath)
|
||||||
|
|
||||||
|
args, err := loadServiceArgs()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("No saved configuration found or error loading: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(args) == 0 {
|
||||||
|
fmt.Println("No saved service arguments found")
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Saved service arguments: %v\n", args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func isWindowsService() bool {
|
func isWindowsService() bool {
|
||||||
isWindowsService, err := svc.IsWindowsService()
|
isWindowsService, err := svc.IsWindowsService()
|
||||||
return err == nil && isWindowsService
|
return err == nil && isWindowsService
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// rotateLogFile handles daily log rotation
|
||||||
|
func rotateLogFile(logDir string, logFile string) error {
|
||||||
|
// Get current log file info
|
||||||
|
info, err := os.Stat(logFile)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil // No current log file to rotate
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to stat log file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if log file is from today
|
||||||
|
now := time.Now()
|
||||||
|
fileTime := info.ModTime()
|
||||||
|
|
||||||
|
// If the log file is from today, no rotation needed
|
||||||
|
if now.Year() == fileTime.Year() && now.YearDay() == fileTime.YearDay() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create rotated filename with date
|
||||||
|
rotatedName := fmt.Sprintf("olm-%s.log", fileTime.Format("2006-01-02"))
|
||||||
|
rotatedPath := filepath.Join(logDir, rotatedName)
|
||||||
|
|
||||||
|
// Rename current log file to dated filename
|
||||||
|
err = os.Rename(logFile, rotatedPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to rotate log file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up old log files (keep last 30 days)
|
||||||
|
cleanupOldLogFiles(logDir, 30)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupOldLogFiles removes log files older than specified days
|
||||||
|
func cleanupOldLogFiles(logDir string, daysToKeep int) {
|
||||||
|
cutoff := time.Now().AddDate(0, 0, -daysToKeep)
|
||||||
|
|
||||||
|
files, err := os.ReadDir(logDir)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, file := range files {
|
||||||
|
if !file.IsDir() && strings.HasPrefix(file.Name(), "olm-") && strings.HasSuffix(file.Name(), ".log") {
|
||||||
|
filePath := filepath.Join(logDir, file.Name())
|
||||||
|
info, err := file.Info()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.ModTime().Before(cutoff) {
|
||||||
|
os.Remove(filePath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func setupWindowsEventLog() {
|
func setupWindowsEventLog() {
|
||||||
// Create log directory if it doesn't exist
|
// Create log directory if it doesn't exist
|
||||||
logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm", "logs")
|
logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm", "logs")
|
||||||
@@ -524,6 +610,14 @@ func setupWindowsEventLog() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
logFile := filepath.Join(logDir, "olm.log")
|
logFile := filepath.Join(logDir, "olm.log")
|
||||||
|
|
||||||
|
// Rotate log file if needed
|
||||||
|
err = rotateLogFile(logDir, logFile)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to rotate log file: %v\n", err)
|
||||||
|
// Continue anyway to create new log file
|
||||||
|
}
|
||||||
|
|
||||||
file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
file, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Failed to open log file: %v\n", err)
|
fmt.Printf("Failed to open log file: %v\n", err)
|
||||||
|
|||||||
637
websocket/client.go
Normal file
637
websocket/client.go
Normal file
@@ -0,0 +1,637 @@
|
|||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"software.sslmate.com/src/go-pkcs12"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TokenResponse struct {
|
||||||
|
Data struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
} `json:"data"`
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type WSMessage struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Data interface{} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// this is not json anymore
|
||||||
|
type Config struct {
|
||||||
|
ID string
|
||||||
|
Secret string
|
||||||
|
Endpoint string
|
||||||
|
TlsClientCert string // legacy PKCS12 file path
|
||||||
|
}
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
config *Config
|
||||||
|
conn *websocket.Conn
|
||||||
|
baseURL string
|
||||||
|
handlers map[string]MessageHandler
|
||||||
|
done chan struct{}
|
||||||
|
handlersMux sync.RWMutex
|
||||||
|
reconnectInterval time.Duration
|
||||||
|
isConnected bool
|
||||||
|
reconnectMux sync.RWMutex
|
||||||
|
pingInterval time.Duration
|
||||||
|
pingTimeout time.Duration
|
||||||
|
onConnect func() error
|
||||||
|
onTokenUpdate func(token string)
|
||||||
|
writeMux sync.Mutex
|
||||||
|
clientType string // Type of client (e.g., "newt", "olm")
|
||||||
|
tlsConfig TLSConfig
|
||||||
|
configNeedsSave bool // Flag to track if config needs to be saved
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientOption func(*Client)
|
||||||
|
|
||||||
|
type MessageHandler func(message WSMessage)
|
||||||
|
|
||||||
|
// TLSConfig holds TLS configuration options
|
||||||
|
type TLSConfig struct {
|
||||||
|
// New separate certificate support
|
||||||
|
ClientCertFile string
|
||||||
|
ClientKeyFile string
|
||||||
|
CAFiles []string
|
||||||
|
|
||||||
|
// Existing PKCS12 support (deprecated)
|
||||||
|
PKCS12File string
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBaseURL sets the base URL for the client
|
||||||
|
func WithBaseURL(url string) ClientOption {
|
||||||
|
return func(c *Client) {
|
||||||
|
c.baseURL = url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTLSConfig sets the TLS configuration for the client
|
||||||
|
func WithTLSConfig(config TLSConfig) ClientOption {
|
||||||
|
return func(c *Client) {
|
||||||
|
c.tlsConfig = config
|
||||||
|
// For backward compatibility, also set the legacy field
|
||||||
|
if config.PKCS12File != "" {
|
||||||
|
c.config.TlsClientCert = config.PKCS12File
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) OnConnect(callback func() error) {
|
||||||
|
c.onConnect = callback
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) OnTokenUpdate(callback func(token string)) {
|
||||||
|
c.onTokenUpdate = callback
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient creates a new websocket client
|
||||||
|
func NewClient(clientType string, ID, secret string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
|
||||||
|
config := &Config{
|
||||||
|
ID: ID,
|
||||||
|
Secret: secret,
|
||||||
|
Endpoint: endpoint,
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &Client{
|
||||||
|
config: config,
|
||||||
|
baseURL: endpoint, // default value
|
||||||
|
handlers: make(map[string]MessageHandler),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
reconnectInterval: 3 * time.Second,
|
||||||
|
isConnected: false,
|
||||||
|
pingInterval: pingInterval,
|
||||||
|
pingTimeout: pingTimeout,
|
||||||
|
clientType: clientType,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply options before loading config
|
||||||
|
for _, opt := range opts {
|
||||||
|
if opt == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
opt(client)
|
||||||
|
}
|
||||||
|
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) GetConfig() *Config {
|
||||||
|
return c.config
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect establishes the WebSocket connection
|
||||||
|
func (c *Client) Connect() error {
|
||||||
|
go c.connectWithRetry()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the WebSocket connection gracefully
|
||||||
|
func (c *Client) Close() error {
|
||||||
|
// Signal shutdown to all goroutines first
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
// Already closed
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
close(c.done)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set connection status to false
|
||||||
|
c.setConnected(false)
|
||||||
|
|
||||||
|
// Close the WebSocket connection gracefully
|
||||||
|
if c.conn != nil {
|
||||||
|
// Send close message
|
||||||
|
c.writeMux.Lock()
|
||||||
|
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||||
|
c.writeMux.Unlock()
|
||||||
|
|
||||||
|
// Close the connection
|
||||||
|
return c.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendMessage sends a message through the WebSocket connection
|
||||||
|
func (c *Client) SendMessage(messageType string, data interface{}) error {
|
||||||
|
if c.conn == nil {
|
||||||
|
return fmt.Errorf("not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := WSMessage{
|
||||||
|
Type: messageType,
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Sending message: %s, data: %+v", messageType, data)
|
||||||
|
|
||||||
|
c.writeMux.Lock()
|
||||||
|
defer c.writeMux.Unlock()
|
||||||
|
return c.conn.WriteJSON(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) {
|
||||||
|
stopChan := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
count := 0
|
||||||
|
maxAttempts := 10
|
||||||
|
|
||||||
|
err := c.SendMessage(messageType, data) // Send immediately
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to send initial message: %v", err)
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
if count >= maxAttempts {
|
||||||
|
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = c.SendMessage(messageType, data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to send message: %v", err)
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
case <-stopChan:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return func() {
|
||||||
|
close(stopChan)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterHandler registers a handler for a specific message type
|
||||||
|
func (c *Client) RegisterHandler(messageType string, handler MessageHandler) {
|
||||||
|
c.handlersMux.Lock()
|
||||||
|
defer c.handlersMux.Unlock()
|
||||||
|
c.handlers[messageType] = handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) getToken() (string, error) {
|
||||||
|
// Parse the base URL to ensure we have the correct hostname
|
||||||
|
baseURL, err := url.Parse(c.baseURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to parse base URL: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure we have the base URL without trailing slashes
|
||||||
|
baseEndpoint := strings.TrimRight(baseURL.String(), "/")
|
||||||
|
|
||||||
|
var tlsConfig *tls.Config = nil
|
||||||
|
|
||||||
|
// Use new TLS configuration method
|
||||||
|
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
|
||||||
|
tlsConfig, err = c.setupTLS()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to setup TLS configuration: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for environment variable to skip TLS verification
|
||||||
|
if os.Getenv("SKIP_TLS_VERIFY") == "true" {
|
||||||
|
if tlsConfig == nil {
|
||||||
|
tlsConfig = &tls.Config{}
|
||||||
|
}
|
||||||
|
tlsConfig.InsecureSkipVerify = true
|
||||||
|
logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenData map[string]interface{}
|
||||||
|
|
||||||
|
// Get a new token
|
||||||
|
if c.clientType == "newt" {
|
||||||
|
tokenData = map[string]interface{}{
|
||||||
|
"newtId": c.config.ID,
|
||||||
|
"secret": c.config.Secret,
|
||||||
|
}
|
||||||
|
} else if c.clientType == "olm" {
|
||||||
|
tokenData = map[string]interface{}{
|
||||||
|
"olmId": c.config.ID,
|
||||||
|
"secret": c.config.Secret,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(tokenData)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to marshal token request data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new request
|
||||||
|
req, err := http.NewRequest(
|
||||||
|
"POST",
|
||||||
|
baseEndpoint+"/api/v1/auth/"+c.clientType+"/get-token",
|
||||||
|
bytes.NewBuffer(jsonData),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set headers
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("X-CSRF-Token", "x-csrf-protection")
|
||||||
|
|
||||||
|
// Make the request
|
||||||
|
client := &http.Client{}
|
||||||
|
if tlsConfig != nil {
|
||||||
|
client.Transport = &http.Transport{
|
||||||
|
TLSClientConfig: tlsConfig,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to request new token: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||||
|
return "", fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenResp TokenResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||||
|
logger.Error("Failed to decode token response.")
|
||||||
|
return "", fmt.Errorf("failed to decode token response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tokenResp.Success {
|
||||||
|
return "", fmt.Errorf("failed to get token: %s", tokenResp.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tokenResp.Data.Token == "" {
|
||||||
|
return "", fmt.Errorf("received empty token from server")
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Received token: %s", tokenResp.Data.Token)
|
||||||
|
|
||||||
|
return tokenResp.Data.Token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) connectWithRetry() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
err := c.establishConnection()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
|
||||||
|
time.Sleep(c.reconnectInterval)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) establishConnection() error {
|
||||||
|
// Get token for authentication
|
||||||
|
token, err := c.getToken()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.onTokenUpdate != nil {
|
||||||
|
c.onTokenUpdate(token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the base URL to determine protocol and hostname
|
||||||
|
baseURL, err := url.Parse(c.baseURL)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse base URL: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine WebSocket protocol based on HTTP protocol
|
||||||
|
wsProtocol := "wss"
|
||||||
|
if baseURL.Scheme == "http" {
|
||||||
|
wsProtocol = "ws"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create WebSocket URL
|
||||||
|
wsURL := fmt.Sprintf("%s://%s/api/v1/ws", wsProtocol, baseURL.Host)
|
||||||
|
u, err := url.Parse(wsURL)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse WebSocket URL: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add token to query parameters
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("token", token)
|
||||||
|
q.Set("clientType", c.clientType)
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
// Connect to WebSocket
|
||||||
|
dialer := websocket.DefaultDialer
|
||||||
|
|
||||||
|
// Use new TLS configuration method
|
||||||
|
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
|
||||||
|
logger.Info("Setting up TLS configuration for WebSocket connection")
|
||||||
|
tlsConfig, err := c.setupTLS()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to setup TLS configuration: %w", err)
|
||||||
|
}
|
||||||
|
dialer.TLSClientConfig = tlsConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for environment variable to skip TLS verification for WebSocket connection
|
||||||
|
if os.Getenv("SKIP_TLS_VERIFY") == "true" {
|
||||||
|
if dialer.TLSClientConfig == nil {
|
||||||
|
dialer.TLSClientConfig = &tls.Config{}
|
||||||
|
}
|
||||||
|
dialer.TLSClientConfig.InsecureSkipVerify = true
|
||||||
|
logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, _, err := dialer.Dial(u.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to WebSocket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.conn = conn
|
||||||
|
c.setConnected(true)
|
||||||
|
|
||||||
|
// Start the ping monitor
|
||||||
|
go c.pingMonitor()
|
||||||
|
// Start the read pump with disconnect detection
|
||||||
|
go c.readPumpWithDisconnectDetection()
|
||||||
|
|
||||||
|
if c.onConnect != nil {
|
||||||
|
if err := c.onConnect(); err != nil {
|
||||||
|
logger.Error("OnConnect callback failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupTLS configures TLS based on the TLS configuration
|
||||||
|
func (c *Client) setupTLS() (*tls.Config, error) {
|
||||||
|
tlsConfig := &tls.Config{}
|
||||||
|
|
||||||
|
// Handle new separate certificate configuration
|
||||||
|
if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" {
|
||||||
|
logger.Info("Loading separate certificate files for mTLS")
|
||||||
|
logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile)
|
||||||
|
logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile)
|
||||||
|
|
||||||
|
// Load client certificate and key
|
||||||
|
cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load client certificate pair: %w", err)
|
||||||
|
}
|
||||||
|
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||||
|
|
||||||
|
// Load CA certificates for remote validation if specified
|
||||||
|
if len(c.tlsConfig.CAFiles) > 0 {
|
||||||
|
logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles)
|
||||||
|
caCertPool := x509.NewCertPool()
|
||||||
|
for _, caFile := range c.tlsConfig.CAFiles {
|
||||||
|
caCert, err := os.ReadFile(caFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read CA file %s: %w", caFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to parse as PEM first, then DER
|
||||||
|
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||||
|
// If PEM parsing failed, try DER
|
||||||
|
cert, err := x509.ParseCertificate(caCert)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse CA certificate from %s: %w", caFile, err)
|
||||||
|
}
|
||||||
|
caCertPool.AddCert(cert)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tlsConfig.RootCAs = caCertPool
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlsConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to existing PKCS12 implementation for backward compatibility
|
||||||
|
if c.tlsConfig.PKCS12File != "" {
|
||||||
|
logger.Info("Loading PKCS12 certificate for mTLS (deprecated)")
|
||||||
|
return c.setupPKCS12TLS()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Legacy fallback using config.TlsClientCert
|
||||||
|
if c.config.TlsClientCert != "" {
|
||||||
|
logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)")
|
||||||
|
return loadClientCertificate(c.config.TlsClientCert)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupPKCS12TLS loads TLS configuration from PKCS12 file
|
||||||
|
func (c *Client) setupPKCS12TLS() (*tls.Config, error) {
|
||||||
|
return loadClientCertificate(c.tlsConfig.PKCS12File)
|
||||||
|
}
|
||||||
|
|
||||||
|
// pingMonitor sends pings at a short interval and triggers reconnect on failure
|
||||||
|
func (c *Client) pingMonitor() {
|
||||||
|
ticker := time.NewTicker(c.pingInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if c.conn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.writeMux.Lock()
|
||||||
|
err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout))
|
||||||
|
c.writeMux.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
// Check if we're shutting down before logging error and reconnecting
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
// Expected during shutdown
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
logger.Error("Ping failed: %v", err)
|
||||||
|
c.reconnect()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// readPumpWithDisconnectDetection reads messages and triggers reconnect on error
|
||||||
|
func (c *Client) readPumpWithDisconnectDetection() {
|
||||||
|
defer func() {
|
||||||
|
if c.conn != nil {
|
||||||
|
c.conn.Close()
|
||||||
|
}
|
||||||
|
// Only attempt reconnect if we're not shutting down
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
// Shutting down, don't reconnect
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
c.reconnect()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
var msg WSMessage
|
||||||
|
err := c.conn.ReadJSON(&msg)
|
||||||
|
if err != nil {
|
||||||
|
// Check if we're shutting down before logging error
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
// Expected during shutdown, don't log as error
|
||||||
|
logger.Debug("WebSocket connection closed during shutdown")
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
// Unexpected error during normal operation
|
||||||
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
|
||||||
|
logger.Error("WebSocket read error: %v", err)
|
||||||
|
} else {
|
||||||
|
logger.Debug("WebSocket connection closed: %v", err)
|
||||||
|
}
|
||||||
|
return // triggers reconnect via defer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.handlersMux.RLock()
|
||||||
|
if handler, ok := c.handlers[msg.Type]; ok {
|
||||||
|
handler(msg)
|
||||||
|
}
|
||||||
|
c.handlersMux.RUnlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) reconnect() {
|
||||||
|
c.setConnected(false)
|
||||||
|
if c.conn != nil {
|
||||||
|
c.conn.Close()
|
||||||
|
c.conn = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only reconnect if we're not shutting down
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
go c.connectWithRetry()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) setConnected(status bool) {
|
||||||
|
c.reconnectMux.Lock()
|
||||||
|
defer c.reconnectMux.Unlock()
|
||||||
|
c.isConnected = status
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadClientCertificate Helper method to load client certificates (PKCS12 format)
|
||||||
|
func loadClientCertificate(p12Path string) (*tls.Config, error) {
|
||||||
|
logger.Info("Loading tls-client-cert %s", p12Path)
|
||||||
|
// Read the PKCS12 file
|
||||||
|
p12Data, err := os.ReadFile(p12Path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read PKCS12 file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse PKCS12 with empty password for non-encrypted files
|
||||||
|
privateKey, certificate, caCerts, err := pkcs12.DecodeChain(p12Data, "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode PKCS12: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create certificate
|
||||||
|
cert := tls.Certificate{
|
||||||
|
Certificate: [][]byte{certificate.Raw},
|
||||||
|
PrivateKey: privateKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optional: Add CA certificates if present
|
||||||
|
rootCAs, err := x509.SystemCertPool()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load system cert pool: %w", err)
|
||||||
|
}
|
||||||
|
if len(caCerts) > 0 {
|
||||||
|
for _, caCert := range caCerts {
|
||||||
|
rootCAs.AddCert(caCert)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create TLS configuration
|
||||||
|
return &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
RootCAs: rootCAs,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user