Add webinterface

This commit is contained in:
Bolke de Bruin
2025-09-25 15:33:46 +02:00
parent 86c277cea4
commit 21a88d2dea
11 changed files with 1544 additions and 4 deletions

32
assets/connect.svg Normal file
View File

@@ -0,0 +1,32 @@
<svg width="200" height="200" viewBox="0 0 200 200" xmlns="http://www.w3.org/2000/svg">
<!-- Drop shadow -->
<defs>
<filter id="drop-shadow" x="-20%" y="-20%" width="140%" height="140%">
<feDropShadow dx="0" dy="4" stdDeviation="8" flood-color="rgba(0,0,0,0.2)"/>
</filter>
</defs>
<!-- Main circle -->
<circle cx="100" cy="100" r="90" fill="#D2572A" filter="url(#drop-shadow)"/>
<!-- White outline circle -->
<circle cx="100" cy="100" r="85" fill="none" stroke="#F5F5F5" stroke-width="2"/>
<!-- Monitor screen (centered) -->
<rect x="70" y="75" width="60" height="40" rx="2" ry="2" fill="none" stroke="white" stroke-width="4"/>
<!-- Monitor base -->
<rect x="95" y="115" width="30" height="6" rx="1" ry="1" fill="none" stroke="white" stroke-width="4"/>
<!-- Monitor stand -->
<rect x="105" y="121" width="10" height="10" fill="none" stroke="white" stroke-width="4"/>
<!-- Connection symbol circle background -->
<circle cx="125" cy="85" r="16" fill="white"/>
<!-- Connection symbol -->
<g stroke="#D2572A" stroke-width="2.5" fill="none" stroke-linecap="round">
<!-- Signal waves -->
<path d="M 115 85 Q 120 80, 125 85 Q 130 90, 135 85"/>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1.2 KiB

10
assets/icon.svg Normal file
View File

@@ -0,0 +1,10 @@
<svg width="100" height="100" viewBox="0 0 100 100" xmlns="http://www.w3.org/2000/svg">
<!-- White background circle -->
<circle cx="50" cy="50" r="45" fill="white"/>
<!-- Black circular outline -->
<circle cx="50" cy="50" r="45" fill="none" stroke="black" stroke-width="4"/>
<!-- Bold black R -->
<text x="50" y="68" text-anchor="middle" font-family="Georgia, serif" font-size="58" font-weight="bold" fill="black" style="font-style: italic;">R</text>
</svg>

After

Width:  |  Height:  |  Size: 485 B

View File

