diff --git a/.github/workflows/cicd.yml b/.github/workflows/cicd.yml index 1479a55..db55d7c 100644 --- a/.github/workflows/cicd.yml +++ b/.github/workflows/cicd.yml @@ -39,7 +39,7 @@ jobs: run: | TAG=${{ env.TAG }} if [ -f main.go ]; then - sed -i 's/Newt version replaceme/Newt version '"$TAG"'/' main.go + sed -i 's/version_replaceme/'"$TAG"'/' main.go echo "Updated main.go with version $TAG" else echo "main.go not found" diff --git a/main.go b/main.go index 091406c..e070a29 100644 --- a/main.go +++ b/main.go @@ -173,18 +173,22 @@ func main() { flag.Parse() - newtVersion := "Newt version replaceme" - if *version { - fmt.Println(newtVersion) - os.Exit(0) - } else { - logger.Info(newtVersion) - } - logger.Init() loggerLevel := parseLogLevel(logLevel) logger.GetLogger().SetLevel(parseLogLevel(logLevel)) + newtVersion := "version_replaceme" + if *version { + fmt.Println("Newt version " + newtVersion) + os.Exit(0) + } else { + logger.Info("Newt version " + newtVersion) + } + + if err := CheckForUpdate("fosrl", "newt", newtVersion); err != nil { + logger.Error("Error checking for updates: %v\n", err) + } + // parse the mtu string into an int mtuInt, err = strconv.Atoi(mtu) if err != nil { @@ -466,6 +470,11 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub results := make([]nodeResult, len(exitNodes)) const pingAttempts = 3 for i, node := range exitNodes { + if connected && node.WasPreviouslyConnected { + logger.Info("Skipping ping for previously connected exit node so we pick another %d (%s)", node.ID, node.Endpoint) + continue + } + var totalLatency time.Duration var lastErr error successes := 0 diff --git a/updates.go b/updates.go new file mode 100644 index 0000000..ef5bbb5 --- /dev/null +++ b/updates.go @@ -0,0 +1,173 @@ +package main + +import ( + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + "time" +) + +// GitHubRelease represents the GitHub API response for a release +type GitHubRelease struct { + TagName string `json:"tag_name"` + Name string `json:"name"` + HTMLURL string `json:"html_url"` +} + +// Version represents a semantic version +type Version struct { + Major int + Minor int + Patch int +} + +// parseVersion parses a semantic version string (e.g., "v1.2.3" or "1.2.3") +func parseVersion(versionStr string) (Version, error) { + // Remove 'v' prefix if present + versionStr = strings.TrimPrefix(versionStr, "v") + + parts := strings.Split(versionStr, ".") + if len(parts) != 3 { + return Version{}, fmt.Errorf("invalid version format: %s", versionStr) + } + + major, err := strconv.Atoi(parts[0]) + if err != nil { + return Version{}, fmt.Errorf("invalid major version: %s", parts[0]) + } + + minor, err := strconv.Atoi(parts[1]) + if err != nil { + return Version{}, fmt.Errorf("invalid minor version: %s", parts[1]) + } + + patch, err := strconv.Atoi(parts[2]) + if err != nil { + return Version{}, fmt.Errorf("invalid patch version: %s", parts[2]) + } + + return Version{Major: major, Minor: minor, Patch: patch}, nil +} + +// isNewer returns true if v2 is newer than v1 +func (v1 Version) isNewer(v2 Version) bool { + if v2.Major > v1.Major { + return true + } + if v2.Major < v1.Major { + return false + } + + if v2.Minor > v1.Minor { + return true + } + if v2.Minor < v1.Minor { + return false + } + + return v2.Patch > v1.Patch +} + +// String returns the version as a string +func (v Version) String() string { + return fmt.Sprintf("%d.%d.%d", v.Major, v.Minor, v.Patch) +} + +// CheckForUpdate checks GitHub for a newer version and prints an update banner if found +func CheckForUpdate(owner, repo, currentVersion string) error { + if currentVersion == "version_replaceme" { + return nil + } + + // GitHub API URL for latest release + url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", owner, repo) + + // Create HTTP client with timeout + client := &http.Client{ + Timeout: 10 * time.Second, + } + + // Make the request + resp, err := client.Get(url) + if err != nil { + return fmt.Errorf("failed to fetch release info: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("GitHub API returned status: %d", resp.StatusCode) + } + + // Parse the JSON response + var release GitHubRelease + if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { + return fmt.Errorf("failed to parse release info: %w", err) + } + + // Parse current and latest versions + currentVer, err := parseVersion(currentVersion) + if err != nil { + return fmt.Errorf("invalid current version: %w", err) + } + + latestVer, err := parseVersion(release.TagName) + if err != nil { + return fmt.Errorf("invalid latest version: %w", err) + } + + // Check if update is available + if currentVer.isNewer(latestVer) { + printUpdateBanner(currentVer.String(), latestVer.String(), release.HTMLURL) + } + + return nil +} + +// printUpdateBanner prints a colorful update notification banner +func printUpdateBanner(currentVersion, latestVersion, releaseURL string) { + const contentWidth = 70 // width between the border lines + + borderTop := "╔" + strings.Repeat("═", contentWidth) + "╗" + borderMid := "╠" + strings.Repeat("═", contentWidth) + "╣" + borderBot := "╚" + strings.Repeat("═", contentWidth) + "╝" + emptyLine := "║" + strings.Repeat(" ", contentWidth) + "║" + + lines := []string{ + borderTop, + "║" + centerText("UPDATE AVAILABLE", contentWidth) + "║", + borderMid, + emptyLine, + "║ Current Version: " + padRight(currentVersion, contentWidth-19) + "║", + "║ Latest Version: " + padRight(latestVersion, contentWidth-19) + "║", + emptyLine, + "║ A newer version is available! Please update to get the" + padRight("", contentWidth-56) + "║", + "║ latest features, bug fixes, and security improvements." + padRight("", contentWidth-56) + "║", + emptyLine, + "║ Release URL: " + padRight(releaseURL, contentWidth-15) + "║", + emptyLine, + borderBot, + } + + for _, line := range lines { + fmt.Println(line) + } +} + +// padRight pads s with spaces on the right to the given width +func padRight(s string, width int) string { + if len(s) > width { + return s[:width] + } + return s + strings.Repeat(" ", width-len(s)) +} + +// centerText centers s in a field of width w +func centerText(s string, w int) string { + if len(s) >= w { + return s[:w] + } + padding := (w - len(s)) / 2 + return strings.Repeat(" ", padding) + s + strings.Repeat(" ", w-len(s)-padding) +} diff --git a/websocket/client.go b/websocket/client.go index 4bd2c7d..9e11f01 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -33,6 +33,7 @@ type Client struct { pingTimeout time.Duration onConnect func() error onTokenUpdate func(token string) + writeMux sync.Mutex } type ClientOption func(*Client) @@ -73,7 +74,7 @@ func NewClient(newtID, secret string, endpoint string, pingInterval time.Duratio baseURL: endpoint, // default value handlers: make(map[string]MessageHandler), done: make(chan struct{}), - reconnectInterval: 10 * time.Second, + reconnectInterval: 3 * time.Second, isConnected: false, pingInterval: pingInterval, pingTimeout: pingTimeout, @@ -125,6 +126,8 @@ func (c *Client) SendMessage(messageType string, data interface{}) error { Data: data, } + c.writeMux.Lock() + defer c.writeMux.Unlock() return c.conn.WriteJSON(msg) } @@ -220,6 +223,7 @@ func (c *Client) getToken() (string, error) { var tokenResp TokenResponse if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + logger.Error("Failed to decode token check response. Raw response: %s", resp.Body) return "", fmt.Errorf("failed to decode token check response: %w", err) } @@ -268,10 +272,7 @@ func (c *Client) getToken() (string, error) { var tokenResp TokenResponse if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { - // print out the token response for debugging - buf := new(bytes.Buffer) - buf.ReadFrom(resp.Body) - logger.Info("Token response: %s", buf.String()) + logger.Error("Failed to decode token response. Raw response: %s", resp.Body) return "", fmt.Errorf("failed to decode token response: %w", err) } @@ -386,7 +387,10 @@ func (c *Client) pingMonitor() { if c.conn == nil { return } - if err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout)); err != nil { + c.writeMux.Lock() + err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout)) + c.writeMux.Unlock() + if err != nil { logger.Error("Ping failed: %v", err) c.reconnect() return