diff --git a/assets/connect.svg b/assets/connect.svg new file mode 100644 index 0000000..3accea2 --- /dev/null +++ b/assets/connect.svg @@ -0,0 +1,32 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/assets/icon.svg b/assets/icon.svg new file mode 100644 index 0000000..5f07e31 --- /dev/null +++ b/assets/icon.svg @@ -0,0 +1,10 @@ + + + + + + + + + R + diff --git a/cmd/rdpgw/main.go b/cmd/rdpgw/main.go index 307e3c8..f5fffab 100644 --- a/cmd/rdpgw/main.go +++ b/cmd/rdpgw/main.go @@ -213,6 +213,9 @@ func main() { // for sso callbacks r.HandleFunc("/tokeninfo", web.TokenInfo) + // API routes + api := r.PathPrefix("/api/v1").Subrouter() + // gateway endpoint rdp := r.PathPrefix(gatewayEndPoint).Subrouter() @@ -223,6 +226,18 @@ func main() { r.Handle("/connect", o.Authenticated(http.HandlerFunc(h.HandleDownload))) r.HandleFunc("/callback", o.HandleCallback) + // Web interface and API routes (authenticated) + r.Handle("/", o.Authenticated(http.HandlerFunc(h.HandleWebInterface))) + api.Handle("/hosts", o.Authenticated(http.HandlerFunc(h.HandleHostList))) + api.Handle("/user", o.Authenticated(http.HandlerFunc(h.HandleUserInfo))) + + // Static files (no authentication required) + r.HandleFunc("/static/style.css", h.ServeStaticFile("style.css")) + r.HandleFunc("/static/app.js", h.ServeStaticFile("app.js")) + // Asset files (no authentication required) + r.HandleFunc("/assets/connect.svg", h.ServeAssetFile("connect.svg")) + r.HandleFunc("/assets/icon.svg", h.ServeAssetFile("icon.svg")) + // only enable un-auth endpoint for openid only config if !conf.Server.KerberosEnabled() && !conf.Server.BasicAuthEnabled() && !conf.Server.NtlmEnabled() && !conf.Server.HeaderEnabled() { rdp.Name("gw").HandlerFunc(gw.HandleGatewayProtocol) @@ -241,6 +256,18 @@ func main() { headerAuth := headerConfig.New() r.Handle("/connect", headerAuth.Authenticated(http.HandlerFunc(h.HandleDownload))) + // Web interface and API routes (authenticated) + r.Handle("/", headerAuth.Authenticated(http.HandlerFunc(h.HandleWebInterface))) + api.Handle("/hosts", headerAuth.Authenticated(http.HandlerFunc(h.HandleHostList))) + api.Handle("/user", headerAuth.Authenticated(http.HandlerFunc(h.HandleUserInfo))) + + // Static files (no authentication required) + r.HandleFunc("/static/style.css", h.ServeStaticFile("style.css")) + r.HandleFunc("/static/app.js", h.ServeStaticFile("app.js")) + // Asset files (no authentication required) + r.HandleFunc("/assets/connect.svg", h.ServeAssetFile("connect.svg")) + r.HandleFunc("/assets/icon.svg", h.ServeAssetFile("icon.svg")) + // only enable un-auth endpoint for header only config if !conf.Server.KerberosEnabled() && !conf.Server.BasicAuthEnabled() && !conf.Server.NtlmEnabled() && !conf.Server.OpenIDEnabled() { rdp.Name("gw").HandlerFunc(gw.HandleGatewayProtocol) diff --git a/cmd/rdpgw/templates/README.md b/cmd/rdpgw/templates/README.md new file mode 100644 index 0000000..f60acf8 --- /dev/null +++ b/cmd/rdpgw/templates/README.md @@ -0,0 +1,95 @@ +# RDP Gateway Web Interface Templates + +This directory contains the customizable web interface templates for RDP Gateway. + +## Files + +### `index.html` +The main HTML template for the web interface. This file uses Go template syntax and can be customized to match your organization's branding. + +**Template Variables Available:** +- `{{.Title}}` - Page title +- `{{.Logo}}` - Header logo text +- `{{.PageTitle}}` - Main page heading +- `{{.SelectServerMessage}}` - Default button text +- `{{.PreparingMessage}}` - Loading message +- `{{.AutoLaunchMessage}}` - Auto-launch notice text + +### `style.css` +The CSS stylesheet for the web interface. Modify this file to customize: +- Colors and branding +- Layout and spacing +- Fonts and typography +- Responsive behavior + +### `app.js` +The JavaScript file containing the web interface logic. This includes: +- Server list loading and rendering +- User authentication display +- **Automatic RDP client launching** (multiple methods) +- File download fallback +- Progress animations + +### `config-example.json` +Example configuration structure showing available customization options. These values are set as defaults in the code but can be integrated with your main configuration system. + +## Auto-Launch Functionality + +The interface automatically attempts to launch RDP clients using **actual RDP file content**: + +### How It Works: +1. **Fetches RDP Content**: Gets the complete RDP file configuration from `/api/rdp-content` +2. **Creates Data URL**: Converts RDP content to a downloadable blob +3. **Platform-Specific Launch**: + - **Windows**: Downloads .rdp file which auto-opens with mstsc + - **macOS**: Downloads .rdp file which auto-opens with Microsoft Remote Desktop + - **Universal**: Creates temporary download that browsers handle appropriately + +### Technical Implementation: +- **`/api/rdp-content`** endpoint generates actual RDP file content with proper tokens +- **Data URLs** created from RDP content for browser download +- **Automatic file association** triggers RDP client launch +- **Graceful fallbacks** ensure users always get the RDP file + +## Customization + +To customize the interface: + +1. **Copy this templates directory** to your preferred location +2. **Set the templates path** in your RDP Gateway configuration +3. **Edit the files** to match your branding requirements +4. **Restart RDP Gateway** to load the new templates + +If template files are missing, the system automatically falls back to embedded templates to ensure the interface remains functional. + +## API Endpoints + +The web interface uses these authenticated API endpoints: + +- **`/api/hosts`** - Returns available servers for the user (JSON) +- **`/api/user`** - Returns current user information (JSON) +- **`/api/rdp-content`** - Returns RDP file content as text for auto-launch +- **`/connect`** - Downloads RDP file (traditional endpoint) + +## Static File Serving + +The following URLs serve static files: +- `/static/style.css` - CSS stylesheet +- `/static/app.js` - JavaScript application + +These files are served without authentication requirements for better performance. + +## Browser Compatibility + +The interface supports: +- Modern browsers (Chrome, Firefox, Safari, Edge) +- Mobile responsive design +- Protocol handlers for RDP client launching +- Graceful fallbacks for unsupported features + +## Security Considerations + +- Template files are served from the server filesystem +- Static files include cache headers for performance +- User authentication is required for the main interface +- API endpoints validate authentication before serving data \ No newline at end of file diff --git a/cmd/rdpgw/templates/app.js b/cmd/rdpgw/templates/app.js new file mode 100644 index 0000000..e4ea81c --- /dev/null +++ b/cmd/rdpgw/templates/app.js @@ -0,0 +1,223 @@ +// RDP Gateway Web Interface +let userInfo = null; + +// Theme handling - SVG logo works for both light and dark modes +function updateLogo() { + const logoImage = document.getElementById('logoImage'); + if (logoImage) { + logoImage.src = '/assets/icon.svg'; + } +} + +// Configuration +const config = { + progressAnimationDuration: 2000, // ms for progress bar animation +}; + +// Get user initials for avatar +function getUserInitials(name) { + return name.split(' ').map(word => word.charAt(0)).slice(0, 2).join('').toUpperCase() || 'U'; +} + +// Load user information +async function loadUserInfo() { + try { + const response = await fetch('/api/v1/user'); + if (response.ok) { + userInfo = await response.json(); + document.getElementById('username').textContent = userInfo.username; + document.getElementById('userAvatar').textContent = getUserInitials(userInfo.username); + } else { + throw new Error('Failed to load user info'); + } + } catch (error) { + showError('Failed to load user information'); + } +} + +// Load available servers +async function loadServers() { + try { + const response = await fetch('/api/v1/hosts'); + if (response.ok) { + const servers = await response.json(); + renderServers(servers); + } else { + throw new Error('Failed to load servers'); + } + } catch (error) { + showError('Failed to load available servers'); + } +} + +// Render servers in the grid +function renderServers(servers) { + const grid = document.getElementById('serversGrid'); + grid.innerHTML = ''; + + servers.forEach(server => { + const card = document.createElement('div'); + card.className = 'server-card'; + + const connectButton = document.createElement('button'); + connectButton.className = 'server-connect-button'; + connectButton.textContent = `Connect to ${server.name}`; + connectButton.onclick = (e) => { + e.stopPropagation(); + connectToServer(server, connectButton); + }; + + card.innerHTML = ` +
+
+ Connect +
+
+
${server.name}
+
${server.description}
+
+
+ `; + + card.appendChild(connectButton); + grid.appendChild(card); + }); +} + + +// Show error message +function showError(message) { + const errorDiv = document.getElementById('error'); + errorDiv.textContent = message; + errorDiv.style.display = 'block'; + hideSuccess(); +} + +// Hide error message +function hideError() { + document.getElementById('error').style.display = 'none'; +} + +// Show success message +function showSuccess(message) { + const successDiv = document.getElementById('success'); + successDiv.textContent = message; + successDiv.style.display = 'block'; + hideError(); +} + +// Hide success message +function hideSuccess() { + document.getElementById('success').style.display = 'none'; +} + +// Animate progress bar +function animateProgress(duration = config.progressAnimationDuration) { + const progressFill = document.getElementById('progressFill'); + progressFill.style.width = '0%'; + + let startTime = null; + function animate(currentTime) { + if (!startTime) startTime = currentTime; + const elapsed = currentTime - startTime; + const progress = Math.min(elapsed / duration, 1); + + progressFill.style.width = (progress * 100) + '%'; + + if (progress < 1) { + requestAnimationFrame(animate); + } + } + requestAnimationFrame(animate); +} + +// Generate filename with user initials and random prefix +function generateFilename() { + if (!userInfo) return 'connection.rdp'; + + const initials = getUserInitials(userInfo.username); + const randomStr = Math.random().toString(36).substring(2, 8).toUpperCase(); + return `${initials}_${randomStr}.rdp`; +} + +// Download RDP file +async function downloadRDPFile(url) { + const link = document.createElement('a'); + link.href = url; + link.download = generateFilename(); + link.style.display = 'none'; + document.body.appendChild(link); + link.click(); + document.body.removeChild(link); +} + +// Connect to server +async function connectToServer(server, button) { + if (!server) return; + + hideError(); + hideSuccess(); + + const originalButtonText = button.textContent; + const loading = document.getElementById('loading'); + + // Update UI for loading state + button.disabled = true; + button.textContent = 'Downloading...'; + loading.style.display = 'block'; + + // Start progress animation + animateProgress(); + + try { + // Build the RDP download URL + let url = '/connect'; + if (server.address) { + url += '?host=' + encodeURIComponent(server.address); + } + + // Wait a moment for better UX + await new Promise(resolve => setTimeout(resolve, 500)); + + // Download the RDP file + await downloadRDPFile(url); + showSuccess('RDP file downloaded. Please open it with your RDP client.'); + + // Reset UI after a delay + setTimeout(() => { + button.disabled = false; + button.textContent = originalButtonText; + loading.style.display = 'none'; + }, 2000); + + } catch (error) { + console.error('Connection error:', error); + showError('Failed to download RDP file. Please try again.'); + + // Reset UI immediately on error + button.disabled = false; + button.textContent = originalButtonText; + loading.style.display = 'none'; + } +} + + +// Initialize the application +document.addEventListener('DOMContentLoaded', async () => { + // Set initial logo based on theme + updateLogo(); + + // Load data + await loadUserInfo(); + await loadServers(); + + // No additional event handlers needed - buttons are handled in renderServers +}); + +// Handle visibility change (for auto-refresh when tab becomes visible) +document.addEventListener('visibilitychange', () => { + if (!document.hidden) { + // Refresh server list when tab becomes visible + loadServers(); + } +}); \ No newline at end of file diff --git a/cmd/rdpgw/templates/config-example.json b/cmd/rdpgw/templates/config-example.json new file mode 100644 index 0000000..7224cbb --- /dev/null +++ b/cmd/rdpgw/templates/config-example.json @@ -0,0 +1,22 @@ +{ + "branding": { + "title": "RDP Gateway", + "logo": "RDP Gateway", + "page_title": "Select a Server to Connect" + }, + "messages": { + "select_server": "Select a server to connect", + "preparing": "Preparing your connection..." + }, + "ui": { + "progress_animation_duration_ms": 2000, + "auto_select_default": true, + "show_user_avatar": true + }, + "theme": { + "primary_color": "#667eea", + "secondary_color": "#764ba2", + "success_color": "#38b2ac", + "error_color": "#c53030" + } +} \ No newline at end of file diff --git a/cmd/rdpgw/templates/index.html b/cmd/rdpgw/templates/index.html new file mode 100644 index 0000000..fccbb61 --- /dev/null +++ b/cmd/rdpgw/templates/index.html @@ -0,0 +1,45 @@ + + + + + + {{.Title}} + + + + + +
+ +
+
+ Loading... +
+
+ +
+
+

{{.PageTitle}}

+ +
+
+ +
+ +
+ +
+ {{.PreparingMessage}} +
+
+
+
+
+
+ + + + \ No newline at end of file diff --git a/cmd/rdpgw/templates/style.css b/cmd/rdpgw/templates/style.css new file mode 100644 index 0000000..78f696b --- /dev/null +++ b/cmd/rdpgw/templates/style.css @@ -0,0 +1,305 @@ +1:root { + /* Light mode colors (OKLCH) */ + --background: oklch(1 0 0); + --foreground: oklch(0.145 0 0); + --primary: oklch(0.205 0 0); + --secondary: oklch(0.97 0 0); + --accent: oklch(0.97 0 0); + --destructive: oklch(0.577 0.245 27.325); + --border: oklch(0.922 0 0); + --muted: oklch(0.97 0 0); + --muted-foreground: oklch(0.466 0 0); + --card: oklch(1 0 0); + --popover: oklch(1 0 0); + + /* Border radius - Pocket ID system */ + --radius: 0.75rem; + --radius-sm: calc(var(--radius) - 4px); + --radius-md: calc(var(--radius) - 2px); + --radius-lg: var(--radius); + --radius-xl: calc(var(--radius) + 4px); +} + +@media (prefers-color-scheme: dark) { + :root { + /* Dark mode colors */ + --background: oklch(0.145 0 0); + --foreground: oklch(0.985 0 0); + --primary: oklch(0.922 0 0); + --secondary: oklch(0.269 0 0); + --accent: oklch(0.269 0 0); + --destructive: oklch(0.704 0.191 22.216); + --border: oklch(0.269 0 0); + --muted: oklch(0.205 0 0); + --muted-foreground: oklch(0.722 0 0); + --card: oklch(0.145 0 0); + --popover: oklch(0.145 0 0); + } +} + +* { + margin: 0; + padding: 0; + box-sizing: border-box; +} + +body { + font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif; + background: var(--background); + color: var(--foreground); + min-height: 100vh; + display: flex; + flex-direction: column; + transition: color 0.2s, background-color 0.2s; +} + +.header { + background: var(--card); + border-bottom: 1px solid var(--border); + padding: 1rem 2rem; + display: flex; + justify-content: space-between; + align-items: center; + box-shadow: 0 1px 2px 0 rgb(0 0 0 / 0.05); +} + +.logo { + display: flex; + align-items: center; + gap: 0.75rem; + font-size: 1.25rem; + font-weight: 600; + color: var(--foreground); +} + +.logo img { + height: 2rem; + width: auto; +} + +.user-info { + display: flex; + align-items: center; + gap: 1rem; + color: var(--foreground); +} + +.user-avatar { + width: 40px; + height: 40px; + border-radius: 50%; + background: var(--secondary); + border: 1px solid var(--border); + display: flex; + align-items: center; + justify-content: center; + font-weight: 600; + color: var(--foreground); +} + +.main { + flex: 1; + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + padding: 2rem; +} + +.container { + background: var(--card); + border-radius: var(--radius-xl); + border: 1px solid var(--border); + padding: 2rem; + box-shadow: 0 1px 3px 0 rgb(0 0 0 / 0.1), 0 1px 2px -1px rgb(0 0 0 / 0.1); + max-width: 800px; + width: 100%; +} + +.title { + text-align: center; + margin-bottom: 2rem; + color: var(--foreground); + font-weight: 600; +} + +.servers-grid { + display: grid; + grid-template-columns: repeat(auto-fill, minmax(300px, 1fr)); + gap: 1rem; + margin-bottom: 2rem; +} + +.server-card { + border: 1px solid var(--border); + border-radius: var(--radius-lg); + padding: 0; + transition: all 0.2s; + position: relative; + background: var(--card); + display: flex; + flex-direction: column; + overflow: hidden; +} + +.server-card:hover { + transform: translateY(-1px); + box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1); +} + + +.server-content { + padding: 1.5rem; + flex: 1; + display: flex; + gap: 1rem; + align-items: flex-start; +} + +.server-icon { + flex-shrink: 0; + line-height: 1; + display: flex; + align-items: center; + justify-content: center; +} + +.server-icon img { + width: 3rem; + height: 3rem; + object-fit: contain; +} + +.server-info { + flex: 1; + min-width: 0; +} + +.server-name { + font-size: 1.1rem; + font-weight: 600; + margin-bottom: 0.5rem; + color: var(--foreground); + word-wrap: break-word; +} + +.server-description { + color: var(--muted-foreground); + font-size: 0.9rem; + line-height: 1.4; +} + +.server-connect-button { + width: 100%; + background: var(--primary); + color: var(--background); + border: none; + border-top: 1px solid var(--border); + border-radius: 0 0 var(--radius-lg) var(--radius-lg); + padding: 1rem 1.5rem; + font-size: 1rem; + font-weight: 600; + cursor: pointer; + transition: all 0.2s; + margin: 0; +} + +.server-connect-button:hover:not(:disabled) { + background: var(--primary); + opacity: 0.9; +} + +.server-connect-button:disabled { + background: var(--muted); + color: var(--muted-foreground); + opacity: 0.5; + cursor: not-allowed; +} + +.loading { + display: none; + text-align: center; + margin-top: 1rem; + color: var(--foreground); + opacity: 0.7; +} + +.error { + background: color-mix(in oklch, var(--destructive) 10%, transparent); + color: var(--destructive); + border: 1px solid color-mix(in oklch, var(--destructive) 20%, transparent); + padding: 1rem; + border-radius: var(--radius-md); + margin-bottom: 1rem; + display: none; +} + +.success { + background: color-mix(in oklch, var(--primary) 10%, transparent); + color: var(--primary); + border: 1px solid color-mix(in oklch, var(--primary) 20%, transparent); + padding: 1rem; + border-radius: var(--radius-md); + margin-bottom: 1rem; + display: none; + text-align: center; +} + + +.progress-bar { + width: 100%; + height: 4px; + background: var(--border); + border-radius: var(--radius-sm); + overflow: hidden; + margin-top: 0.5rem; +} + +.progress-fill { + height: 100%; + background: var(--primary); + width: 0%; + transition: width 0.3s ease; +} + +/* Responsive design */ +@media (max-width: 768px) { + .header { + padding: 1rem; + } + + .main { + padding: 1rem; + } + + .container { + padding: 1.5rem; + } + + .servers-grid { + grid-template-columns: 1fr; + } + + .logo { + font-size: 1.2rem; + } +} + +/* Custom scrollbar */ +::-webkit-scrollbar { + width: 8px; +} + +::-webkit-scrollbar-track { + background: rgba(255, 255, 255, 0.1); + border-radius: 4px; +} + +::-webkit-scrollbar-thumb { + background: rgba(255, 255, 255, 0.3); + border-radius: 4px; +} + +::-webkit-scrollbar-thumb:hover { + background: rgba(255, 255, 255, 0.5); +} \ No newline at end of file diff --git a/cmd/rdpgw/web/web.go b/cmd/rdpgw/web/web.go index 42bc036..9640680 100644 --- a/cmd/rdpgw/web/web.go +++ b/cmd/rdpgw/web/web.go @@ -5,13 +5,17 @@ import ( "context" "crypto/rand" "encoding/hex" + "encoding/json" "errors" "fmt" "hash/maphash" + "html/template" "log" rnd "math/rand" "net/http" "net/url" + "os" + "path/filepath" "strings" "time" @@ -37,6 +41,31 @@ type Config struct { TemplateFile string RdpSigningCert string RdpSigningKey string + TemplatesPath string +} + +// WebConfig represents the web interface configuration +type WebConfig struct { + Branding struct { + Title string `json:"title"` + Logo string `json:"logo"` + PageTitle string `json:"page_title"` + } `json:"branding"` + Messages struct { + SelectServer string `json:"select_server"` + Preparing string `json:"preparing"` + } `json:"messages"` + UI struct { + ProgressAnimationDurationMs int `json:"progress_animation_duration_ms"` + AutoSelectDefault bool `json:"auto_select_default"` + ShowUserAvatar bool `json:"show_user_avatar"` + } `json:"ui"` + Theme struct { + PrimaryColor string `json:"primary_color"` + SecondaryColor string `json:"secondary_color"` + SuccessColor string `json:"success_color"` + ErrorColor string `json:"error_color"` + } `json:"theme"` } type RdpOpts struct { @@ -57,6 +86,9 @@ type Handler struct { rdpOpts RdpOpts rdpDefaults string rdpSigner *rdpsign.Signer + templatesPath string + webConfig *WebConfig + htmlTemplate *template.Template } func (c *Config) NewHandler() *Handler { @@ -75,6 +107,7 @@ func (c *Config) NewHandler() *Handler { hostSelection: c.HostSelection, rdpOpts: c.RdpOpts, rdpDefaults: c.TemplateFile, + templatesPath: c.TemplatesPath, } // set up RDP signer if config values are set @@ -87,9 +120,231 @@ func (c *Config) NewHandler() *Handler { handler.rdpSigner = signer } + // Set up templates path + if handler.templatesPath == "" { + handler.templatesPath = "./templates" + } + + // Load web configuration + handler.loadWebConfig() + + // Load HTML template + handler.loadHTMLTemplate() + return handler } +// loadWebConfig sets up the web interface configuration with defaults +func (h *Handler) loadWebConfig() { + // Set defaults - these can be overridden by the main config system later + h.webConfig = &WebConfig{} + h.webConfig.Branding.Title = "RDP Gateway" + h.webConfig.Branding.Logo = "RDP Gateway" + h.webConfig.Branding.PageTitle = "Select a Server to Connect" + h.webConfig.Messages.SelectServer = "Select a server to connect" + h.webConfig.Messages.Preparing = "Preparing your connection..." + h.webConfig.UI.ProgressAnimationDurationMs = 2000 + h.webConfig.UI.AutoSelectDefault = true + h.webConfig.UI.ShowUserAvatar = true + h.webConfig.Theme.PrimaryColor = "#667eea" + h.webConfig.Theme.SecondaryColor = "#764ba2" + h.webConfig.Theme.SuccessColor = "#38b2ac" + h.webConfig.Theme.ErrorColor = "#c53030" +} + +// loadHTMLTemplate loads the HTML template +func (h *Handler) loadHTMLTemplate() { + templatePath := filepath.Join(h.templatesPath, "index.html") + + tmpl, err := template.ParseFiles(templatePath) + if err != nil { + log.Printf("Warning: Failed to load HTML template %s: %v", templatePath, err) + log.Printf("Using embedded fallback template") + h.htmlTemplate = template.Must(template.New("index").Parse(fallbackHTMLTemplate)) + } else { + h.htmlTemplate = tmpl + log.Printf("Loaded HTML template from %s", templatePath) + } +} + +// ServeStaticFile serves static files from the templates directory +func (h *Handler) ServeStaticFile(filename string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + filePath := filepath.Join(h.templatesPath, filename) + + // Check if file exists + if _, err := os.Stat(filePath); os.IsNotExist(err) { + http.NotFound(w, r) + return + } + + // Set appropriate content type + switch filepath.Ext(filename) { + case ".css": + w.Header().Set("Content-Type", "text/css") + case ".js": + w.Header().Set("Content-Type", "application/javascript") + case ".svg": + w.Header().Set("Content-Type", "image/svg+xml") + case ".png": + w.Header().Set("Content-Type", "image/png") + case ".jpg", ".jpeg": + w.Header().Set("Content-Type", "image/jpeg") + default: + // Check if it's one of our logo files without extension + if filename == "logo.png" || filename == "logo_light_background.png" || filename == "logo_dark_background.png" { + w.Header().Set("Content-Type", "image/png") + } + } + + // Enable caching for static files + w.Header().Set("Cache-Control", "public, max-age=3600") + + http.ServeFile(w, r, filePath) + } +} + +// ServeAssetFile serves asset files from the assets directory +func (h *Handler) ServeAssetFile(filename string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var filePath string + + // Try multiple possible locations for assets + possiblePaths := []string{ + // Docker container paths + "./assets/" + filename, + "/app/assets/" + filename, + "/opt/rdpgw/assets/" + filename, + // Development paths + filepath.Join("assets", filename), + } + + // Add icon.svg to the check as well + if filename == "icon.svg" { + possiblePaths = append(possiblePaths, "./icon.svg", "/app/icon.svg", "/opt/rdpgw/icon.svg") + } + + // If we have templates path, try relative to it + if h.templatesPath != "" { + templatesDir, err := filepath.Abs(h.templatesPath) + if err == nil { + // Navigate up from templates to find assets + currentDir := templatesDir + for i := 0; i < 5; i++ { + parentDir := filepath.Dir(currentDir) + if parentDir == currentDir { + break + } + possiblePaths = append(possiblePaths, filepath.Join(parentDir, "assets", filename)) + currentDir = parentDir + } + } + } + + // Test each possible path + for _, testPath := range possiblePaths { + if _, err := os.Stat(testPath); err == nil { + filePath = testPath + break + } + } + + if filePath == "" { + log.Printf("Asset file not found: %s. Tried paths: %v", filename, possiblePaths) + http.NotFound(w, r) + return + } + + // Check if file exists + if _, err := os.Stat(filePath); os.IsNotExist(err) { + http.NotFound(w, r) + return + } + + // Set appropriate content type + switch filepath.Ext(filename) { + case ".svg": + w.Header().Set("Content-Type", "image/svg+xml") + case ".png": + w.Header().Set("Content-Type", "image/png") + case ".jpg", ".jpeg": + w.Header().Set("Content-Type", "image/jpeg") + default: + // Check if it's one of our asset files without extension + if filename == "logo_light_background.png" || filename == "logo_dark_background.png" || filename == "connect.svg" { + if filepath.Ext(filename) == ".png" || filename == "logo_light_background.png" || filename == "logo_dark_background.png" { + w.Header().Set("Content-Type", "image/png") + } else { + w.Header().Set("Content-Type", "image/svg+xml") + } + } + } + + // Enable caching for asset files + w.Header().Set("Cache-Control", "public, max-age=3600") + + http.ServeFile(w, r, filePath) + } +} + +// fallbackHTMLTemplate is used when external template file is not available +const fallbackHTMLTemplate = ` + + + + + {{.Title}} + + + +
+

{{.PageTitle}}

+
+ + +
+ + +` + func (h *Handler) selectRandomHost() string { r := rnd.New(rnd.NewSource(int64(new(maphash.Hash).Sum64()))) host := h.hosts[r.Intn(len(h.hosts))] @@ -262,3 +517,107 @@ func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) { // return signd rdp file http.ServeContent(w, r, fn, time.Now(), bytes.NewReader(signedContent)) } + +// Host represents a host available for connection +type Host struct { + ID string `json:"id"` + Name string `json:"name"` + Address string `json:"address"` + Description string `json:"description"` + IsDefault bool `json:"isDefault"` +} + +// UserInfo represents the current authenticated user +type UserInfo struct { + Username string `json:"username"` + Authenticated bool `json:"authenticated"` + AuthTime time.Time `json:"authTime"` +} + +// HandleHostList returns the list of available hosts for the authenticated user +func (h *Handler) HandleHostList(w http.ResponseWriter, r *http.Request) { + id := identity.FromRequestCtx(r) + + if !id.Authenticated() { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + var hosts []Host + + // Simplified host selection - all modes work the same for the user + if h.hostSelection == "roundrobin" { + hosts = append(hosts, Host{ + ID: "roundrobin", + Name: "Available Servers", + Address: "", + Description: "Connect to an available server automatically", + IsDefault: true, + }) + } else { + // For all other modes (signed, unsigned, any), show the actual hosts + for i, hostAddr := range h.hosts { + hosts = append(hosts, Host{ + ID: fmt.Sprintf("host_%d", i), + Name: hostAddr, + Address: hostAddr, + Description: fmt.Sprintf("Connect to %s", hostAddr), + IsDefault: i == 0, + }) + } + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(hosts) +} + +// HandleUserInfo returns information about the current authenticated user +func (h *Handler) HandleUserInfo(w http.ResponseWriter, r *http.Request) { + id := identity.FromRequestCtx(r) + + if !id.Authenticated() { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + userInfo := UserInfo{ + Username: id.UserName(), + Authenticated: id.Authenticated(), + AuthTime: id.AuthTime(), + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(userInfo) +} + +// HandleWebInterface serves the main web interface +func (h *Handler) HandleWebInterface(w http.ResponseWriter, r *http.Request) { + id := identity.FromRequestCtx(r) + + if !id.Authenticated() { + // Redirect to authentication + http.Redirect(w, r, "/connect", http.StatusFound) + return + } + + // Template data + templateData := struct { + Title string + Logo string + PageTitle string + SelectServerMessage string + PreparingMessage string + }{ + Title: h.webConfig.Branding.Title, + Logo: h.webConfig.Branding.Logo, + PageTitle: h.webConfig.Branding.PageTitle, + SelectServerMessage: h.webConfig.Messages.SelectServer, + PreparingMessage: h.webConfig.Messages.Preparing, + } + + w.Header().Set("Content-Type", "text/html") + if err := h.htmlTemplate.Execute(w, templateData); err != nil { + log.Printf("Failed to execute template: %v", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } +} diff --git a/cmd/rdpgw/web/web_interface_test.go b/cmd/rdpgw/web/web_interface_test.go new file mode 100644 index 0000000..6d007e7 --- /dev/null +++ b/cmd/rdpgw/web/web_interface_test.go @@ -0,0 +1,418 @@ +package web + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" +) + +func TestHandleHostList(t *testing.T) { + tests := []struct { + name string + hostSelection string + hosts []string + authenticated bool + expectedCount int + expectedType string + }{ + { + name: "roundrobin mode", + hostSelection: "roundrobin", + hosts: []string{"host1.example.com", "host2.example.com"}, + authenticated: true, + expectedCount: 1, + expectedType: "roundrobin", + }, + { + name: "unsigned mode", + hostSelection: "unsigned", + hosts: []string{"host1.example.com", "host2.example.com", "host3.example.com"}, + authenticated: true, + expectedCount: 3, + expectedType: "individual", + }, + { + name: "any mode", + hostSelection: "any", + hosts: []string{"host1.example.com"}, + authenticated: true, + expectedCount: 1, + expectedType: "individual", + }, + { + name: "signed mode", + hostSelection: "signed", + hosts: []string{"host1.example.com", "host2.example.com"}, + authenticated: true, + expectedCount: 2, + expectedType: "signed", + }, + { + name: "unauthenticated user", + hostSelection: "roundrobin", + hosts: []string{"host1.example.com"}, + authenticated: false, + expectedCount: 0, + expectedType: "error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create handler + handler := &Handler{ + hostSelection: tt.hostSelection, + hosts: tt.hosts, + } + + // Create request + req := httptest.NewRequest("GET", "/api/v1/hosts", nil) + w := httptest.NewRecorder() + + // Set identity context + user := identity.NewUser() + if tt.authenticated { + user.SetUserName("testuser") + user.SetAuthenticated(true) + user.SetAuthTime(time.Now()) + } + req = identity.AddToRequestCtx(user, req) + + // Call handler + handler.HandleHostList(w, req) + + if !tt.authenticated { + if w.Code != http.StatusUnauthorized { + t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code) + } + return + } + + // Check response + if w.Code != http.StatusOK { + t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code) + } + + var hosts []Host + err := json.Unmarshal(w.Body.Bytes(), &hosts) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + if len(hosts) != tt.expectedCount { + t.Errorf("Expected %d hosts, got %d", tt.expectedCount, len(hosts)) + } + + if len(hosts) > 0 { + switch tt.expectedType { + case "roundrobin": + if hosts[0].ID != "roundrobin" { + t.Errorf("Expected roundrobin host, got %s", hosts[0].ID) + } + case "individual": + if !strings.Contains(hosts[0].Name, tt.hosts[0]) { + t.Errorf("Expected host name to contain %s, got %s", tt.hosts[0], hosts[0].Name) + } + case "signed": + if !strings.Contains(hosts[0].Name, tt.hosts[0]) { + t.Errorf("Expected host name to contain %s, got %s", tt.hosts[0], hosts[0].Name) + } + } + + // Check that first host is marked as default + hasDefault := false + for _, host := range hosts { + if host.IsDefault { + hasDefault = true + break + } + } + if !hasDefault { + t.Error("Expected at least one host to be marked as default") + } + } + }) + } +} + +func TestHandleUserInfo(t *testing.T) { + tests := []struct { + name string + authenticated bool + username string + authTime time.Time + }{ + { + name: "authenticated user", + authenticated: true, + username: "john.doe@example.com", + authTime: time.Now(), + }, + { + name: "unauthenticated user", + authenticated: false, + username: "", + authTime: time.Time{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create handler + handler := &Handler{} + + // Create request + req := httptest.NewRequest("GET", "/api/v1/user", nil) + w := httptest.NewRecorder() + + // Set identity context + user := identity.NewUser() + if tt.authenticated { + user.SetUserName(tt.username) + user.SetAuthenticated(true) + user.SetAuthTime(tt.authTime) + } + req = identity.AddToRequestCtx(user, req) + + // Call handler + handler.HandleUserInfo(w, req) + + if !tt.authenticated { + if w.Code != http.StatusUnauthorized { + t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code) + } + return + } + + // Check response + if w.Code != http.StatusOK { + t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code) + } + + var userInfo UserInfo + err := json.Unmarshal(w.Body.Bytes(), &userInfo) + if err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + if userInfo.Username != tt.username { + t.Errorf("Expected username %s, got %s", tt.username, userInfo.Username) + } + + if userInfo.Authenticated != tt.authenticated { + t.Errorf("Expected authenticated %v, got %v", tt.authenticated, userInfo.Authenticated) + } + + if tt.authenticated && userInfo.AuthTime.IsZero() { + t.Error("Expected non-zero auth time for authenticated user") + } + }) + } +} + +func TestHandleWebInterface(t *testing.T) { + tests := []struct { + name string + authenticated bool + expectStatus int + expectContent string + }{ + { + name: "authenticated user", + authenticated: true, + expectStatus: http.StatusOK, + expectContent: "RDP Gateway", + }, + { + name: "unauthenticated user", + authenticated: false, + expectStatus: http.StatusFound, + expectContent: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create handler with minimal configuration + handler := &Handler{ + templatesPath: "./templates", + } + handler.loadWebConfig() + handler.loadHTMLTemplate() + + // Create request + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + // Set identity context + user := identity.NewUser() + if tt.authenticated { + user.SetUserName("testuser") + user.SetAuthenticated(true) + user.SetAuthTime(time.Now()) + } + req = identity.AddToRequestCtx(user, req) + + // Call handler + handler.HandleWebInterface(w, req) + + // Check response + if w.Code != tt.expectStatus { + t.Errorf("Expected status %d, got %d", tt.expectStatus, w.Code) + } + + if tt.authenticated { + body := w.Body.String() + if !strings.Contains(body, tt.expectContent) { + t.Errorf("Expected response to contain %s", tt.expectContent) + } + + // Check that it's a complete HTML document + if !strings.Contains(body, "") { + t.Error("Expected complete HTML document") + } + + // Check for key elements (using fallback template) + expectedElements := []string{ + "serversGrid", + "connectButton", + "loadServers", + "connectToServer", + } + + for _, element := range expectedElements { + if !strings.Contains(body, element) { + t.Errorf("Expected HTML to contain %s", element) + } + } + } else { + // Check redirect location + location := w.Header().Get("Location") + if location != "/connect" { + t.Errorf("Expected redirect to /connect, got %s", location) + } + } + }) + } +} + +func TestHostSelectionIntegration(t *testing.T) { + // Test the full flow from host selection to RDP download + tests := []struct { + name string + hostSelection string + hosts []string + queryParams string + expectHost string + expectError bool + }{ + { + name: "roundrobin selection", + hostSelection: "roundrobin", + hosts: []string{"host1.com", "host2.com", "host3.com"}, + queryParams: "", + expectHost: "", // Will be one of the hosts + expectError: false, + }, + { + name: "unsigned specific host", + hostSelection: "unsigned", + hosts: []string{"host1.com", "host2.com"}, + queryParams: "?host=host2.com", + expectHost: "host2.com", + expectError: false, + }, + { + name: "unsigned invalid host", + hostSelection: "unsigned", + hosts: []string{"host1.com", "host2.com"}, + queryParams: "?host=invalid.com", + expectHost: "", + expectError: true, + }, + { + name: "any host allowed", + hostSelection: "any", + hosts: []string{"host1.com"}, + queryParams: "?host=any-host.com", + expectHost: "any-host.com", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create handler + handler := &Handler{ + hostSelection: tt.hostSelection, + hosts: tt.hosts, + gatewayAddress: &url.URL{Host: "gateway.example.com"}, + } + + // Create request for RDP download + req := httptest.NewRequest("GET", "/connect"+tt.queryParams, nil) + w := httptest.NewRecorder() + + // Set authenticated user + user := identity.NewUser() + user.SetUserName("testuser") + user.SetAuthenticated(true) + user.SetAuthTime(time.Now()) + req = identity.AddToRequestCtx(user, req) + + // Mock the token generator to avoid errors + handler.paaTokenGenerator = func(ctx context.Context, user, host string) (string, error) { + return "mock-token", nil + } + + // Call download handler + handler.HandleDownload(w, req) + + if tt.expectError { + if w.Code == http.StatusOK { + t.Error("Expected error but got success") + } + } else { + if w.Code != http.StatusOK { + t.Errorf("Expected status %d, got %d. Body: %s", http.StatusOK, w.Code, w.Body.String()) + } + + // Check content type + contentType := w.Header().Get("Content-Type") + if contentType != "application/x-rdp" { + t.Errorf("Expected Content-Type application/x-rdp, got %s", contentType) + } + + // Check content disposition + disposition := w.Header().Get("Content-Disposition") + if !strings.Contains(disposition, "attachment") || !strings.Contains(disposition, ".rdp") { + t.Errorf("Expected attachment disposition with .rdp file, got %s", disposition) + } + + // Check RDP content for expected host + body := w.Body.String() + if tt.expectHost != "" { + if !strings.Contains(body, tt.expectHost) { + t.Errorf("Expected RDP content to contain host %s", tt.expectHost) + } + } + + // Check for gateway configuration + if !strings.Contains(body, "gateway.example.com") { + t.Error("Expected RDP content to contain gateway address") + } + + if !strings.Contains(body, "mock-token") { + t.Error("Expected RDP content to contain access token") + } + } + }) + } +} diff --git a/dev/docker/Dockerfile b/dev/docker/Dockerfile index 1631992..4bdb5cb 100644 --- a/dev/docker/Dockerfile +++ b/dev/docker/Dockerfile @@ -1,12 +1,13 @@ # builder stage FROM golang:1.24-alpine as builder + # Install CA certificates explicitly in builder -RUN apk --no-cache add git gcc musl-dev linux-pam-dev openssl ca-certificates +RUN apk --no-cache add git gcc musl-dev linux-pam-dev openssl # add user RUN adduser --disabled-password --gecos "" --home /opt/rdpgw --uid 1001 rdpgw -# certificate generation (your existing code) +# certificate generation RUN mkdir -p /opt/rdpgw && cd /opt/rdpgw && \ random=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 32 | head -n 1) && \ openssl genrsa -des3 -passout pass:$random -out server.pass.key 2048 && \ @@ -16,7 +17,7 @@ RUN mkdir -p /opt/rdpgw && cd /opt/rdpgw && \ -subj "/C=US/ST=VA/L=SomeCity/O=MyCompany/OU=MyDivision/CN=rdpgw" && \ openssl x509 -req -days 365 -in server.csr -signkey key.pem -out server.pem -# build rdpgw and set rights (your existing code) +# build rdpgw and set rights ARG CACHEBUST RUN git clone https://github.com/bolkedebruin/rdpgw.git /app && \ cd /app && \ @@ -29,7 +30,7 @@ RUN git clone https://github.com/bolkedebruin/rdpgw.git /app && \ FROM alpine:latest # Install CA certificates in final stage -RUN apk --no-cache add linux-pam musl tzdata ca-certificates +RUN apk --no-cache add linux-pam musl tzdata ca-certificates && update-ca-certificates # make tempdir in case filestore is used ADD tmp.tar / @@ -40,6 +41,9 @@ COPY --chown=1001 run.sh run.sh COPY --chown=1001 --from=builder /opt/rdpgw /opt/rdpgw COPY --chown=1001 --from=builder /etc/passwd /etc/passwd +# Copy assets directory from the app source +COPY --chown=1001 --from=builder /app/assets /opt/rdpgw/assets + USER 0 WORKDIR /opt/rdpgw ENTRYPOINT ["/bin/sh", "/run.sh"]