@@ -213,6 +213,9 @@ func main() {
// for sso callbacks
r.HandleFunc("/tokeninfo", web.TokenInfo)
// API routes
api := r.PathPrefix("/api/v1").Subrouter()
// gateway endpoint
rdp := r.PathPrefix(gatewayEndPoint).Subrouter()
@@ -223,6 +226,18 @@ func main() {
r.Handle("/connect", o.Authenticated(http.HandlerFunc(h.HandleDownload)))
r.HandleFunc("/callback", o.HandleCallback)
// Web interface and API routes (authenticated)
r.Handle("/", o.Authenticated(http.HandlerFunc(h.HandleWebInterface)))
api.Handle("/hosts", o.Authenticated(http.HandlerFunc(h.HandleHostList)))
api.Handle("/user", o.Authenticated(http.HandlerFunc(h.HandleUserInfo)))
// Static files (no authentication required)
r.HandleFunc("/static/style.css", h.ServeStaticFile("style.css"))
r.HandleFunc("/static/app.js", h.ServeStaticFile("app.js"))
// Asset files (no authentication required)
r.HandleFunc("/assets/connect.svg", h.ServeAssetFile("connect.svg"))
r.HandleFunc("/assets/icon.svg", h.ServeAssetFile("icon.svg"))
// only enable un-auth endpoint for openid only config
if !conf.Server.KerberosEnabled() && !conf.Server.BasicAuthEnabled() && !conf.Server.NtlmEnabled() && !conf.Server.HeaderEnabled() {
rdp.Name("gw").HandlerFunc(gw.HandleGatewayProtocol)
@@ -241,6 +256,18 @@ func main() {
headerAuth := headerConfig.New()
r.Handle("/connect", headerAuth.Authenticated(http.HandlerFunc(h.HandleDownload)))
// Web interface and API routes (authenticated)
r.Handle("/", headerAuth.Authenticated(http.HandlerFunc(h.HandleWebInterface)))
api.Handle("/hosts", headerAuth.Authenticated(http.HandlerFunc(h.HandleHostList)))
api.Handle("/user", headerAuth.Authenticated(http.HandlerFunc(h.HandleUserInfo)))
// Static files (no authentication required)
r.HandleFunc("/static/style.css", h.ServeStaticFile("style.css"))
r.HandleFunc("/static/app.js", h.ServeStaticFile("app.js"))
// Asset files (no authentication required)
r.HandleFunc("/assets/connect.svg", h.ServeAssetFile("connect.svg"))
r.HandleFunc("/assets/icon.svg", h.ServeAssetFile("icon.svg"))
// only enable un-auth endpoint for header only config
if !conf.Server.KerberosEnabled() && !conf.Server.BasicAuthEnabled() && !conf.Server.NtlmEnabled() && !conf.Server.OpenIDEnabled() {
rdp.Name("gw").HandlerFunc(gw.HandleGatewayProtocol)

View File

@@ -0,0 +1,95 @@
# RDP Gateway Web Interface Templates
This directory contains the customizable web interface templates for RDP Gateway.
## Files
### `index.html`
The main HTML template for the web interface. This file uses Go template syntax and can be customized to match your organization's branding.
**Template Variables Available:**
- `{{.Title}}` - Page title
- `{{.Logo}}` - Header logo text
- `{{.PageTitle}}` - Main page heading
- `{{.SelectServerMessage}}` - Default button text
- `{{.PreparingMessage}}` - Loading message
- `{{.AutoLaunchMessage}}` - Auto-launch notice text
### `style.css`
The CSS stylesheet for the web interface. Modify this file to customize:
- Colors and branding
- Layout and spacing
- Fonts and typography
- Responsive behavior
### `app.js`
The JavaScript file containing the web interface logic. This includes:
- Server list loading and rendering
- User authentication display
- **Automatic RDP client launching** (multiple methods)
- File download fallback
- Progress animations
### `config-example.json`
Example configuration structure showing available customization options. These values are set as defaults in the code but can be integrated with your main configuration system.
## Auto-Launch Functionality
The interface automatically attempts to launch RDP clients using **actual RDP file content**:
### How It Works:
1. **Fetches RDP Content**: Gets the complete RDP file configuration from `/api/rdp-content`
2. **Creates Data URL**: Converts RDP content to a downloadable blob
3. **Platform-Specific Launch**:
- **Windows**: Downloads .rdp file which auto-opens with mstsc
- **macOS**: Downloads .rdp file which auto-opens with Microsoft Remote Desktop
- **Universal**: Creates temporary download that browsers handle appropriately
### Technical Implementation:
- **`/api/rdp-content`** endpoint generates actual RDP file content with proper tokens
- **Data URLs** created from RDP content for browser download
- **Automatic file association** triggers RDP client launch
- **Graceful fallbacks** ensure users always get the RDP file
## Customization
To customize the interface:
1. **Copy this templates directory** to your preferred location
2. **Set the templates path** in your RDP Gateway configuration
3. **Edit the files** to match your branding requirements
4. **Restart RDP Gateway** to load the new templates
If template files are missing, the system automatically falls back to embedded templates to ensure the interface remains functional.
## API Endpoints
The web interface uses these authenticated API endpoints:
- **`/api/hosts`** - Returns available servers for the user (JSON)
- **`/api/user`** - Returns current user information (JSON)
- **`/api/rdp-content`** - Returns RDP file content as text for auto-launch
- **`/connect`** - Downloads RDP file (traditional endpoint)
## Static File Serving
The following URLs serve static files:
- `/static/style.css` - CSS stylesheet
- `/static/app.js` - JavaScript application
These files are served without authentication requirements for better performance.
## Browser Compatibility
The interface supports:
- Modern browsers (Chrome, Firefox, Safari, Edge)
- Mobile responsive design
- Protocol handlers for RDP client launching
- Graceful fallbacks for unsupported features
## Security Considerations
- Template files are served from the server filesystem
- Static files include cache headers for performance
- User authentication is required for the main interface
- API endpoints validate authentication before serving data

223
cmd/rdpgw/templates/app.js Normal file
View File

@@ -0,0 +1,223 @@
// RDP Gateway Web Interface
let userInfo = null;
// Theme handling - SVG logo works for both light and dark modes
function updateLogo() {
const logoImage = document.getElementById('logoImage');
if (logoImage) {
logoImage.src = '/assets/icon.svg';
}
}
// Configuration
const config = {
progressAnimationDuration: 2000, // ms for progress bar animation
};
// Get user initials for avatar
function getUserInitials(name) {
return name.split(' ').map(word => word.charAt(0)).slice(0, 2).join('').toUpperCase() || 'U';
}
// Load user information
async function loadUserInfo() {
try {
const response = await fetch('/api/v1/user');
if (response.ok) {
userInfo = await response.json();
document.getElementById('username').textContent = userInfo.username;
document.getElementById('userAvatar').textContent = getUserInitials(userInfo.username);
} else {
throw new Error('Failed to load user info');
}
} catch (error) {
showError('Failed to load user information');
}
}
// Load available servers
async function loadServers() {
try {
const response = await fetch('/api/v1/hosts');
if (response.ok) {
const servers = await response.json();
renderServers(servers);
} else {
throw new Error('Failed to load servers');
}
} catch (error) {
showError('Failed to load available servers');
}
}
// Render servers in the grid
function renderServers(servers) {
const grid = document.getElementById('serversGrid');
grid.innerHTML = '';
servers.forEach(server => {
const card = document.createElement('div');
card.className = 'server-card';
const connectButton = document.createElement('button');
connectButton.className = 'server-connect-button';
connectButton.textContent = `Connect to ${server.name}`;
connectButton.onclick = (e) => {
e.stopPropagation();
connectToServer(server, connectButton);
};
card.innerHTML = `
<div class="server-content">
<div class="server-icon">
<img src="/assets/connect.svg" alt="Connect" />
</div>
<div class="server-info">
<div class="server-name">${server.name}</div>
<div class="server-description">${server.description}</div>
</div>
</div>
`;
card.appendChild(connectButton);
grid.appendChild(card);
});
}
// Show error message
function showError(message) {
const errorDiv = document.getElementById('error');
errorDiv.textContent = message;
errorDiv.style.display = 'block';
hideSuccess();
}
// Hide error message
function hideError() {
document.getElementById('error').style.display = 'none';
}
// Show success message
function showSuccess(message) {
const successDiv = document.getElementById('success');
successDiv.textContent = message;
successDiv.style.display = 'block';
hideError();
}
// Hide success message
function hideSuccess() {
document.getElementById('success').style.display = 'none';
}
// Animate progress bar
function animateProgress(duration = config.progressAnimationDuration) {
const progressFill = document.getElementById('progressFill');
progressFill.style.width = '0%';
let startTime = null;
function animate(currentTime) {
if (!startTime) startTime = currentTime;
const elapsed = currentTime - startTime;
const progress = Math.min(elapsed / duration, 1);
progressFill.style.width = (progress * 100) + '%';
if (progress < 1) {
requestAnimationFrame(animate);
}
}
requestAnimationFrame(animate);
}
// Generate filename with user initials and random prefix
function generateFilename() {
if (!userInfo) return 'connection.rdp';
const initials = getUserInitials(userInfo.username);
const randomStr = Math.random().toString(36).substring(2, 8).toUpperCase();
return `${initials}_${randomStr}.rdp`;
}
// Download RDP file
async function downloadRDPFile(url) {
const link = document.createElement('a');
link.href = url;
link.download = generateFilename();
link.style.display = 'none';
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
}
// Connect to server
async function connectToServer(server, button) {
if (!server) return;
hideError();
hideSuccess();
const originalButtonText = button.textContent;
const loading = document.getElementById('loading');
// Update UI for loading state
button.disabled = true;
button.textContent = 'Downloading...';
loading.style.display = 'block';
// Start progress animation
animateProgress();
try {
// Build the RDP download URL
let url = '/connect';
if (server.address) {
url += '?host=' + encodeURIComponent(server.address);
}
// Wait a moment for better UX
await new Promise(resolve => setTimeout(resolve, 500));
// Download the RDP file
await downloadRDPFile(url);
showSuccess('RDP file downloaded. Please open it with your RDP client.');
// Reset UI after a delay
setTimeout(() => {
button.disabled = false;
button.textContent = originalButtonText;
loading.style.display = 'none';
}, 2000);
} catch (error) {
console.error('Connection error:', error);
showError('Failed to download RDP file. Please try again.');
// Reset UI immediately on error
button.disabled = false;
button.textContent = originalButtonText;
loading.style.display = 'none';
}
}
// Initialize the application
document.addEventListener('DOMContentLoaded', async () => {
// Set initial logo based on theme
updateLogo();
// Load data
await loadUserInfo();
await loadServers();
// No additional event handlers needed - buttons are handled in renderServers
});
// Handle visibility change (for auto-refresh when tab becomes visible)
document.addEventListener('visibilitychange', () => {
if (!document.hidden) {
// Refresh server list when tab becomes visible
loadServers();
}
});

View File

@@ -0,0 +1,22 @@
{
"branding": {
"title": "RDP Gateway",
"logo": "RDP Gateway",
"page_title": "Select a Server to Connect"
},
"messages": {
"select_server": "Select a server to connect",
"preparing": "Preparing your connection..."
},
"ui": {
"progress_animation_duration_ms": 2000,
"auto_select_default": true,
"show_user_avatar": true
},
"theme": {
"primary_color": "#667eea",
"secondary_color": "#764ba2",
"success_color": "#38b2ac",
"error_color": "#c53030"
}
}

View File

@@ -0,0 +1,45 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{{.Title}}</title>
<link rel="stylesheet" href="/static/style.css">
<link rel="icon" type="image/svg+xml" href="/assets/icon.svg">
<link rel="alternate icon" type="image/x-icon" href="/assets/icon.svg">
</head>
<body>
<div class="header">
<div class="logo">
<img src="/assets/icon.svg" alt="Logo" id="logoImage">
{{.Logo}}
</div>
<div class="user-info">
<div class="user-avatar" id="userAvatar"></div>
<span id="username">Loading...</span>
</div>
</div>
<div class="main">
<div class="container">
<h1 class="title">{{.PageTitle}}</h1>
<div class="error" id="error"></div>
<div class="success" id="success"></div>
<div class="servers-grid" id="serversGrid">
<!-- Servers will be loaded here -->
</div>
<div class="loading" id="loading">
{{.PreparingMessage}}
<div class="progress-bar">
<div class="progress-fill" id="progressFill"></div>
</div>
</div>
</div>
</div>
<script src="/static/app.js"></script>
</body>
</html>

View File

@@ -0,0 +1,305 @@
1:root {
/* Light mode colors (OKLCH) */
--background: oklch(1 0 0);
--foreground: oklch(0.145 0 0);
--primary: oklch(0.205 0 0);
--secondary: oklch(0.97 0 0);
--accent: oklch(0.97 0 0);
--destructive: oklch(0.577 0.245 27.325);
--border: oklch(0.922 0 0);
--muted: oklch(0.97 0 0);
--muted-foreground: oklch(0.466 0 0);
--card: oklch(1 0 0);
--popover: oklch(1 0 0);
/* Border radius - Pocket ID system */
--radius: 0.75rem;
--radius-sm: calc(var(--radius) - 4px);
--radius-md: calc(var(--radius) - 2px);
--radius-lg: var(--radius);
--radius-xl: calc(var(--radius) + 4px);
}
@media (prefers-color-scheme: dark) {
:root {
/* Dark mode colors */
--background: oklch(0.145 0 0);
--foreground: oklch(0.985 0 0);
--primary: oklch(0.922 0 0);
--secondary: oklch(0.269 0 0);
--accent: oklch(0.269 0 0);
--destructive: oklch(0.704 0.191 22.216);
--border: oklch(0.269 0 0);
--muted: oklch(0.205 0 0);
--muted-foreground: oklch(0.722 0 0);
--card: oklch(0.145 0 0);
--popover: oklch(0.145 0 0);
}
}
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
background: var(--background);
color: var(--foreground);
min-height: 100vh;
display: flex;
flex-direction: column;
transition: color 0.2s, background-color 0.2s;
}
.header {
background: var(--card);
border-bottom: 1px solid var(--border);
padding: 1rem 2rem;
display: flex;
justify-content: space-between;
align-items: center;
box-shadow: 0 1px 2px 0 rgb(0 0 0 / 0.05);
}
.logo {
display: flex;
align-items: center;
gap: 0.75rem;
font-size: 1.25rem;
font-weight: 600;
color: var(--foreground);
}
.logo img {
height: 2rem;
width: auto;
}
.user-info {
display: flex;
align-items: center;
gap: 1rem;
color: var(--foreground);
}
.user-avatar {
width: 40px;
height: 40px;
border-radius: 50%;
background: var(--secondary);
border: 1px solid var(--border);
display: flex;
align-items: center;
justify-content: center;
font-weight: 600;
color: var(--foreground);
}
.main {
flex: 1;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
padding: 2rem;
}
.container {
background: var(--card);
border-radius: var(--radius-xl);
border: 1px solid var(--border);
padding: 2rem;
box-shadow: 0 1px 3px 0 rgb(0 0 0 / 0.1), 0 1px 2px -1px rgb(0 0 0 / 0.1);
max-width: 800px;
width: 100%;
}
.title {
text-align: center;
margin-bottom: 2rem;
color: var(--foreground);
font-weight: 600;
}
.servers-grid {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(300px, 1fr));
gap: 1rem;
margin-bottom: 2rem;
}
.server-card {
border: 1px solid var(--border);
border-radius: var(--radius-lg);
padding: 0;
transition: all 0.2s;
position: relative;
background: var(--card);
display: flex;
flex-direction: column;
overflow: hidden;
}
.server-card:hover {
transform: translateY(-1px);
box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
}
.server-content {
padding: 1.5rem;
flex: 1;
display: flex;
gap: 1rem;
align-items: flex-start;
}
.server-icon {
flex-shrink: 0;
line-height: 1;
display: flex;
align-items: center;
justify-content: center;
}
.server-icon img {
width: 3rem;
height: 3rem;
object-fit: contain;
}
.server-info {
flex: 1;
min-width: 0;
}
.server-name {
font-size: 1.1rem;
font-weight: 600;
margin-bottom: 0.5rem;
color: var(--foreground);
word-wrap: break-word;
}
.server-description {
color: var(--muted-foreground);
font-size: 0.9rem;
line-height: 1.4;
}
.server-connect-button {
width: 100%;
background: var(--primary);
color: var(--background);
border: none;
border-top: 1px solid var(--border);
border-radius: 0 0 var(--radius-lg) var(--radius-lg);
padding: 1rem 1.5rem;
font-size: 1rem;
font-weight: 600;
cursor: pointer;
transition: all 0.2s;
margin: 0;
}
.server-connect-button:hover:not(:disabled) {
background: var(--primary);
opacity: 0.9;
}
.server-connect-button:disabled {
background: var(--muted);
color: var(--muted-foreground);
opacity: 0.5;
cursor: not-allowed;
}
.loading {
display: none;
text-align: center;
margin-top: 1rem;
color: var(--foreground);
opacity: 0.7;
}
.error {
background: color-mix(in oklch, var(--destructive) 10%, transparent);
color: var(--destructive);
border: 1px solid color-mix(in oklch, var(--destructive) 20%, transparent);
padding: 1rem;
border-radius: var(--radius-md);
margin-bottom: 1rem;
display: none;
}
.success {
background: color-mix(in oklch, var(--primary) 10%, transparent);
color: var(--primary);
border: 1px solid color-mix(in oklch, var(--primary) 20%, transparent);
padding: 1rem;
border-radius: var(--radius-md);
margin-bottom: 1rem;
display: none;
text-align: center;
}
.progress-bar {
width: 100%;
height: 4px;
background: var(--border);
border-radius: var(--radius-sm);
overflow: hidden;
margin-top: 0.5rem;
}
.progress-fill {
height: 100%;
background: var(--primary);
width: 0%;
transition: width 0.3s ease;
}
/* Responsive design */
@media (max-width: 768px) {
.header {
padding: 1rem;
}
.main {
padding: 1rem;
}
.container {
padding: 1.5rem;
}
.servers-grid {
grid-template-columns: 1fr;
}
.logo {
font-size: 1.2rem;
}
}
/* Custom scrollbar */
::-webkit-scrollbar {
width: 8px;
}
::-webkit-scrollbar-track {
background: rgba(255, 255, 255, 0.1);
border-radius: 4px;
}
::-webkit-scrollbar-thumb {
background: rgba(255, 255, 255, 0.3);
border-radius: 4px;
}
::-webkit-scrollbar-thumb:hover {
background: rgba(255, 255, 255, 0.5);
}

View File

@@ -5,13 +5,17 @@ import (
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"hash/maphash"
"html/template"
"log"
rnd "math/rand"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"time"
@@ -37,6 +41,31 @@ type Config struct {
TemplateFile string
RdpSigningCert string
RdpSigningKey string
TemplatesPath string
}
// WebConfig represents the web interface configuration
type WebConfig struct {
Branding struct {
Title string `json:"title"`
Logo string `json:"logo"`
PageTitle string `json:"page_title"`
} `json:"branding"`
Messages struct {
SelectServer string `json:"select_server"`
Preparing string `json:"preparing"`
} `json:"messages"`
UI struct {
ProgressAnimationDurationMs int `json:"progress_animation_duration_ms"`
AutoSelectDefault bool `json:"auto_select_default"`
ShowUserAvatar bool `json:"show_user_avatar"`
} `json:"ui"`
Theme struct {
PrimaryColor string `json:"primary_color"`
SecondaryColor string `json:"secondary_color"`
SuccessColor string `json:"success_color"`
ErrorColor string `json:"error_color"`
} `json:"theme"`
}
type RdpOpts struct {
@@ -57,6 +86,9 @@ type Handler struct {
rdpOpts RdpOpts
rdpDefaults string
rdpSigner *rdpsign.Signer
templatesPath string
webConfig *WebConfig
htmlTemplate *template.Template
}
func (c *Config) NewHandler() *Handler {
@@ -75,6 +107,7 @@ func (c *Config) NewHandler() *Handler {
hostSelection: c.HostSelection,
rdpOpts: c.RdpOpts,
rdpDefaults: c.TemplateFile,
templatesPath: c.TemplatesPath,
}
// set up RDP signer if config values are set
@@ -87,9 +120,231 @@ func (c *Config) NewHandler() *Handler {
handler.rdpSigner = signer
}
// Set up templates path
if handler.templatesPath == "" {
handler.templatesPath = "./templates"
}
// Load web configuration
handler.loadWebConfig()
// Load HTML template
handler.loadHTMLTemplate()
return handler
}
// loadWebConfig sets up the web interface configuration with defaults
func (h *Handler) loadWebConfig() {
// Set defaults - these can be overridden by the main config system later
h.webConfig = &WebConfig{}
h.webConfig.Branding.Title = "RDP Gateway"
h.webConfig.Branding.Logo = "RDP Gateway"
h.webConfig.Branding.PageTitle = "Select a Server to Connect"
h.webConfig.Messages.SelectServer = "Select a server to connect"
h.webConfig.Messages.Preparing = "Preparing your connection..."
h.webConfig.UI.ProgressAnimationDurationMs = 2000
h.webConfig.UI.AutoSelectDefault = true
h.webConfig.UI.ShowUserAvatar = true
h.webConfig.Theme.PrimaryColor = "#667eea"
h.webConfig.Theme.SecondaryColor = "#764ba2"
h.webConfig.Theme.SuccessColor = "#38b2ac"
h.webConfig.Theme.ErrorColor = "#c53030"
}
// loadHTMLTemplate loads the HTML template
func (h *Handler) loadHTMLTemplate() {
templatePath := filepath.Join(h.templatesPath, "index.html")
tmpl, err := template.ParseFiles(templatePath)
if err != nil {
log.Printf("Warning: Failed to load HTML template %s: %v", templatePath, err)
log.Printf("Using embedded fallback template")
h.htmlTemplate = template.Must(template.New("index").Parse(fallbackHTMLTemplate))
} else {
h.htmlTemplate = tmpl
log.Printf("Loaded HTML template from %s", templatePath)
}
}
// ServeStaticFile serves static files from the templates directory
func (h *Handler) ServeStaticFile(filename string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
filePath := filepath.Join(h.templatesPath, filename)
// Check if file exists
if _, err := os.Stat(filePath); os.IsNotExist(err) {
http.NotFound(w, r)
return
}
// Set appropriate content type
switch filepath.Ext(filename) {
case ".css":
w.Header().Set("Content-Type", "text/css")
case ".js":
w.Header().Set("Content-Type", "application/javascript")
case ".svg":
w.Header().Set("Content-Type", "image/svg+xml")
case ".png":
w.Header().Set("Content-Type", "image/png")
case ".jpg", ".jpeg":
w.Header().Set("Content-Type", "image/jpeg")
default:
// Check if it's one of our logo files without extension
if filename == "logo.png" || filename == "logo_light_background.png" || filename == "logo_dark_background.png" {
w.Header().Set("Content-Type", "image/png")
}
}
// Enable caching for static files
w.Header().Set("Cache-Control", "public, max-age=3600")
http.ServeFile(w, r, filePath)
}
}
// ServeAssetFile serves asset files from the assets directory
func (h *Handler) ServeAssetFile(filename string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var filePath string
// Try multiple possible locations for assets
possiblePaths := []string{
// Docker container paths
"./assets/" + filename,
"/app/assets/" + filename,
"/opt/rdpgw/assets/" + filename,
// Development paths
filepath.Join("assets", filename),
}
// Add icon.svg to the check as well
if filename == "icon.svg" {
possiblePaths = append(possiblePaths, "./icon.svg", "/app/icon.svg", "/opt/rdpgw/icon.svg")
}
// If we have templates path, try relative to it
if h.templatesPath != "" {
templatesDir, err := filepath.Abs(h.templatesPath)
if err == nil {
// Navigate up from templates to find assets
currentDir := templatesDir
for i := 0; i < 5; i++ {
parentDir := filepath.Dir(currentDir)
if parentDir == currentDir {
break
}
possiblePaths = append(possiblePaths, filepath.Join(parentDir, "assets", filename))
currentDir = parentDir
}
}
}
// Test each possible path
for _, testPath := range possiblePaths {
if _, err := os.Stat(testPath); err == nil {
filePath = testPath
break
}
}
if filePath == "" {
log.Printf("Asset file not found: %s. Tried paths: %v", filename, possiblePaths)
http.NotFound(w, r)
return
}
// Check if file exists
if _, err := os.Stat(filePath); os.IsNotExist(err) {
http.NotFound(w, r)
return
}
// Set appropriate content type
switch filepath.Ext(filename) {
case ".svg":
w.Header().Set("Content-Type", "image/svg+xml")
case ".png":
w.Header().Set("Content-Type", "image/png")
case ".jpg", ".jpeg":
w.Header().Set("Content-Type", "image/jpeg")
default:
// Check if it's one of our asset files without extension
if filename == "logo_light_background.png" || filename == "logo_dark_background.png" || filename == "connect.svg" {
if filepath.Ext(filename) == ".png" || filename == "logo_light_background.png" || filename == "logo_dark_background.png" {
w.Header().Set("Content-Type", "image/png")
} else {
w.Header().Set("Content-Type", "image/svg+xml")
}
}
}
// Enable caching for asset files
w.Header().Set("Cache-Control", "public, max-age=3600")
http.ServeFile(w, r, filePath)
}
}
// fallbackHTMLTemplate is used when external template file is not available
const fallbackHTMLTemplate = `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{{.Title}}</title>
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); min-height: 100vh; }
.container { max-width: 800px; margin: 2rem auto; padding: 2rem; background: white; border-radius: 12px; }
.server-card { border: 2px solid #e2e8f0; border-radius: 8px; padding: 1.5rem; margin: 1rem; cursor: pointer; }
.server-card:hover { border-color: #667eea; }
.server-card.selected { border-color: #667eea; background: rgba(102, 126, 234, 0.05); }
.connect-button { width: 100%; background: #667eea; color: white; border: none; border-radius: 8px;
padding: 1rem 2rem; font-size: 1.1rem; cursor: pointer; }
.connect-button:disabled { background: #a0aec0; cursor: not-allowed; }
</style>
</head>
<body>
<div class="container">
<h1>{{.PageTitle}}</h1>
<div id="serversGrid"></div>
<button class="connect-button" id="connectButton" disabled>{{.SelectServerMessage}}</button>
<div id="loading" style="display:none;">{{.PreparingMessage}}</div>
</div>
<script>
// Fallback minimal JavaScript
let selectedServer = null;
async function loadServers() {
const response = await fetch('/api/hosts');
const servers = await response.json();
const grid = document.getElementById('serversGrid');
servers.forEach(server => {
const card = document.createElement('div');
card.className = 'server-card';
card.innerHTML = server.icon + ' ' + server.name + '<br><small>' + server.description + '</small>';
card.onclick = () => {
document.querySelectorAll('.server-card').forEach(c => c.classList.remove('selected'));
card.classList.add('selected');
selectedServer = server;
document.getElementById('connectButton').disabled = false;
};
grid.appendChild(card);
});
}
async function connectToServer() {
if (!selectedServer) return;
let url = '/connect';
if (selectedServer.address) url += '?host=' + encodeURIComponent(selectedServer.address);
window.location.href = url;
}
document.addEventListener('DOMContentLoaded', loadServers);
document.getElementById('connectButton').onclick = connectToServer;
</script>
</body>
</html>`
func (h *Handler) selectRandomHost() string {
r := rnd.New(rnd.NewSource(int64(new(maphash.Hash).Sum64())))
host := h.hosts[r.Intn(len(h.hosts))]
@@ -262,3 +517,107 @@ func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) {
// return signd rdp file
http.ServeContent(w, r, fn, time.Now(), bytes.NewReader(signedContent))
}
// Host represents a host available for connection
type Host struct {
ID string `json:"id"`
Name string `json:"name"`
Address string `json:"address"`
Description string `json:"description"`
IsDefault bool `json:"isDefault"`
}
// UserInfo represents the current authenticated user
type UserInfo struct {
Username string `json:"username"`
Authenticated bool `json:"authenticated"`
AuthTime time.Time `json:"authTime"`
}
// HandleHostList returns the list of available hosts for the authenticated user
func (h *Handler) HandleHostList(w http.ResponseWriter, r *http.Request) {
id := identity.FromRequestCtx(r)
if !id.Authenticated() {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
var hosts []Host
// Simplified host selection - all modes work the same for the user
if h.hostSelection == "roundrobin" {
hosts = append(hosts, Host{
ID: "roundrobin",
Name: "Available Servers",
Address: "",
Description: "Connect to an available server automatically",
IsDefault: true,
})
} else {
// For all other modes (signed, unsigned, any), show the actual hosts
for i, hostAddr := range h.hosts {
hosts = append(hosts, Host{
ID: fmt.Sprintf("host_%d", i),
Name: hostAddr,
Address: hostAddr,
Description: fmt.Sprintf("Connect to %s", hostAddr),
IsDefault: i == 0,
})
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(hosts)
}
// HandleUserInfo returns information about the current authenticated user
func (h *Handler) HandleUserInfo(w http.ResponseWriter, r *http.Request) {
id := identity.FromRequestCtx(r)
if !id.Authenticated() {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
userInfo := UserInfo{
Username: id.UserName(),
Authenticated: id.Authenticated(),
AuthTime: id.AuthTime(),
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(userInfo)
}
// HandleWebInterface serves the main web interface
func (h *Handler) HandleWebInterface(w http.ResponseWriter, r *http.Request) {
id := identity.FromRequestCtx(r)
if !id.Authenticated() {
// Redirect to authentication
http.Redirect(w, r, "/connect", http.StatusFound)
return
}
// Template data
templateData := struct {
Title string
Logo string
PageTitle string
SelectServerMessage string
PreparingMessage string
}{
Title: h.webConfig.Branding.Title,
Logo: h.webConfig.Branding.Logo,
PageTitle: h.webConfig.Branding.PageTitle,
SelectServerMessage: h.webConfig.Messages.SelectServer,
PreparingMessage: h.webConfig.Messages.Preparing,
}
w.Header().Set("Content-Type", "text/html")
if err := h.htmlTemplate.Execute(w, templateData); err != nil {
log.Printf("Failed to execute template: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}

View File

@@ -0,0 +1,418 @@
package web
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
)
func TestHandleHostList(t *testing.T) {
tests := []struct {
name string
hostSelection string
hosts []string
authenticated bool
expectedCount int
expectedType string
}{
{
name: "roundrobin mode",
hostSelection: "roundrobin",
hosts: []string{"host1.example.com", "host2.example.com"},
authenticated: true,
expectedCount: 1,
expectedType: "roundrobin",
},
{
name: "unsigned mode",
hostSelection: "unsigned",
hosts: []string{"host1.example.com", "host2.example.com", "host3.example.com"},
authenticated: true,
expectedCount: 3,
expectedType: "individual",
},
{
name: "any mode",
hostSelection: "any",
hosts: []string{"host1.example.com"},
authenticated: true,
expectedCount: 1,
expectedType: "individual",
},
{
name: "signed mode",
hostSelection: "signed",
hosts: []string{"host1.example.com", "host2.example.com"},
authenticated: true,
expectedCount: 2,
expectedType: "signed",
},
{
name: "unauthenticated user",
hostSelection: "roundrobin",
hosts: []string{"host1.example.com"},
authenticated: false,
expectedCount: 0,
expectedType: "error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create handler
handler := &Handler{
hostSelection: tt.hostSelection,
hosts: tt.hosts,
}
// Create request
req := httptest.NewRequest("GET", "/api/v1/hosts", nil)
w := httptest.NewRecorder()
// Set identity context
user := identity.NewUser()
if tt.authenticated {
user.SetUserName("testuser")
user.SetAuthenticated(true)
user.SetAuthTime(time.Now())
}
req = identity.AddToRequestCtx(user, req)
// Call handler
handler.HandleHostList(w, req)
if !tt.authenticated {
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
}
return
}
// Check response
if w.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
}
var hosts []Host
err := json.Unmarshal(w.Body.Bytes(), &hosts)
if err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if len(hosts) != tt.expectedCount {
t.Errorf("Expected %d hosts, got %d", tt.expectedCount, len(hosts))
}
if len(hosts) > 0 {
switch tt.expectedType {
case "roundrobin":
if hosts[0].ID != "roundrobin" {
t.Errorf("Expected roundrobin host, got %s", hosts[0].ID)
}
case "individual":
if !strings.Contains(hosts[0].Name, tt.hosts[0]) {
t.Errorf("Expected host name to contain %s, got %s", tt.hosts[0], hosts[0].Name)
}
case "signed":
if !strings.Contains(hosts[0].Name, tt.hosts[0]) {
t.Errorf("Expected host name to contain %s, got %s", tt.hosts[0], hosts[0].Name)
}
}
// Check that first host is marked as default
hasDefault := false
for _, host := range hosts {
if host.IsDefault {
hasDefault = true
break
}
}
if !hasDefault {
t.Error("Expected at least one host to be marked as default")
}
}
})
}
}
func TestHandleUserInfo(t *testing.T) {
tests := []struct {
name string
authenticated bool
username string
authTime time.Time
}{
{
name: "authenticated user",
authenticated: true,
username: "john.doe@example.com",
authTime: time.Now(),
},
{
name: "unauthenticated user",
authenticated: false,
username: "",
authTime: time.Time{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create handler
handler := &Handler{}
// Create request
req := httptest.NewRequest("GET", "/api/v1/user", nil)
w := httptest.NewRecorder()
// Set identity context
user := identity.NewUser()
if tt.authenticated {
user.SetUserName(tt.username)
user.SetAuthenticated(true)
user.SetAuthTime(tt.authTime)
}
req = identity.AddToRequestCtx(user, req)
// Call handler
handler.HandleUserInfo(w, req)
if !tt.authenticated {
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
}
return
}
// Check response
if w.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
}
var userInfo UserInfo
err := json.Unmarshal(w.Body.Bytes(), &userInfo)
if err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if userInfo.Username != tt.username {
t.Errorf("Expected username %s, got %s", tt.username, userInfo.Username)
}
if userInfo.Authenticated != tt.authenticated {
t.Errorf("Expected authenticated %v, got %v", tt.authenticated, userInfo.Authenticated)
}
if tt.authenticated && userInfo.AuthTime.IsZero() {
t.Error("Expected non-zero auth time for authenticated user")
}
})
}
}
func TestHandleWebInterface(t *testing.T) {
tests := []struct {
name string
authenticated bool
expectStatus int
expectContent string
}{
{
name: "authenticated user",
authenticated: true,
expectStatus: http.StatusOK,
expectContent: "RDP Gateway",
},
{
name: "unauthenticated user",
authenticated: false,
expectStatus: http.StatusFound,
expectContent: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create handler with minimal configuration
handler := &Handler{
templatesPath: "./templates",
}
handler.loadWebConfig()
handler.loadHTMLTemplate()
// Create request
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
// Set identity context
user := identity.NewUser()
if tt.authenticated {
user.SetUserName("testuser")
user.SetAuthenticated(true)
user.SetAuthTime(time.Now())
}
req = identity.AddToRequestCtx(user, req)
// Call handler
handler.HandleWebInterface(w, req)
// Check response
if w.Code != tt.expectStatus {
t.Errorf("Expected status %d, got %d", tt.expectStatus, w.Code)
}
if tt.authenticated {
body := w.Body.String()
if !strings.Contains(body, tt.expectContent) {
t.Errorf("Expected response to contain %s", tt.expectContent)
}
// Check that it's a complete HTML document
if !strings.Contains(body, "<!DOCTYPE html>") {
t.Error("Expected complete HTML document")
}
// Check for key elements (using fallback template)
expectedElements := []string{
"serversGrid",
"connectButton",
"loadServers",
"connectToServer",
}
for _, element := range expectedElements {
if !strings.Contains(body, element) {
t.Errorf("Expected HTML to contain %s", element)
}
}
} else {
// Check redirect location
location := w.Header().Get("Location")
if location != "/connect" {
t.Errorf("Expected redirect to /connect, got %s", location)
}
}
})
}
}
func TestHostSelectionIntegration(t *testing.T) {
// Test the full flow from host selection to RDP download
tests := []struct {
name string
hostSelection string
hosts []string
queryParams string
expectHost string
expectError bool
}{
{
name: "roundrobin selection",
hostSelection: "roundrobin",
hosts: []string{"host1.com", "host2.com", "host3.com"},
queryParams: "",
expectHost: "", // Will be one of the hosts
expectError: false,
},
{
name: "unsigned specific host",
hostSelection: "unsigned",
hosts: []string{"host1.com", "host2.com"},
queryParams: "?host=host2.com",
expectHost: "host2.com",
expectError: false,
},
{
name: "unsigned invalid host",
hostSelection: "unsigned",
hosts: []string{"host1.com", "host2.com"},
queryParams: "?host=invalid.com",
expectHost: "",
expectError: true,
},
{
name: "any host allowed",
hostSelection: "any",
hosts: []string{"host1.com"},
queryParams: "?host=any-host.com",
expectHost: "any-host.com",
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create handler
handler := &Handler{
hostSelection: tt.hostSelection,
hosts: tt.hosts,
gatewayAddress: &url.URL{Host: "gateway.example.com"},
}
// Create request for RDP download
req := httptest.NewRequest("GET", "/connect"+tt.queryParams, nil)
w := httptest.NewRecorder()
// Set authenticated user
user := identity.NewUser()
user.SetUserName("testuser")
user.SetAuthenticated(true)
user.SetAuthTime(time.Now())
req = identity.AddToRequestCtx(user, req)
// Mock the token generator to avoid errors
handler.paaTokenGenerator = func(ctx context.Context, user, host string) (string, error) {
return "mock-token", nil
}
// Call download handler
handler.HandleDownload(w, req)
if tt.expectError {
if w.Code == http.StatusOK {
t.Error("Expected error but got success")
}
} else {
if w.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d. Body: %s", http.StatusOK, w.Code, w.Body.String())
}
// Check content type
contentType := w.Header().Get("Content-Type")
if contentType != "application/x-rdp" {
t.Errorf("Expected Content-Type application/x-rdp, got %s", contentType)
}
// Check content disposition
disposition := w.Header().Get("Content-Disposition")
if !strings.Contains(disposition, "attachment") || !strings.Contains(disposition, ".rdp") {
t.Errorf("Expected attachment disposition with .rdp file, got %s", disposition)
}
// Check RDP content for expected host
body := w.Body.String()
if tt.expectHost != "" {
if !strings.Contains(body, tt.expectHost) {
t.Errorf("Expected RDP content to contain host %s", tt.expectHost)
}
}
// Check for gateway configuration
if !strings.Contains(body, "gateway.example.com") {
t.Error("Expected RDP content to contain gateway address")
}
if !strings.Contains(body, "mock-token") {
t.Error("Expected RDP content to contain access token")
}
}
})
}
}

View File

@@ -1,12 +1,13 @@
# builder stage
FROM golang:1.24-alpine as builder
# Install CA certificates explicitly in builder
RUN apk --no-cache add git gcc musl-dev linux-pam-dev openssl ca-certificates
RUN apk --no-cache add git gcc musl-dev linux-pam-dev openssl
# add user
RUN adduser --disabled-password --gecos "" --home /opt/rdpgw --uid 1001 rdpgw
# certificate generation (your existing code)
# certificate generation
RUN mkdir -p /opt/rdpgw && cd /opt/rdpgw && \
random=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 32 | head -n 1) && \
openssl genrsa -des3 -passout pass:$random -out server.pass.key 2048 && \
@@ -16,7 +17,7 @@ RUN mkdir -p /opt/rdpgw && cd /opt/rdpgw && \
-subj "/C=US/ST=VA/L=SomeCity/O=MyCompany/OU=MyDivision/CN=rdpgw" && \
openssl x509 -req -days 365 -in server.csr -signkey key.pem -out server.pem
# build rdpgw and set rights (your existing code)
# build rdpgw and set rights
ARG CACHEBUST
RUN git clone https://github.com/bolkedebruin/rdpgw.git /app && \
cd /app && \
@@ -29,7 +30,7 @@ RUN git clone https://github.com/bolkedebruin/rdpgw.git /app && \
FROM alpine:latest
# Install CA certificates in final stage
RUN apk --no-cache add linux-pam musl tzdata ca-certificates
RUN apk --no-cache add linux-pam musl tzdata ca-certificates && update-ca-certificates
# make tempdir in case filestore is used
ADD tmp.tar /
@@ -40,6 +41,9 @@ COPY --chown=1001 run.sh run.sh
COPY --chown=1001 --from=builder /opt/rdpgw /opt/rdpgw
COPY --chown=1001 --from=builder /etc/passwd /etc/passwd
# Copy assets directory from the app source
COPY --chown=1001 --from=builder /app/assets /opt/rdpgw/assets
USER 0
WORKDIR /opt/rdpgw
ENTRYPOINT ["/bin/sh", "/run.sh"]