Add optional user token to validate

Former-commit-id: 5734684a21
This commit is contained in:
Owen
2025-11-07 14:51:00 -08:00
parent befab0f8d1
commit 235877c379
5 changed files with 52 additions and 34 deletions

View File

@@ -18,6 +18,7 @@ type OlmConfig struct {
ID string `json:"id"` ID string `json:"id"`
Secret string `json:"secret"` Secret string `json:"secret"`
OrgID string `json:"org"` OrgID string `json:"org"`
UserToken string `json:"userToken"`
// Network settings // Network settings
MTU int `json:"mtu"` MTU int `json:"mtu"`
@@ -193,6 +194,10 @@ func loadConfigFromEnv(config *OlmConfig) {
config.OrgID = val config.OrgID = val
config.sources["org"] = string(SourceEnv) config.sources["org"] = string(SourceEnv)
} }
if val := os.Getenv("USER_TOKEN"); val != "" {
config.UserToken = val
config.sources["userToken"] = string(SourceEnv)
}
if val := os.Getenv("MTU"); val != "" { if val := os.Getenv("MTU"); val != "" {
if mtu, err := strconv.Atoi(val); err == nil { if mtu, err := strconv.Atoi(val); err == nil {
config.MTU = mtu config.MTU = mtu
@@ -249,6 +254,7 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
"id": config.ID, "id": config.ID,
"secret": config.Secret, "secret": config.Secret,
"org": config.OrgID, "org": config.OrgID,
"userToken": config.UserToken,
"mtu": config.MTU, "mtu": config.MTU,
"dns": config.DNS, "dns": config.DNS,
"logLevel": config.LogLevel, "logLevel": config.LogLevel,
@@ -266,6 +272,7 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
serviceFlags.StringVar(&config.ID, "id", config.ID, "Olm ID") serviceFlags.StringVar(&config.ID, "id", config.ID, "Olm ID")
serviceFlags.StringVar(&config.Secret, "secret", config.Secret, "Olm secret") serviceFlags.StringVar(&config.Secret, "secret", config.Secret, "Olm secret")
serviceFlags.StringVar(&config.OrgID, "org", config.OrgID, "Organization ID") serviceFlags.StringVar(&config.OrgID, "org", config.OrgID, "Organization ID")
serviceFlags.StringVar(&config.UserToken, "user-token", config.UserToken, "User token (optional)")
serviceFlags.IntVar(&config.MTU, "mtu", config.MTU, "MTU to use") serviceFlags.IntVar(&config.MTU, "mtu", config.MTU, "MTU to use")
serviceFlags.StringVar(&config.DNS, "dns", config.DNS, "DNS server to use") serviceFlags.StringVar(&config.DNS, "dns", config.DNS, "DNS server to use")
serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") serviceFlags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
@@ -298,6 +305,9 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
if config.OrgID != origValues["org"].(string) { if config.OrgID != origValues["org"].(string) {
config.sources["org"] = string(SourceCLI) config.sources["org"] = string(SourceCLI)
} }
if config.UserToken != origValues["userToken"].(string) {
config.sources["userToken"] = string(SourceCLI)
}
if config.MTU != origValues["mtu"].(int) { if config.MTU != origValues["mtu"].(int) {
config.sources["mtu"] = string(SourceCLI) config.sources["mtu"] = string(SourceCLI)
} }
@@ -384,6 +394,10 @@ func mergeConfigs(dest, src *OlmConfig) {
dest.OrgID = src.OrgID dest.OrgID = src.OrgID
dest.sources["org"] = string(SourceFile) dest.sources["org"] = string(SourceFile)
} }
if src.UserToken != "" {
dest.UserToken = src.UserToken
dest.sources["userToken"] = string(SourceFile)
}
if src.MTU != 0 && src.MTU != 1280 { if src.MTU != 0 && src.MTU != 1280 {
dest.MTU = src.MTU dest.MTU = src.MTU
dest.sources["mtu"] = string(SourceFile) dest.sources["mtu"] = string(SourceFile)
@@ -490,6 +504,7 @@ func (c *OlmConfig) ShowConfig() {
fmt.Printf(" id = %s [%s]\n", formatValue("id", c.ID), getSource("id")) fmt.Printf(" id = %s [%s]\n", formatValue("id", c.ID), getSource("id"))
fmt.Printf(" secret = %s [%s]\n", formatValue("secret", c.Secret), getSource("secret")) fmt.Printf(" secret = %s [%s]\n", formatValue("secret", c.Secret), getSource("secret"))
fmt.Printf(" org = %s [%s]\n", formatValue("org", c.OrgID), getSource("org")) fmt.Printf(" org = %s [%s]\n", formatValue("org", c.OrgID), getSource("org"))
fmt.Printf(" user-token = %s [%s]\n", formatValue("userToken", c.UserToken), getSource("userToken"))
// Network settings // Network settings
fmt.Println("\nNetwork:") fmt.Println("\nNetwork:")

View File

@@ -195,6 +195,7 @@ func main() {
Endpoint: config.Endpoint, Endpoint: config.Endpoint,
ID: config.ID, ID: config.ID,
Secret: config.Secret, Secret: config.Secret,
UserToken: config.UserToken,
MTU: config.MTU, MTU: config.MTU,
DNS: config.DNS, DNS: config.DNS,
InterfaceName: config.InterfaceName, InterfaceName: config.InterfaceName,

View File

@@ -562,6 +562,7 @@ func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
func sendPing(olm *websocket.Client) error { func sendPing(olm *websocket.Client) error {
err := olm.SendMessage("olm/ping", map[string]interface{}{ err := olm.SendMessage("olm/ping", map[string]interface{}{
"timestamp": time.Now().Unix(), "timestamp": time.Now().Unix(),
"userToken": olm.GetConfig().UserToken,
}) })
if err != nil { if err != nil {
logger.Error("Failed to send ping message: %v", err) logger.Error("Failed to send ping message: %v", err)

View File

@@ -24,6 +24,7 @@ type Config struct {
Endpoint string Endpoint string
ID string ID string
Secret string Secret string
UserToken string
// Network settings // Network settings
MTU int MTU int
@@ -107,6 +108,7 @@ func Run(ctx context.Context, config Config) {
id = config.ID id = config.ID
secret = config.Secret secret = config.Secret
endpoint = config.Endpoint endpoint = config.Endpoint
userToken = config.UserToken
) )
// Main event loop that handles connect, disconnect, and reconnect // Main event loop that handles connect, disconnect, and reconnect
@@ -129,12 +131,13 @@ func Run(ctx context.Context, config Config) {
id = req.ID id = req.ID
secret = req.Secret secret = req.Secret
endpoint = req.Endpoint endpoint = req.Endpoint
userToken := req.UserToken
// Start the tunnel process with the new credentials // Start the tunnel process with the new credentials
if id != "" && secret != "" && endpoint != "" { if id != "" && secret != "" && endpoint != "" {
logger.Info("Starting tunnel with new credentials") logger.Info("Starting tunnel with new credentials")
tunnelRunning = true tunnelRunning = true
go TunnelProcess(ctx, config, id, secret, endpoint) go TunnelProcess(ctx, config, id, secret, userToken, endpoint)
} }
case <-apiServer.GetDisconnectChannel(): case <-apiServer.GetDisconnectChannel():
@@ -144,13 +147,14 @@ func Run(ctx context.Context, config Config) {
id = "" id = ""
secret = "" secret = ""
endpoint = "" endpoint = ""
userToken = ""
default: default:
// If we have credentials and no tunnel is running, start it // If we have credentials and no tunnel is running, start it
if id != "" && secret != "" && endpoint != "" && !tunnelRunning { if id != "" && secret != "" && endpoint != "" && !tunnelRunning {
logger.Info("Starting tunnel process with initial credentials") logger.Info("Starting tunnel process with initial credentials")
tunnelRunning = true tunnelRunning = true
go TunnelProcess(ctx, config, id, secret, endpoint) go TunnelProcess(ctx, config, id, secret, userToken, endpoint)
} else if id == "" || secret == "" || endpoint == "" { } else if id == "" || secret == "" || endpoint == "" {
// If we don't have credentials, check if API is enabled // If we don't have credentials, check if API is enabled
if !config.EnableAPI { if !config.EnableAPI {
@@ -181,7 +185,7 @@ shutdown:
logger.Info("Olm service shutting down") logger.Info("Olm service shutting down")
} }
func TunnelProcess(ctx context.Context, config Config, id string, secret string, endpoint string) { func TunnelProcess(ctx context.Context, config Config, id string, secret string, userToken string, endpoint string) {
// Create a cancellable context for this tunnel process // Create a cancellable context for this tunnel process
tunnelCtx, cancel := context.WithCancel(ctx) tunnelCtx, cancel := context.WithCancel(ctx)
tunnelCancel = cancel tunnelCancel = cancel
@@ -200,9 +204,9 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
// Create a new olm client using the provided credentials // Create a new olm client using the provided credentials
olm, err := websocket.NewClient( olm, err := websocket.NewClient(
"olm",
id, // Use provided ID id, // Use provided ID
secret, // Use provided secret secret, // Use provided secret
userToken, // Use provided user token OPTIONAL
endpoint, // Use provided endpoint endpoint, // Use provided endpoint
config.PingIntervalDuration, config.PingIntervalDuration,
config.PingTimeoutDuration, config.PingTimeoutDuration,

View File

@@ -39,6 +39,7 @@ type Config struct {
Secret string Secret string
Endpoint string Endpoint string
TlsClientCert string // legacy PKCS12 file path TlsClientCert string // legacy PKCS12 file path
UserToken string // optional user token for websocket authentication
} }
type Client struct { type Client struct {
@@ -103,11 +104,12 @@ func (c *Client) OnTokenUpdate(callback func(token string)) {
} }
// NewClient creates a new websocket client // NewClient creates a new websocket client
func NewClient(clientType string, ID, secret string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) { func NewClient(ID, secret string, userToken string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
config := &Config{ config := &Config{
ID: ID, ID: ID,
Secret: secret, Secret: secret,
Endpoint: endpoint, Endpoint: endpoint,
UserToken: userToken,
} }
client := &Client{ client := &Client{
@@ -119,7 +121,7 @@ func NewClient(clientType string, ID, secret string, endpoint string, pingInterv
isConnected: false, isConnected: false,
pingInterval: pingInterval, pingInterval: pingInterval,
pingTimeout: pingTimeout, pingTimeout: pingTimeout,
clientType: clientType, clientType: "olm",
} }
// Apply options before loading config // Apply options before loading config
@@ -263,18 +265,10 @@ func (c *Client) getToken() (string, error) {
var tokenData map[string]interface{} var tokenData map[string]interface{}
// Get a new token
if c.clientType == "newt" {
tokenData = map[string]interface{}{
"newtId": c.config.ID,
"secret": c.config.Secret,
}
} else if c.clientType == "olm" {
tokenData = map[string]interface{}{ tokenData = map[string]interface{}{
"olmId": c.config.ID, "olmId": c.config.ID,
"secret": c.config.Secret, "secret": c.config.Secret,
} }
}
jsonData, err := json.Marshal(tokenData) jsonData, err := json.Marshal(tokenData)
if err != nil { if err != nil {
@@ -384,6 +378,9 @@ func (c *Client) establishConnection() error {
q := u.Query() q := u.Query()
q.Set("token", token) q.Set("token", token)
q.Set("clientType", c.clientType) q.Set("clientType", c.clientType)
if c.config.UserToken != "" {
q.Set("userToken", c.config.UserToken)
}
u.RawQuery = q.Encode() u.RawQuery = q.Encode()
// Connect to WebSocket // Connect to WebSocket