diff --git a/assets/connect.svg b/assets/connect.svg
new file mode 100644
index 0000000..3accea2
--- /dev/null
+++ b/assets/connect.svg
@@ -0,0 +1,32 @@
+
\ No newline at end of file
diff --git a/assets/icon.svg b/assets/icon.svg
new file mode 100644
index 0000000..5f07e31
--- /dev/null
+++ b/assets/icon.svg
@@ -0,0 +1,10 @@
+
diff --git a/cmd/rdpgw/main.go b/cmd/rdpgw/main.go
index 307e3c8..f5fffab 100644
--- a/cmd/rdpgw/main.go
+++ b/cmd/rdpgw/main.go
@@ -213,6 +213,9 @@ func main() {
// for sso callbacks
r.HandleFunc("/tokeninfo", web.TokenInfo)
+ // API routes
+ api := r.PathPrefix("/api/v1").Subrouter()
+
// gateway endpoint
rdp := r.PathPrefix(gatewayEndPoint).Subrouter()
@@ -223,6 +226,18 @@ func main() {
r.Handle("/connect", o.Authenticated(http.HandlerFunc(h.HandleDownload)))
r.HandleFunc("/callback", o.HandleCallback)
+ // Web interface and API routes (authenticated)
+ r.Handle("/", o.Authenticated(http.HandlerFunc(h.HandleWebInterface)))
+ api.Handle("/hosts", o.Authenticated(http.HandlerFunc(h.HandleHostList)))
+ api.Handle("/user", o.Authenticated(http.HandlerFunc(h.HandleUserInfo)))
+
+ // Static files (no authentication required)
+ r.HandleFunc("/static/style.css", h.ServeStaticFile("style.css"))
+ r.HandleFunc("/static/app.js", h.ServeStaticFile("app.js"))
+ // Asset files (no authentication required)
+ r.HandleFunc("/assets/connect.svg", h.ServeAssetFile("connect.svg"))
+ r.HandleFunc("/assets/icon.svg", h.ServeAssetFile("icon.svg"))
+
// only enable un-auth endpoint for openid only config
if !conf.Server.KerberosEnabled() && !conf.Server.BasicAuthEnabled() && !conf.Server.NtlmEnabled() && !conf.Server.HeaderEnabled() {
rdp.Name("gw").HandlerFunc(gw.HandleGatewayProtocol)
@@ -241,6 +256,18 @@ func main() {
headerAuth := headerConfig.New()
r.Handle("/connect", headerAuth.Authenticated(http.HandlerFunc(h.HandleDownload)))
+ // Web interface and API routes (authenticated)
+ r.Handle("/", headerAuth.Authenticated(http.HandlerFunc(h.HandleWebInterface)))
+ api.Handle("/hosts", headerAuth.Authenticated(http.HandlerFunc(h.HandleHostList)))
+ api.Handle("/user", headerAuth.Authenticated(http.HandlerFunc(h.HandleUserInfo)))
+
+ // Static files (no authentication required)
+ r.HandleFunc("/static/style.css", h.ServeStaticFile("style.css"))
+ r.HandleFunc("/static/app.js", h.ServeStaticFile("app.js"))
+ // Asset files (no authentication required)
+ r.HandleFunc("/assets/connect.svg", h.ServeAssetFile("connect.svg"))
+ r.HandleFunc("/assets/icon.svg", h.ServeAssetFile("icon.svg"))
+
// only enable un-auth endpoint for header only config
if !conf.Server.KerberosEnabled() && !conf.Server.BasicAuthEnabled() && !conf.Server.NtlmEnabled() && !conf.Server.OpenIDEnabled() {
rdp.Name("gw").HandlerFunc(gw.HandleGatewayProtocol)
diff --git a/cmd/rdpgw/templates/README.md b/cmd/rdpgw/templates/README.md
new file mode 100644
index 0000000..f60acf8
--- /dev/null
+++ b/cmd/rdpgw/templates/README.md
@@ -0,0 +1,95 @@
+# RDP Gateway Web Interface Templates
+
+This directory contains the customizable web interface templates for RDP Gateway.
+
+## Files
+
+### `index.html`
+The main HTML template for the web interface. This file uses Go template syntax and can be customized to match your organization's branding.
+
+**Template Variables Available:**
+- `{{.Title}}` - Page title
+- `{{.Logo}}` - Header logo text
+- `{{.PageTitle}}` - Main page heading
+- `{{.SelectServerMessage}}` - Default button text
+- `{{.PreparingMessage}}` - Loading message
+- `{{.AutoLaunchMessage}}` - Auto-launch notice text
+
+### `style.css`
+The CSS stylesheet for the web interface. Modify this file to customize:
+- Colors and branding
+- Layout and spacing
+- Fonts and typography
+- Responsive behavior
+
+### `app.js`
+The JavaScript file containing the web interface logic. This includes:
+- Server list loading and rendering
+- User authentication display
+- **Automatic RDP client launching** (multiple methods)
+- File download fallback
+- Progress animations
+
+### `config-example.json`
+Example configuration structure showing available customization options. These values are set as defaults in the code but can be integrated with your main configuration system.
+
+## Auto-Launch Functionality
+
+The interface automatically attempts to launch RDP clients using **actual RDP file content**:
+
+### How It Works:
+1. **Fetches RDP Content**: Gets the complete RDP file configuration from `/api/rdp-content`
+2. **Creates Data URL**: Converts RDP content to a downloadable blob
+3. **Platform-Specific Launch**:
+ - **Windows**: Downloads .rdp file which auto-opens with mstsc
+ - **macOS**: Downloads .rdp file which auto-opens with Microsoft Remote Desktop
+ - **Universal**: Creates temporary download that browsers handle appropriately
+
+### Technical Implementation:
+- **`/api/rdp-content`** endpoint generates actual RDP file content with proper tokens
+- **Data URLs** created from RDP content for browser download
+- **Automatic file association** triggers RDP client launch
+- **Graceful fallbacks** ensure users always get the RDP file
+
+## Customization
+
+To customize the interface:
+
+1. **Copy this templates directory** to your preferred location
+2. **Set the templates path** in your RDP Gateway configuration
+3. **Edit the files** to match your branding requirements
+4. **Restart RDP Gateway** to load the new templates
+
+If template files are missing, the system automatically falls back to embedded templates to ensure the interface remains functional.
+
+## API Endpoints
+
+The web interface uses these authenticated API endpoints:
+
+- **`/api/hosts`** - Returns available servers for the user (JSON)
+- **`/api/user`** - Returns current user information (JSON)
+- **`/api/rdp-content`** - Returns RDP file content as text for auto-launch
+- **`/connect`** - Downloads RDP file (traditional endpoint)
+
+## Static File Serving
+
+The following URLs serve static files:
+- `/static/style.css` - CSS stylesheet
+- `/static/app.js` - JavaScript application
+
+These files are served without authentication requirements for better performance.
+
+## Browser Compatibility
+
+The interface supports:
+- Modern browsers (Chrome, Firefox, Safari, Edge)
+- Mobile responsive design
+- Protocol handlers for RDP client launching
+- Graceful fallbacks for unsupported features
+
+## Security Considerations
+
+- Template files are served from the server filesystem
+- Static files include cache headers for performance
+- User authentication is required for the main interface
+- API endpoints validate authentication before serving data
\ No newline at end of file
diff --git a/cmd/rdpgw/templates/app.js b/cmd/rdpgw/templates/app.js
new file mode 100644
index 0000000..e4ea81c
--- /dev/null
+++ b/cmd/rdpgw/templates/app.js
@@ -0,0 +1,223 @@
+// RDP Gateway Web Interface
+let userInfo = null;
+
+// Theme handling - SVG logo works for both light and dark modes
+function updateLogo() {
+ const logoImage = document.getElementById('logoImage');
+ if (logoImage) {
+ logoImage.src = '/assets/icon.svg';
+ }
+}
+
+// Configuration
+const config = {
+ progressAnimationDuration: 2000, // ms for progress bar animation
+};
+
+// Get user initials for avatar
+function getUserInitials(name) {
+ return name.split(' ').map(word => word.charAt(0)).slice(0, 2).join('').toUpperCase() || 'U';
+}
+
+// Load user information
+async function loadUserInfo() {
+ try {
+ const response = await fetch('/api/v1/user');
+ if (response.ok) {
+ userInfo = await response.json();
+ document.getElementById('username').textContent = userInfo.username;
+ document.getElementById('userAvatar').textContent = getUserInitials(userInfo.username);
+ } else {
+ throw new Error('Failed to load user info');
+ }
+ } catch (error) {
+ showError('Failed to load user information');
+ }
+}
+
+// Load available servers
+async function loadServers() {
+ try {
+ const response = await fetch('/api/v1/hosts');
+ if (response.ok) {
+ const servers = await response.json();
+ renderServers(servers);
+ } else {
+ throw new Error('Failed to load servers');
+ }
+ } catch (error) {
+ showError('Failed to load available servers');
+ }
+}
+
+// Render servers in the grid
+function renderServers(servers) {
+ const grid = document.getElementById('serversGrid');
+ grid.innerHTML = '';
+
+ servers.forEach(server => {
+ const card = document.createElement('div');
+ card.className = 'server-card';
+
+ const connectButton = document.createElement('button');
+ connectButton.className = 'server-connect-button';
+ connectButton.textContent = `Connect to ${server.name}`;
+ connectButton.onclick = (e) => {
+ e.stopPropagation();
+ connectToServer(server, connectButton);
+ };
+
+ card.innerHTML = `
+
+
+

+
+
+
${server.name}
+
${server.description}
+
+
+ `;
+
+ card.appendChild(connectButton);
+ grid.appendChild(card);
+ });
+}
+
+
+// Show error message
+function showError(message) {
+ const errorDiv = document.getElementById('error');
+ errorDiv.textContent = message;
+ errorDiv.style.display = 'block';
+ hideSuccess();
+}
+
+// Hide error message
+function hideError() {
+ document.getElementById('error').style.display = 'none';
+}
+
+// Show success message
+function showSuccess(message) {
+ const successDiv = document.getElementById('success');
+ successDiv.textContent = message;
+ successDiv.style.display = 'block';
+ hideError();
+}
+
+// Hide success message
+function hideSuccess() {
+ document.getElementById('success').style.display = 'none';
+}
+
+// Animate progress bar
+function animateProgress(duration = config.progressAnimationDuration) {
+ const progressFill = document.getElementById('progressFill');
+ progressFill.style.width = '0%';
+
+ let startTime = null;
+ function animate(currentTime) {
+ if (!startTime) startTime = currentTime;
+ const elapsed = currentTime - startTime;
+ const progress = Math.min(elapsed / duration, 1);
+
+ progressFill.style.width = (progress * 100) + '%';
+
+ if (progress < 1) {
+ requestAnimationFrame(animate);
+ }
+ }
+ requestAnimationFrame(animate);
+}
+
+// Generate filename with user initials and random prefix
+function generateFilename() {
+ if (!userInfo) return 'connection.rdp';
+
+ const initials = getUserInitials(userInfo.username);
+ const randomStr = Math.random().toString(36).substring(2, 8).toUpperCase();
+ return `${initials}_${randomStr}.rdp`;
+}
+
+// Download RDP file
+async function downloadRDPFile(url) {
+ const link = document.createElement('a');
+ link.href = url;
+ link.download = generateFilename();
+ link.style.display = 'none';
+ document.body.appendChild(link);
+ link.click();
+ document.body.removeChild(link);
+}
+
+// Connect to server
+async function connectToServer(server, button) {
+ if (!server) return;
+
+ hideError();
+ hideSuccess();
+
+ const originalButtonText = button.textContent;
+ const loading = document.getElementById('loading');
+
+ // Update UI for loading state
+ button.disabled = true;
+ button.textContent = 'Downloading...';
+ loading.style.display = 'block';
+
+ // Start progress animation
+ animateProgress();
+
+ try {
+ // Build the RDP download URL
+ let url = '/connect';
+ if (server.address) {
+ url += '?host=' + encodeURIComponent(server.address);
+ }
+
+ // Wait a moment for better UX
+ await new Promise(resolve => setTimeout(resolve, 500));
+
+ // Download the RDP file
+ await downloadRDPFile(url);
+ showSuccess('RDP file downloaded. Please open it with your RDP client.');
+
+ // Reset UI after a delay
+ setTimeout(() => {
+ button.disabled = false;
+ button.textContent = originalButtonText;
+ loading.style.display = 'none';
+ }, 2000);
+
+ } catch (error) {
+ console.error('Connection error:', error);
+ showError('Failed to download RDP file. Please try again.');
+
+ // Reset UI immediately on error
+ button.disabled = false;
+ button.textContent = originalButtonText;
+ loading.style.display = 'none';
+ }
+}
+
+
+// Initialize the application
+document.addEventListener('DOMContentLoaded', async () => {
+ // Set initial logo based on theme
+ updateLogo();
+
+ // Load data
+ await loadUserInfo();
+ await loadServers();
+
+ // No additional event handlers needed - buttons are handled in renderServers
+});
+
+// Handle visibility change (for auto-refresh when tab becomes visible)
+document.addEventListener('visibilitychange', () => {
+ if (!document.hidden) {
+ // Refresh server list when tab becomes visible
+ loadServers();
+ }
+});
\ No newline at end of file
diff --git a/cmd/rdpgw/templates/config-example.json b/cmd/rdpgw/templates/config-example.json
new file mode 100644
index 0000000..7224cbb
--- /dev/null
+++ b/cmd/rdpgw/templates/config-example.json
@@ -0,0 +1,22 @@
+{
+ "branding": {
+ "title": "RDP Gateway",
+ "logo": "RDP Gateway",
+ "page_title": "Select a Server to Connect"
+ },
+ "messages": {
+ "select_server": "Select a server to connect",
+ "preparing": "Preparing your connection..."
+ },
+ "ui": {
+ "progress_animation_duration_ms": 2000,
+ "auto_select_default": true,
+ "show_user_avatar": true
+ },
+ "theme": {
+ "primary_color": "#667eea",
+ "secondary_color": "#764ba2",
+ "success_color": "#38b2ac",
+ "error_color": "#c53030"
+ }
+}
\ No newline at end of file
diff --git a/cmd/rdpgw/templates/index.html b/cmd/rdpgw/templates/index.html
new file mode 100644
index 0000000..fccbb61
--- /dev/null
+++ b/cmd/rdpgw/templates/index.html
@@ -0,0 +1,45 @@
+
+
+
+
+
+ {{.Title}}
+
+
+
+
+
+
+
+
+
+
{{.PageTitle}}
+
+
+
+
+
+
+
+
+
+ {{.PreparingMessage}}
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/cmd/rdpgw/templates/style.css b/cmd/rdpgw/templates/style.css
new file mode 100644
index 0000000..78f696b
--- /dev/null
+++ b/cmd/rdpgw/templates/style.css
@@ -0,0 +1,305 @@
+1:root {
+ /* Light mode colors (OKLCH) */
+ --background: oklch(1 0 0);
+ --foreground: oklch(0.145 0 0);
+ --primary: oklch(0.205 0 0);
+ --secondary: oklch(0.97 0 0);
+ --accent: oklch(0.97 0 0);
+ --destructive: oklch(0.577 0.245 27.325);
+ --border: oklch(0.922 0 0);
+ --muted: oklch(0.97 0 0);
+ --muted-foreground: oklch(0.466 0 0);
+ --card: oklch(1 0 0);
+ --popover: oklch(1 0 0);
+
+ /* Border radius - Pocket ID system */
+ --radius: 0.75rem;
+ --radius-sm: calc(var(--radius) - 4px);
+ --radius-md: calc(var(--radius) - 2px);
+ --radius-lg: var(--radius);
+ --radius-xl: calc(var(--radius) + 4px);
+}
+
+@media (prefers-color-scheme: dark) {
+ :root {
+ /* Dark mode colors */
+ --background: oklch(0.145 0 0);
+ --foreground: oklch(0.985 0 0);
+ --primary: oklch(0.922 0 0);
+ --secondary: oklch(0.269 0 0);
+ --accent: oklch(0.269 0 0);
+ --destructive: oklch(0.704 0.191 22.216);
+ --border: oklch(0.269 0 0);
+ --muted: oklch(0.205 0 0);
+ --muted-foreground: oklch(0.722 0 0);
+ --card: oklch(0.145 0 0);
+ --popover: oklch(0.145 0 0);
+ }
+}
+
+* {
+ margin: 0;
+ padding: 0;
+ box-sizing: border-box;
+}
+
+body {
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
+ background: var(--background);
+ color: var(--foreground);
+ min-height: 100vh;
+ display: flex;
+ flex-direction: column;
+ transition: color 0.2s, background-color 0.2s;
+}
+
+.header {
+ background: var(--card);
+ border-bottom: 1px solid var(--border);
+ padding: 1rem 2rem;
+ display: flex;
+ justify-content: space-between;
+ align-items: center;
+ box-shadow: 0 1px 2px 0 rgb(0 0 0 / 0.05);
+}
+
+.logo {
+ display: flex;
+ align-items: center;
+ gap: 0.75rem;
+ font-size: 1.25rem;
+ font-weight: 600;
+ color: var(--foreground);
+}
+
+.logo img {
+ height: 2rem;
+ width: auto;
+}
+
+.user-info {
+ display: flex;
+ align-items: center;
+ gap: 1rem;
+ color: var(--foreground);
+}
+
+.user-avatar {
+ width: 40px;
+ height: 40px;
+ border-radius: 50%;
+ background: var(--secondary);
+ border: 1px solid var(--border);
+ display: flex;
+ align-items: center;
+ justify-content: center;
+ font-weight: 600;
+ color: var(--foreground);
+}
+
+.main {
+ flex: 1;
+ display: flex;
+ flex-direction: column;
+ align-items: center;
+ justify-content: center;
+ padding: 2rem;
+}
+
+.container {
+ background: var(--card);
+ border-radius: var(--radius-xl);
+ border: 1px solid var(--border);
+ padding: 2rem;
+ box-shadow: 0 1px 3px 0 rgb(0 0 0 / 0.1), 0 1px 2px -1px rgb(0 0 0 / 0.1);
+ max-width: 800px;
+ width: 100%;
+}
+
+.title {
+ text-align: center;
+ margin-bottom: 2rem;
+ color: var(--foreground);
+ font-weight: 600;
+}
+
+.servers-grid {
+ display: grid;
+ grid-template-columns: repeat(auto-fill, minmax(300px, 1fr));
+ gap: 1rem;
+ margin-bottom: 2rem;
+}
+
+.server-card {
+ border: 1px solid var(--border);
+ border-radius: var(--radius-lg);
+ padding: 0;
+ transition: all 0.2s;
+ position: relative;
+ background: var(--card);
+ display: flex;
+ flex-direction: column;
+ overflow: hidden;
+}
+
+.server-card:hover {
+ transform: translateY(-1px);
+ box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
+}
+
+
+.server-content {
+ padding: 1.5rem;
+ flex: 1;
+ display: flex;
+ gap: 1rem;
+ align-items: flex-start;
+}
+
+.server-icon {
+ flex-shrink: 0;
+ line-height: 1;
+ display: flex;
+ align-items: center;
+ justify-content: center;
+}
+
+.server-icon img {
+ width: 3rem;
+ height: 3rem;
+ object-fit: contain;
+}
+
+.server-info {
+ flex: 1;
+ min-width: 0;
+}
+
+.server-name {
+ font-size: 1.1rem;
+ font-weight: 600;
+ margin-bottom: 0.5rem;
+ color: var(--foreground);
+ word-wrap: break-word;
+}
+
+.server-description {
+ color: var(--muted-foreground);
+ font-size: 0.9rem;
+ line-height: 1.4;
+}
+
+.server-connect-button {
+ width: 100%;
+ background: var(--primary);
+ color: var(--background);
+ border: none;
+ border-top: 1px solid var(--border);
+ border-radius: 0 0 var(--radius-lg) var(--radius-lg);
+ padding: 1rem 1.5rem;
+ font-size: 1rem;
+ font-weight: 600;
+ cursor: pointer;
+ transition: all 0.2s;
+ margin: 0;
+}
+
+.server-connect-button:hover:not(:disabled) {
+ background: var(--primary);
+ opacity: 0.9;
+}
+
+.server-connect-button:disabled {
+ background: var(--muted);
+ color: var(--muted-foreground);
+ opacity: 0.5;
+ cursor: not-allowed;
+}
+
+.loading {
+ display: none;
+ text-align: center;
+ margin-top: 1rem;
+ color: var(--foreground);
+ opacity: 0.7;
+}
+
+.error {
+ background: color-mix(in oklch, var(--destructive) 10%, transparent);
+ color: var(--destructive);
+ border: 1px solid color-mix(in oklch, var(--destructive) 20%, transparent);
+ padding: 1rem;
+ border-radius: var(--radius-md);
+ margin-bottom: 1rem;
+ display: none;
+}
+
+.success {
+ background: color-mix(in oklch, var(--primary) 10%, transparent);
+ color: var(--primary);
+ border: 1px solid color-mix(in oklch, var(--primary) 20%, transparent);
+ padding: 1rem;
+ border-radius: var(--radius-md);
+ margin-bottom: 1rem;
+ display: none;
+ text-align: center;
+}
+
+
+.progress-bar {
+ width: 100%;
+ height: 4px;
+ background: var(--border);
+ border-radius: var(--radius-sm);
+ overflow: hidden;
+ margin-top: 0.5rem;
+}
+
+.progress-fill {
+ height: 100%;
+ background: var(--primary);
+ width: 0%;
+ transition: width 0.3s ease;
+}
+
+/* Responsive design */
+@media (max-width: 768px) {
+ .header {
+ padding: 1rem;
+ }
+
+ .main {
+ padding: 1rem;
+ }
+
+ .container {
+ padding: 1.5rem;
+ }
+
+ .servers-grid {
+ grid-template-columns: 1fr;
+ }
+
+ .logo {
+ font-size: 1.2rem;
+ }
+}
+
+/* Custom scrollbar */
+::-webkit-scrollbar {
+ width: 8px;
+}
+
+::-webkit-scrollbar-track {
+ background: rgba(255, 255, 255, 0.1);
+ border-radius: 4px;
+}
+
+::-webkit-scrollbar-thumb {
+ background: rgba(255, 255, 255, 0.3);
+ border-radius: 4px;
+}
+
+::-webkit-scrollbar-thumb:hover {
+ background: rgba(255, 255, 255, 0.5);
+}
\ No newline at end of file
diff --git a/cmd/rdpgw/web/web.go b/cmd/rdpgw/web/web.go
index 42bc036..9640680 100644
--- a/cmd/rdpgw/web/web.go
+++ b/cmd/rdpgw/web/web.go
@@ -5,13 +5,17 @@ import (
"context"
"crypto/rand"
"encoding/hex"
+ "encoding/json"
"errors"
"fmt"
"hash/maphash"
+ "html/template"
"log"
rnd "math/rand"
"net/http"
"net/url"
+ "os"
+ "path/filepath"
"strings"
"time"
@@ -37,6 +41,31 @@ type Config struct {
TemplateFile string
RdpSigningCert string
RdpSigningKey string
+ TemplatesPath string
+}
+
+// WebConfig represents the web interface configuration
+type WebConfig struct {
+ Branding struct {
+ Title string `json:"title"`
+ Logo string `json:"logo"`
+ PageTitle string `json:"page_title"`
+ } `json:"branding"`
+ Messages struct {
+ SelectServer string `json:"select_server"`
+ Preparing string `json:"preparing"`
+ } `json:"messages"`
+ UI struct {
+ ProgressAnimationDurationMs int `json:"progress_animation_duration_ms"`
+ AutoSelectDefault bool `json:"auto_select_default"`
+ ShowUserAvatar bool `json:"show_user_avatar"`
+ } `json:"ui"`
+ Theme struct {
+ PrimaryColor string `json:"primary_color"`
+ SecondaryColor string `json:"secondary_color"`
+ SuccessColor string `json:"success_color"`
+ ErrorColor string `json:"error_color"`
+ } `json:"theme"`
}
type RdpOpts struct {
@@ -57,6 +86,9 @@ type Handler struct {
rdpOpts RdpOpts
rdpDefaults string
rdpSigner *rdpsign.Signer
+ templatesPath string
+ webConfig *WebConfig
+ htmlTemplate *template.Template
}
func (c *Config) NewHandler() *Handler {
@@ -75,6 +107,7 @@ func (c *Config) NewHandler() *Handler {
hostSelection: c.HostSelection,
rdpOpts: c.RdpOpts,
rdpDefaults: c.TemplateFile,
+ templatesPath: c.TemplatesPath,
}
// set up RDP signer if config values are set
@@ -87,9 +120,231 @@ func (c *Config) NewHandler() *Handler {
handler.rdpSigner = signer
}
+ // Set up templates path
+ if handler.templatesPath == "" {
+ handler.templatesPath = "./templates"
+ }
+
+ // Load web configuration
+ handler.loadWebConfig()
+
+ // Load HTML template
+ handler.loadHTMLTemplate()
+
return handler
}
+// loadWebConfig sets up the web interface configuration with defaults
+func (h *Handler) loadWebConfig() {
+ // Set defaults - these can be overridden by the main config system later
+ h.webConfig = &WebConfig{}
+ h.webConfig.Branding.Title = "RDP Gateway"
+ h.webConfig.Branding.Logo = "RDP Gateway"
+ h.webConfig.Branding.PageTitle = "Select a Server to Connect"
+ h.webConfig.Messages.SelectServer = "Select a server to connect"
+ h.webConfig.Messages.Preparing = "Preparing your connection..."
+ h.webConfig.UI.ProgressAnimationDurationMs = 2000
+ h.webConfig.UI.AutoSelectDefault = true
+ h.webConfig.UI.ShowUserAvatar = true
+ h.webConfig.Theme.PrimaryColor = "#667eea"
+ h.webConfig.Theme.SecondaryColor = "#764ba2"
+ h.webConfig.Theme.SuccessColor = "#38b2ac"
+ h.webConfig.Theme.ErrorColor = "#c53030"
+}
+
+// loadHTMLTemplate loads the HTML template
+func (h *Handler) loadHTMLTemplate() {
+ templatePath := filepath.Join(h.templatesPath, "index.html")
+
+ tmpl, err := template.ParseFiles(templatePath)
+ if err != nil {
+ log.Printf("Warning: Failed to load HTML template %s: %v", templatePath, err)
+ log.Printf("Using embedded fallback template")
+ h.htmlTemplate = template.Must(template.New("index").Parse(fallbackHTMLTemplate))
+ } else {
+ h.htmlTemplate = tmpl
+ log.Printf("Loaded HTML template from %s", templatePath)
+ }
+}
+
+// ServeStaticFile serves static files from the templates directory
+func (h *Handler) ServeStaticFile(filename string) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ filePath := filepath.Join(h.templatesPath, filename)
+
+ // Check if file exists
+ if _, err := os.Stat(filePath); os.IsNotExist(err) {
+ http.NotFound(w, r)
+ return
+ }
+
+ // Set appropriate content type
+ switch filepath.Ext(filename) {
+ case ".css":
+ w.Header().Set("Content-Type", "text/css")
+ case ".js":
+ w.Header().Set("Content-Type", "application/javascript")
+ case ".svg":
+ w.Header().Set("Content-Type", "image/svg+xml")
+ case ".png":
+ w.Header().Set("Content-Type", "image/png")
+ case ".jpg", ".jpeg":
+ w.Header().Set("Content-Type", "image/jpeg")
+ default:
+ // Check if it's one of our logo files without extension
+ if filename == "logo.png" || filename == "logo_light_background.png" || filename == "logo_dark_background.png" {
+ w.Header().Set("Content-Type", "image/png")
+ }
+ }
+
+ // Enable caching for static files
+ w.Header().Set("Cache-Control", "public, max-age=3600")
+
+ http.ServeFile(w, r, filePath)
+ }
+}
+
+// ServeAssetFile serves asset files from the assets directory
+func (h *Handler) ServeAssetFile(filename string) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ var filePath string
+
+ // Try multiple possible locations for assets
+ possiblePaths := []string{
+ // Docker container paths
+ "./assets/" + filename,
+ "/app/assets/" + filename,
+ "/opt/rdpgw/assets/" + filename,
+ // Development paths
+ filepath.Join("assets", filename),
+ }
+
+ // Add icon.svg to the check as well
+ if filename == "icon.svg" {
+ possiblePaths = append(possiblePaths, "./icon.svg", "/app/icon.svg", "/opt/rdpgw/icon.svg")
+ }
+
+ // If we have templates path, try relative to it
+ if h.templatesPath != "" {
+ templatesDir, err := filepath.Abs(h.templatesPath)
+ if err == nil {
+ // Navigate up from templates to find assets
+ currentDir := templatesDir
+ for i := 0; i < 5; i++ {
+ parentDir := filepath.Dir(currentDir)
+ if parentDir == currentDir {
+ break
+ }
+ possiblePaths = append(possiblePaths, filepath.Join(parentDir, "assets", filename))
+ currentDir = parentDir
+ }
+ }
+ }
+
+ // Test each possible path
+ for _, testPath := range possiblePaths {
+ if _, err := os.Stat(testPath); err == nil {
+ filePath = testPath
+ break
+ }
+ }
+
+ if filePath == "" {
+ log.Printf("Asset file not found: %s. Tried paths: %v", filename, possiblePaths)
+ http.NotFound(w, r)
+ return
+ }
+
+ // Check if file exists
+ if _, err := os.Stat(filePath); os.IsNotExist(err) {
+ http.NotFound(w, r)
+ return
+ }
+
+ // Set appropriate content type
+ switch filepath.Ext(filename) {
+ case ".svg":
+ w.Header().Set("Content-Type", "image/svg+xml")
+ case ".png":
+ w.Header().Set("Content-Type", "image/png")
+ case ".jpg", ".jpeg":
+ w.Header().Set("Content-Type", "image/jpeg")
+ default:
+ // Check if it's one of our asset files without extension
+ if filename == "logo_light_background.png" || filename == "logo_dark_background.png" || filename == "connect.svg" {
+ if filepath.Ext(filename) == ".png" || filename == "logo_light_background.png" || filename == "logo_dark_background.png" {
+ w.Header().Set("Content-Type", "image/png")
+ } else {
+ w.Header().Set("Content-Type", "image/svg+xml")
+ }
+ }
+ }
+
+ // Enable caching for asset files
+ w.Header().Set("Cache-Control", "public, max-age=3600")
+
+ http.ServeFile(w, r, filePath)
+ }
+}
+
+// fallbackHTMLTemplate is used when external template file is not available
+const fallbackHTMLTemplate = `
+
+
+
+
+ {{.Title}}
+
+
+
+
+
{{.PageTitle}}
+
+
+
{{.PreparingMessage}}
+
+
+
+`
+
func (h *Handler) selectRandomHost() string {
r := rnd.New(rnd.NewSource(int64(new(maphash.Hash).Sum64())))
host := h.hosts[r.Intn(len(h.hosts))]
@@ -262,3 +517,107 @@ func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) {
// return signd rdp file
http.ServeContent(w, r, fn, time.Now(), bytes.NewReader(signedContent))
}
+
+// Host represents a host available for connection
+type Host struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ Address string `json:"address"`
+ Description string `json:"description"`
+ IsDefault bool `json:"isDefault"`
+}
+
+// UserInfo represents the current authenticated user
+type UserInfo struct {
+ Username string `json:"username"`
+ Authenticated bool `json:"authenticated"`
+ AuthTime time.Time `json:"authTime"`
+}
+
+// HandleHostList returns the list of available hosts for the authenticated user
+func (h *Handler) HandleHostList(w http.ResponseWriter, r *http.Request) {
+ id := identity.FromRequestCtx(r)
+
+ if !id.Authenticated() {
+ http.Error(w, "Unauthorized", http.StatusUnauthorized)
+ return
+ }
+
+ var hosts []Host
+
+ // Simplified host selection - all modes work the same for the user
+ if h.hostSelection == "roundrobin" {
+ hosts = append(hosts, Host{
+ ID: "roundrobin",
+ Name: "Available Servers",
+ Address: "",
+ Description: "Connect to an available server automatically",
+ IsDefault: true,
+ })
+ } else {
+ // For all other modes (signed, unsigned, any), show the actual hosts
+ for i, hostAddr := range h.hosts {
+ hosts = append(hosts, Host{
+ ID: fmt.Sprintf("host_%d", i),
+ Name: hostAddr,
+ Address: hostAddr,
+ Description: fmt.Sprintf("Connect to %s", hostAddr),
+ IsDefault: i == 0,
+ })
+ }
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(hosts)
+}
+
+// HandleUserInfo returns information about the current authenticated user
+func (h *Handler) HandleUserInfo(w http.ResponseWriter, r *http.Request) {
+ id := identity.FromRequestCtx(r)
+
+ if !id.Authenticated() {
+ http.Error(w, "Unauthorized", http.StatusUnauthorized)
+ return
+ }
+
+ userInfo := UserInfo{
+ Username: id.UserName(),
+ Authenticated: id.Authenticated(),
+ AuthTime: id.AuthTime(),
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(userInfo)
+}
+
+// HandleWebInterface serves the main web interface
+func (h *Handler) HandleWebInterface(w http.ResponseWriter, r *http.Request) {
+ id := identity.FromRequestCtx(r)
+
+ if !id.Authenticated() {
+ // Redirect to authentication
+ http.Redirect(w, r, "/connect", http.StatusFound)
+ return
+ }
+
+ // Template data
+ templateData := struct {
+ Title string
+ Logo string
+ PageTitle string
+ SelectServerMessage string
+ PreparingMessage string
+ }{
+ Title: h.webConfig.Branding.Title,
+ Logo: h.webConfig.Branding.Logo,
+ PageTitle: h.webConfig.Branding.PageTitle,
+ SelectServerMessage: h.webConfig.Messages.SelectServer,
+ PreparingMessage: h.webConfig.Messages.Preparing,
+ }
+
+ w.Header().Set("Content-Type", "text/html")
+ if err := h.htmlTemplate.Execute(w, templateData); err != nil {
+ log.Printf("Failed to execute template: %v", err)
+ http.Error(w, "Internal Server Error", http.StatusInternalServerError)
+ }
+}
diff --git a/cmd/rdpgw/web/web_interface_test.go b/cmd/rdpgw/web/web_interface_test.go
new file mode 100644
index 0000000..6d007e7
--- /dev/null
+++ b/cmd/rdpgw/web/web_interface_test.go
@@ -0,0 +1,418 @@
+package web
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
+)
+
+func TestHandleHostList(t *testing.T) {
+ tests := []struct {
+ name string
+ hostSelection string
+ hosts []string
+ authenticated bool
+ expectedCount int
+ expectedType string
+ }{
+ {
+ name: "roundrobin mode",
+ hostSelection: "roundrobin",
+ hosts: []string{"host1.example.com", "host2.example.com"},
+ authenticated: true,
+ expectedCount: 1,
+ expectedType: "roundrobin",
+ },
+ {
+ name: "unsigned mode",
+ hostSelection: "unsigned",
+ hosts: []string{"host1.example.com", "host2.example.com", "host3.example.com"},
+ authenticated: true,
+ expectedCount: 3,
+ expectedType: "individual",
+ },
+ {
+ name: "any mode",
+ hostSelection: "any",
+ hosts: []string{"host1.example.com"},
+ authenticated: true,
+ expectedCount: 1,
+ expectedType: "individual",
+ },
+ {
+ name: "signed mode",
+ hostSelection: "signed",
+ hosts: []string{"host1.example.com", "host2.example.com"},
+ authenticated: true,
+ expectedCount: 2,
+ expectedType: "signed",
+ },
+ {
+ name: "unauthenticated user",
+ hostSelection: "roundrobin",
+ hosts: []string{"host1.example.com"},
+ authenticated: false,
+ expectedCount: 0,
+ expectedType: "error",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Create handler
+ handler := &Handler{
+ hostSelection: tt.hostSelection,
+ hosts: tt.hosts,
+ }
+
+ // Create request
+ req := httptest.NewRequest("GET", "/api/v1/hosts", nil)
+ w := httptest.NewRecorder()
+
+ // Set identity context
+ user := identity.NewUser()
+ if tt.authenticated {
+ user.SetUserName("testuser")
+ user.SetAuthenticated(true)
+ user.SetAuthTime(time.Now())
+ }
+ req = identity.AddToRequestCtx(user, req)
+
+ // Call handler
+ handler.HandleHostList(w, req)
+
+ if !tt.authenticated {
+ if w.Code != http.StatusUnauthorized {
+ t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
+ }
+ return
+ }
+
+ // Check response
+ if w.Code != http.StatusOK {
+ t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
+ }
+
+ var hosts []Host
+ err := json.Unmarshal(w.Body.Bytes(), &hosts)
+ if err != nil {
+ t.Fatalf("Failed to unmarshal response: %v", err)
+ }
+
+ if len(hosts) != tt.expectedCount {
+ t.Errorf("Expected %d hosts, got %d", tt.expectedCount, len(hosts))
+ }
+
+ if len(hosts) > 0 {
+ switch tt.expectedType {
+ case "roundrobin":
+ if hosts[0].ID != "roundrobin" {
+ t.Errorf("Expected roundrobin host, got %s", hosts[0].ID)
+ }
+ case "individual":
+ if !strings.Contains(hosts[0].Name, tt.hosts[0]) {
+ t.Errorf("Expected host name to contain %s, got %s", tt.hosts[0], hosts[0].Name)
+ }
+ case "signed":
+ if !strings.Contains(hosts[0].Name, tt.hosts[0]) {
+ t.Errorf("Expected host name to contain %s, got %s", tt.hosts[0], hosts[0].Name)
+ }
+ }
+
+ // Check that first host is marked as default
+ hasDefault := false
+ for _, host := range hosts {
+ if host.IsDefault {
+ hasDefault = true
+ break
+ }
+ }
+ if !hasDefault {
+ t.Error("Expected at least one host to be marked as default")
+ }
+ }
+ })
+ }
+}
+
+func TestHandleUserInfo(t *testing.T) {
+ tests := []struct {
+ name string
+ authenticated bool
+ username string
+ authTime time.Time
+ }{
+ {
+ name: "authenticated user",
+ authenticated: true,
+ username: "john.doe@example.com",
+ authTime: time.Now(),
+ },
+ {
+ name: "unauthenticated user",
+ authenticated: false,
+ username: "",
+ authTime: time.Time{},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Create handler
+ handler := &Handler{}
+
+ // Create request
+ req := httptest.NewRequest("GET", "/api/v1/user", nil)
+ w := httptest.NewRecorder()
+
+ // Set identity context
+ user := identity.NewUser()
+ if tt.authenticated {
+ user.SetUserName(tt.username)
+ user.SetAuthenticated(true)
+ user.SetAuthTime(tt.authTime)
+ }
+ req = identity.AddToRequestCtx(user, req)
+
+ // Call handler
+ handler.HandleUserInfo(w, req)
+
+ if !tt.authenticated {
+ if w.Code != http.StatusUnauthorized {
+ t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
+ }
+ return
+ }
+
+ // Check response
+ if w.Code != http.StatusOK {
+ t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
+ }
+
+ var userInfo UserInfo
+ err := json.Unmarshal(w.Body.Bytes(), &userInfo)
+ if err != nil {
+ t.Fatalf("Failed to unmarshal response: %v", err)
+ }
+
+ if userInfo.Username != tt.username {
+ t.Errorf("Expected username %s, got %s", tt.username, userInfo.Username)
+ }
+
+ if userInfo.Authenticated != tt.authenticated {
+ t.Errorf("Expected authenticated %v, got %v", tt.authenticated, userInfo.Authenticated)
+ }
+
+ if tt.authenticated && userInfo.AuthTime.IsZero() {
+ t.Error("Expected non-zero auth time for authenticated user")
+ }
+ })
+ }
+}
+
+func TestHandleWebInterface(t *testing.T) {
+ tests := []struct {
+ name string
+ authenticated bool
+ expectStatus int
+ expectContent string
+ }{
+ {
+ name: "authenticated user",
+ authenticated: true,
+ expectStatus: http.StatusOK,
+ expectContent: "RDP Gateway",
+ },
+ {
+ name: "unauthenticated user",
+ authenticated: false,
+ expectStatus: http.StatusFound,
+ expectContent: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Create handler with minimal configuration
+ handler := &Handler{
+ templatesPath: "./templates",
+ }
+ handler.loadWebConfig()
+ handler.loadHTMLTemplate()
+
+ // Create request
+ req := httptest.NewRequest("GET", "/", nil)
+ w := httptest.NewRecorder()
+
+ // Set identity context
+ user := identity.NewUser()
+ if tt.authenticated {
+ user.SetUserName("testuser")
+ user.SetAuthenticated(true)
+ user.SetAuthTime(time.Now())
+ }
+ req = identity.AddToRequestCtx(user, req)
+
+ // Call handler
+ handler.HandleWebInterface(w, req)
+
+ // Check response
+ if w.Code != tt.expectStatus {
+ t.Errorf("Expected status %d, got %d", tt.expectStatus, w.Code)
+ }
+
+ if tt.authenticated {
+ body := w.Body.String()
+ if !strings.Contains(body, tt.expectContent) {
+ t.Errorf("Expected response to contain %s", tt.expectContent)
+ }
+
+ // Check that it's a complete HTML document
+ if !strings.Contains(body, "") {
+ t.Error("Expected complete HTML document")
+ }
+
+ // Check for key elements (using fallback template)
+ expectedElements := []string{
+ "serversGrid",
+ "connectButton",
+ "loadServers",
+ "connectToServer",
+ }
+
+ for _, element := range expectedElements {
+ if !strings.Contains(body, element) {
+ t.Errorf("Expected HTML to contain %s", element)
+ }
+ }
+ } else {
+ // Check redirect location
+ location := w.Header().Get("Location")
+ if location != "/connect" {
+ t.Errorf("Expected redirect to /connect, got %s", location)
+ }
+ }
+ })
+ }
+}
+
+func TestHostSelectionIntegration(t *testing.T) {
+ // Test the full flow from host selection to RDP download
+ tests := []struct {
+ name string
+ hostSelection string
+ hosts []string
+ queryParams string
+ expectHost string
+ expectError bool
+ }{
+ {
+ name: "roundrobin selection",
+ hostSelection: "roundrobin",
+ hosts: []string{"host1.com", "host2.com", "host3.com"},
+ queryParams: "",
+ expectHost: "", // Will be one of the hosts
+ expectError: false,
+ },
+ {
+ name: "unsigned specific host",
+ hostSelection: "unsigned",
+ hosts: []string{"host1.com", "host2.com"},
+ queryParams: "?host=host2.com",
+ expectHost: "host2.com",
+ expectError: false,
+ },
+ {
+ name: "unsigned invalid host",
+ hostSelection: "unsigned",
+ hosts: []string{"host1.com", "host2.com"},
+ queryParams: "?host=invalid.com",
+ expectHost: "",
+ expectError: true,
+ },
+ {
+ name: "any host allowed",
+ hostSelection: "any",
+ hosts: []string{"host1.com"},
+ queryParams: "?host=any-host.com",
+ expectHost: "any-host.com",
+ expectError: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Create handler
+ handler := &Handler{
+ hostSelection: tt.hostSelection,
+ hosts: tt.hosts,
+ gatewayAddress: &url.URL{Host: "gateway.example.com"},
+ }
+
+ // Create request for RDP download
+ req := httptest.NewRequest("GET", "/connect"+tt.queryParams, nil)
+ w := httptest.NewRecorder()
+
+ // Set authenticated user
+ user := identity.NewUser()
+ user.SetUserName("testuser")
+ user.SetAuthenticated(true)
+ user.SetAuthTime(time.Now())
+ req = identity.AddToRequestCtx(user, req)
+
+ // Mock the token generator to avoid errors
+ handler.paaTokenGenerator = func(ctx context.Context, user, host string) (string, error) {
+ return "mock-token", nil
+ }
+
+ // Call download handler
+ handler.HandleDownload(w, req)
+
+ if tt.expectError {
+ if w.Code == http.StatusOK {
+ t.Error("Expected error but got success")
+ }
+ } else {
+ if w.Code != http.StatusOK {
+ t.Errorf("Expected status %d, got %d. Body: %s", http.StatusOK, w.Code, w.Body.String())
+ }
+
+ // Check content type
+ contentType := w.Header().Get("Content-Type")
+ if contentType != "application/x-rdp" {
+ t.Errorf("Expected Content-Type application/x-rdp, got %s", contentType)
+ }
+
+ // Check content disposition
+ disposition := w.Header().Get("Content-Disposition")
+ if !strings.Contains(disposition, "attachment") || !strings.Contains(disposition, ".rdp") {
+ t.Errorf("Expected attachment disposition with .rdp file, got %s", disposition)
+ }
+
+ // Check RDP content for expected host
+ body := w.Body.String()
+ if tt.expectHost != "" {
+ if !strings.Contains(body, tt.expectHost) {
+ t.Errorf("Expected RDP content to contain host %s", tt.expectHost)
+ }
+ }
+
+ // Check for gateway configuration
+ if !strings.Contains(body, "gateway.example.com") {
+ t.Error("Expected RDP content to contain gateway address")
+ }
+
+ if !strings.Contains(body, "mock-token") {
+ t.Error("Expected RDP content to contain access token")
+ }
+ }
+ })
+ }
+}
diff --git a/dev/docker/Dockerfile b/dev/docker/Dockerfile
index 1631992..4bdb5cb 100644
--- a/dev/docker/Dockerfile
+++ b/dev/docker/Dockerfile
@@ -1,12 +1,13 @@
# builder stage
FROM golang:1.24-alpine as builder
+
# Install CA certificates explicitly in builder
-RUN apk --no-cache add git gcc musl-dev linux-pam-dev openssl ca-certificates
+RUN apk --no-cache add git gcc musl-dev linux-pam-dev openssl
# add user
RUN adduser --disabled-password --gecos "" --home /opt/rdpgw --uid 1001 rdpgw
-# certificate generation (your existing code)
+# certificate generation
RUN mkdir -p /opt/rdpgw && cd /opt/rdpgw && \
random=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 32 | head -n 1) && \
openssl genrsa -des3 -passout pass:$random -out server.pass.key 2048 && \
@@ -16,7 +17,7 @@ RUN mkdir -p /opt/rdpgw && cd /opt/rdpgw && \
-subj "/C=US/ST=VA/L=SomeCity/O=MyCompany/OU=MyDivision/CN=rdpgw" && \
openssl x509 -req -days 365 -in server.csr -signkey key.pem -out server.pem
-# build rdpgw and set rights (your existing code)
+# build rdpgw and set rights
ARG CACHEBUST
RUN git clone https://github.com/bolkedebruin/rdpgw.git /app && \
cd /app && \
@@ -29,7 +30,7 @@ RUN git clone https://github.com/bolkedebruin/rdpgw.git /app && \
FROM alpine:latest
# Install CA certificates in final stage
-RUN apk --no-cache add linux-pam musl tzdata ca-certificates
+RUN apk --no-cache add linux-pam musl tzdata ca-certificates && update-ca-certificates
# make tempdir in case filestore is used
ADD tmp.tar /
@@ -40,6 +41,9 @@ COPY --chown=1001 run.sh run.sh
COPY --chown=1001 --from=builder /opt/rdpgw /opt/rdpgw
COPY --chown=1001 --from=builder /etc/passwd /etc/passwd
+# Copy assets directory from the app source
+COPY --chown=1001 --from=builder /app/assets /opt/rdpgw/assets
+
USER 0
WORKDIR /opt/rdpgw
ENTRYPOINT ["/bin/sh", "/run.sh"]