diff --git a/.dockerignore b/.dockerignore index c748048e..d4f63d63 100644 --- a/.dockerignore +++ b/.dockerignore @@ -28,9 +28,9 @@ LICENSE CONTRIBUTING.md dist .git -migrations/ +server/migrations/ config/ build.ts tsconfig.json Dockerfile* -migrations/ +drizzle.config.ts diff --git a/.github/workflows/cicd.yml b/.github/workflows/cicd.yml index 7358fa2a..5a776c99 100644 --- a/.github/workflows/cicd.yml +++ b/.github/workflows/cicd.yml @@ -289,22 +289,14 @@ jobs: echo "LATEST_BADGER_TAG=$LATEST_TAG" >> $GITHUB_ENV shell: bash - - name: Update install/main.go - run: | - PANGOLIN_VERSION=${{ env.TAG }} - GERBIL_VERSION=${{ env.LATEST_GERBIL_TAG }} - BADGER_VERSION=${{ env.LATEST_BADGER_TAG }} - sed -i "s/config.PangolinVersion = \".*\"/config.PangolinVersion = \"$PANGOLIN_VERSION\"/" install/main.go - sed -i "s/config.GerbilVersion = \".*\"/config.GerbilVersion = \"$GERBIL_VERSION\"/" install/main.go - sed -i "s/config.BadgerVersion = \".*\"/config.BadgerVersion = \"$BADGER_VERSION\"/" install/main.go - echo "Updated install/main.go with Pangolin version $PANGOLIN_VERSION, Gerbil version $GERBIL_VERSION, and Badger version $BADGER_VERSION" - cat install/main.go - shell: bash - - name: Build installer working-directory: install run: | - make go-build-release + make go-build-release \ + PANGOLIN_VERSION=${{ env.TAG }} \ + GERBIL_VERSION=${{ env.LATEST_GERBIL_TAG }} \ + BADGER_VERSION=${{ env.LATEST_BADGER_TAG }} + shell: bash - name: Upload artifacts from /install/bin uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 diff --git a/Dockerfile b/Dockerfile index 12c519b7..07acba26 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,8 +1,8 @@ -FROM node:24-alpine AS base +FROM node:24-slim AS base WORKDIR /app -RUN apk add --no-cache python3 make g++ +RUN apt-get update && apt-get install -y python3 make g++ && rm -rf /var/lib/apt/lists/* COPY package*.json ./ @@ -27,11 +27,11 @@ FROM base AS builder RUN npm ci --omit=dev -FROM node:24-alpine AS runner +FROM node:24-slim AS runner WORKDIR /app -RUN apk add --no-cache curl tzdata +RUN apt-get update && apt-get install -y curl tzdata && rm -rf /var/lib/apt/lists/* COPY --from=builder /app/node_modules ./node_modules COPY --from=builder /app/package.json ./package.json diff --git a/install/Makefile b/install/Makefile index 53365f50..8a836b77 100644 --- a/install/Makefile +++ b/install/Makefile @@ -1,41 +1,24 @@ -all: update-versions go-build-release put-back -dev-all: dev-update-versions dev-build dev-clean +all: go-build-release + +# Build with version injection via ldflags +# Versions can be passed via: make go-build-release PANGOLIN_VERSION=x.x.x GERBIL_VERSION=x.x.x BADGER_VERSION=x.x.x +# Or fetched automatically if not provided (requires curl and jq) + +PANGOLIN_VERSION ?= $(shell curl -s https://api.github.com/repos/fosrl/pangolin/tags | jq -r '.[0].name') +GERBIL_VERSION ?= $(shell curl -s https://api.github.com/repos/fosrl/gerbil/tags | jq -r '.[0].name') +BADGER_VERSION ?= $(shell curl -s https://api.github.com/repos/fosrl/badger/tags | jq -r '.[0].name') + +LDFLAGS = -X main.pangolinVersion=$(PANGOLIN_VERSION) \ + -X main.gerbilVersion=$(GERBIL_VERSION) \ + -X main.badgerVersion=$(BADGER_VERSION) go-build-release: - CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/installer_linux_amd64 - CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -o bin/installer_linux_arm64 + @echo "Building with versions - Pangolin: $(PANGOLIN_VERSION), Gerbil: $(GERBIL_VERSION), Badger: $(BADGER_VERSION)" + CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags "$(LDFLAGS)" -o bin/installer_linux_amd64 + CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -ldflags "$(LDFLAGS)" -o bin/installer_linux_arm64 clean: rm -f bin/installer_linux_amd64 rm -f bin/installer_linux_arm64 -update-versions: - @echo "Fetching latest versions..." - cp main.go main.go.bak && \ - $(MAKE) dev-update-versions - -put-back: - mv main.go.bak main.go - -dev-update-versions: - if [ -z "$(tag)" ]; then \ - PANGOLIN_VERSION=$$(curl -s https://api.github.com/repos/fosrl/pangolin/tags | jq -r '.[0].name'); \ - else \ - PANGOLIN_VERSION=$(tag); \ - fi && \ - GERBIL_VERSION=$$(curl -s https://api.github.com/repos/fosrl/gerbil/tags | jq -r '.[0].name') && \ - BADGER_VERSION=$$(curl -s https://api.github.com/repos/fosrl/badger/tags | jq -r '.[0].name') && \ - echo "Latest versions - Pangolin: $$PANGOLIN_VERSION, Gerbil: $$GERBIL_VERSION, Badger: $$BADGER_VERSION" && \ - sed -i "s/config.PangolinVersion = \".*\"/config.PangolinVersion = \"$$PANGOLIN_VERSION\"/" main.go && \ - sed -i "s/config.GerbilVersion = \".*\"/config.GerbilVersion = \"$$GERBIL_VERSION\"/" main.go && \ - sed -i "s/config.BadgerVersion = \".*\"/config.BadgerVersion = \"$$BADGER_VERSION\"/" main.go && \ - echo "Updated main.go with latest versions" - -dev-build: go-build-release - -dev-clean: - @echo "Restoring version values ..." - sed -i "s/config.PangolinVersion = \".*\"/config.PangolinVersion = \"replaceme\"/" main.go && \ - sed -i "s/config.GerbilVersion = \".*\"/config.GerbilVersion = \"replaceme\"/" main.go && \ - sed -i "s/config.BadgerVersion = \".*\"/config.BadgerVersion = \"replaceme\"/" main.go - @echo "Restored version strings in main.go" +.PHONY: all go-build-release clean diff --git a/install/config.go b/install/config.go index e75dd50d..548e2ab3 100644 --- a/install/config.go +++ b/install/config.go @@ -118,19 +118,19 @@ func copyDockerService(sourceFile, destFile, serviceName string) error { } // Parse source Docker Compose YAML - var sourceCompose map[string]interface{} + var sourceCompose map[string]any if err := yaml.Unmarshal(sourceData, &sourceCompose); err != nil { return fmt.Errorf("error parsing source Docker Compose file: %w", err) } // Parse destination Docker Compose YAML - var destCompose map[string]interface{} + var destCompose map[string]any if err := yaml.Unmarshal(destData, &destCompose); err != nil { return fmt.Errorf("error parsing destination Docker Compose file: %w", err) } // Get services section from source - sourceServices, ok := sourceCompose["services"].(map[string]interface{}) + sourceServices, ok := sourceCompose["services"].(map[string]any) if !ok { return fmt.Errorf("services section not found in source file or has invalid format") } @@ -142,10 +142,10 @@ func copyDockerService(sourceFile, destFile, serviceName string) error { } // Get or create services section in destination - destServices, ok := destCompose["services"].(map[string]interface{}) + destServices, ok := destCompose["services"].(map[string]any) if !ok { // If services section doesn't exist, create it - destServices = make(map[string]interface{}) + destServices = make(map[string]any) destCompose["services"] = destServices } @@ -187,13 +187,12 @@ func backupConfig() error { return nil } -func MarshalYAMLWithIndent(data interface{}, indent int) ([]byte, error) { +func MarshalYAMLWithIndent(data any, indent int) ([]byte, error) { buffer := new(bytes.Buffer) encoder := yaml.NewEncoder(buffer) encoder.SetIndent(indent) - err := encoder.Encode(data) - if err != nil { + if err := encoder.Encode(data); err != nil { return nil, err } @@ -209,7 +208,7 @@ func replaceInFile(filepath, oldStr, newStr string) error { } // Replace the string - newContent := strings.Replace(string(content), oldStr, newStr, -1) + newContent := strings.ReplaceAll(string(content), oldStr, newStr) // Write the modified content back to the file err = os.WriteFile(filepath, []byte(newContent), 0644) @@ -228,28 +227,28 @@ func CheckAndAddTraefikLogVolume(composePath string) error { } // Parse YAML into a generic map - var compose map[string]interface{} + var compose map[string]any if err := yaml.Unmarshal(data, &compose); err != nil { return fmt.Errorf("error parsing compose file: %w", err) } // Get services section - services, ok := compose["services"].(map[string]interface{}) + services, ok := compose["services"].(map[string]any) if !ok { return fmt.Errorf("services section not found or invalid") } // Get traefik service - traefik, ok := services["traefik"].(map[string]interface{}) + traefik, ok := services["traefik"].(map[string]any) if !ok { return fmt.Errorf("traefik service not found or invalid") } // Check volumes logVolume := "./config/traefik/logs:/var/log/traefik" - var volumes []interface{} + var volumes []any - if existingVolumes, ok := traefik["volumes"].([]interface{}); ok { + if existingVolumes, ok := traefik["volumes"].([]any); ok { // Check if volume already exists for _, v := range existingVolumes { if v.(string) == logVolume { @@ -295,13 +294,13 @@ func MergeYAML(baseFile, overlayFile string) error { } // Parse base YAML into a map - var baseMap map[string]interface{} + var baseMap map[string]any if err := yaml.Unmarshal(baseContent, &baseMap); err != nil { return fmt.Errorf("error parsing base YAML: %v", err) } // Parse overlay YAML into a map - var overlayMap map[string]interface{} + var overlayMap map[string]any if err := yaml.Unmarshal(overlayContent, &overlayMap); err != nil { return fmt.Errorf("error parsing overlay YAML: %v", err) } @@ -324,8 +323,8 @@ func MergeYAML(baseFile, overlayFile string) error { } // mergeMap recursively merges two maps -func mergeMap(base, overlay map[string]interface{}) map[string]interface{} { - result := make(map[string]interface{}) +func mergeMap(base, overlay map[string]any) map[string]any { + result := make(map[string]any) // Copy all key-values from base map for k, v := range base { @@ -336,8 +335,8 @@ func mergeMap(base, overlay map[string]interface{}) map[string]interface{} { for k, v := range overlay { // If both maps have the same key and both values are maps, merge recursively if baseVal, ok := base[k]; ok { - if baseMap, isBaseMap := baseVal.(map[string]interface{}); isBaseMap { - if overlayMap, isOverlayMap := v.(map[string]interface{}); isOverlayMap { + if baseMap, isBaseMap := baseVal.(map[string]any); isBaseMap { + if overlayMap, isOverlayMap := v.(map[string]any); isOverlayMap { result[k] = mergeMap(baseMap, overlayMap) continue } diff --git a/install/containers.go b/install/containers.go index 333fd890..b5d18423 100644 --- a/install/containers.go +++ b/install/containers.go @@ -144,12 +144,13 @@ func installDocker() error { } func startDockerService() error { - if runtime.GOOS == "linux" { + switch runtime.GOOS { + case "linux": cmd := exec.Command("systemctl", "enable", "--now", "docker") cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr return cmd.Run() - } else if runtime.GOOS == "darwin" { + case "darwin": // On macOS, Docker is usually started via the Docker Desktop application fmt.Println("Please start Docker Desktop manually on macOS.") return nil @@ -302,7 +303,7 @@ func pullContainers(containerType SupportedContainer) error { return nil } - return fmt.Errorf("Unsupported container type: %s", containerType) + return fmt.Errorf("unsupported container type: %s", containerType) } // startContainers starts the containers using the appropriate command. @@ -325,7 +326,7 @@ func startContainers(containerType SupportedContainer) error { return nil } - return fmt.Errorf("Unsupported container type: %s", containerType) + return fmt.Errorf("unsupported container type: %s", containerType) } // stopContainers stops the containers using the appropriate command. @@ -347,7 +348,7 @@ func stopContainers(containerType SupportedContainer) error { return nil } - return fmt.Errorf("Unsupported container type: %s", containerType) + return fmt.Errorf("unsupported container type: %s", containerType) } // restartContainer restarts a specific container using the appropriate command. @@ -369,5 +370,5 @@ func restartContainer(container string, containerType SupportedContainer) error return nil } - return fmt.Errorf("Unsupported container type: %s", containerType) + return fmt.Errorf("unsupported container type: %s", containerType) } diff --git a/install/crowdsec.go b/install/crowdsec.go index 401ef215..c75dccf3 100644 --- a/install/crowdsec.go +++ b/install/crowdsec.go @@ -27,9 +27,18 @@ func installCrowdsec(config Config) error { os.Exit(1) } - os.MkdirAll("config/crowdsec/db", 0755) - os.MkdirAll("config/crowdsec/acquis.d", 0755) - os.MkdirAll("config/traefik/logs", 0755) + if err := os.MkdirAll("config/crowdsec/db", 0755); err != nil { + fmt.Printf("Error creating config files: %v\n", err) + os.Exit(1) + } + if err := os.MkdirAll("config/crowdsec/acquis.d", 0755); err != nil { + fmt.Printf("Error creating config files: %v\n", err) + os.Exit(1) + } + if err := os.MkdirAll("config/traefik/logs", 0755); err != nil { + fmt.Printf("Error creating config files: %v\n", err) + os.Exit(1) + } if err := copyDockerService("config/crowdsec/docker-compose.yml", "docker-compose.yml", "crowdsec"); err != nil { fmt.Printf("Error copying docker service: %v\n", err) @@ -153,34 +162,34 @@ func CheckAndAddCrowdsecDependency(composePath string) error { } // Parse YAML into a generic map - var compose map[string]interface{} + var compose map[string]any if err := yaml.Unmarshal(data, &compose); err != nil { return fmt.Errorf("error parsing compose file: %w", err) } // Get services section - services, ok := compose["services"].(map[string]interface{}) + services, ok := compose["services"].(map[string]any) if !ok { return fmt.Errorf("services section not found or invalid") } // Get traefik service - traefik, ok := services["traefik"].(map[string]interface{}) + traefik, ok := services["traefik"].(map[string]any) if !ok { return fmt.Errorf("traefik service not found or invalid") } // Get dependencies - dependsOn, ok := traefik["depends_on"].(map[string]interface{}) + dependsOn, ok := traefik["depends_on"].(map[string]any) if ok { // Append the new block for crowdsec - dependsOn["crowdsec"] = map[string]interface{}{ + dependsOn["crowdsec"] = map[string]any{ "condition": "service_healthy", } } else { // No dependencies exist, create it - traefik["depends_on"] = map[string]interface{}{ - "crowdsec": map[string]interface{}{ + traefik["depends_on"] = map[string]any{ + "crowdsec": map[string]any{ "condition": "service_healthy", }, } diff --git a/install/main.go b/install/main.go index 7ea1a8e1..9de332b6 100644 --- a/install/main.go +++ b/install/main.go @@ -19,11 +19,17 @@ import ( "time" ) -// DO NOT EDIT THIS FUNCTION; IT MATCHED BY REGEX IN CICD +// Version variables injected at build time via -ldflags +var ( + pangolinVersion string + gerbilVersion string + badgerVersion string +) + func loadVersions(config *Config) { - config.PangolinVersion = "replaceme" - config.GerbilVersion = "replaceme" - config.BadgerVersion = "replaceme" + config.PangolinVersion = pangolinVersion + config.GerbilVersion = gerbilVersion + config.BadgerVersion = badgerVersion } //go:embed config/* @@ -99,7 +105,10 @@ func main() { os.Exit(1) } - moveFile("config/docker-compose.yml", "docker-compose.yml") + if err := moveFile("config/docker-compose.yml", "docker-compose.yml"); err != nil { + fmt.Printf("Error moving docker-compose.yml: %v\n", err) + os.Exit(1) + } fmt.Println("\nConfiguration files created successfully!") @@ -120,7 +129,11 @@ func main() { if !isDockerInstalled() && runtime.GOOS == "linux" && config.InstallationContainerType == Docker { if readBool("Docker is not installed. Would you like to install it?", true) { - installDocker() + if err := installDocker(); err != nil { + fmt.Printf("Error installing Docker: %v\n", err) + return + } + // try to start docker service but ignore errors if err := startDockerService(); err != nil { fmt.Println("Error starting Docker service:", err) @@ -129,7 +142,7 @@ func main() { } // wait 10 seconds for docker to start checking if docker is running every 2 seconds fmt.Println("Waiting for Docker to start...") - for i := 0; i < 5; i++ { + for range 5 { if isDockerRunning() { fmt.Println("Docker is running!") break @@ -287,7 +300,8 @@ func podmanOrDocker() SupportedContainer { os.Exit(1) } - if chosenContainer == Podman { + switch chosenContainer { + case Podman: if !isPodmanInstalled() { fmt.Println("Podman or podman-compose is not installed. Please install both manually. Automated installation will be available in a later release.") os.Exit(1) @@ -308,7 +322,7 @@ func podmanOrDocker() SupportedContainer { // Linux only. if err := run("bash", "-c", "echo 'net.ipv4.ip_unprivileged_port_start=80' > /etc/sysctl.d/99-podman.conf && sysctl --system"); err != nil { - fmt.Printf("Error configuring unprivileged ports: %v\n", err) + fmt.Printf("Error configuring unprivileged ports: %v\n", err) os.Exit(1) } } else { @@ -318,7 +332,7 @@ func podmanOrDocker() SupportedContainer { fmt.Println("Unprivileged ports have been configured.") } - } else if chosenContainer == Docker { + case Docker: // check if docker is not installed and the user is root if !isDockerInstalled() { if os.Geteuid() != 0 { @@ -333,7 +347,7 @@ func podmanOrDocker() SupportedContainer { fmt.Println("The installer will not be able to run docker commands without running it as root.") os.Exit(1) } - } else { + default: // This shouldn't happen unless there's a third container runtime. os.Exit(1) } @@ -402,10 +416,18 @@ func collectUserInput() Config { } func createConfigFiles(config Config) error { - os.MkdirAll("config", 0755) - os.MkdirAll("config/letsencrypt", 0755) - os.MkdirAll("config/db", 0755) - os.MkdirAll("config/logs", 0755) + if err := os.MkdirAll("config", 0755); err != nil { + return fmt.Errorf("failed to create config directory: %v", err) + } + if err := os.MkdirAll("config/letsencrypt", 0755); err != nil { + return fmt.Errorf("failed to create letsencrypt directory: %v", err) + } + if err := os.MkdirAll("config/db", 0755); err != nil { + return fmt.Errorf("failed to create db directory: %v", err) + } + if err := os.MkdirAll("config/logs", 0755); err != nil { + return fmt.Errorf("failed to create logs directory: %v", err) + } // Walk through all embedded files err := fs.WalkDir(configFiles, "config", func(path string, d fs.DirEntry, err error) error { @@ -559,22 +581,24 @@ func showSetupTokenInstructions(containerType SupportedContainer, dashboardDomai fmt.Println("To get your setup token, you need to:") fmt.Println("") fmt.Println("1. Start the containers") - if containerType == Docker { + switch containerType { + case Docker: fmt.Println(" docker compose up -d") - } else if containerType == Podman { + case Podman: fmt.Println(" podman-compose up -d") - } else { } + fmt.Println("") fmt.Println("2. Wait for the Pangolin container to start and generate the token") fmt.Println("") fmt.Println("3. Check the container logs for the setup token") - if containerType == Docker { + switch containerType { + case Docker: fmt.Println(" docker logs pangolin | grep -A 2 -B 2 'SETUP TOKEN'") - } else if containerType == Podman { + case Podman: fmt.Println(" podman logs pangolin | grep -A 2 -B 2 'SETUP TOKEN'") - } else { } + fmt.Println("") fmt.Println("4. Look for output like") fmt.Println(" === SETUP TOKEN GENERATED ===") @@ -636,10 +660,7 @@ func checkPortsAvailable(port int) error { addr := fmt.Sprintf(":%d", port) ln, err := net.Listen("tcp", addr) if err != nil { - return fmt.Errorf( - "ERROR: port %d is occupied or cannot be bound: %w\n\n", - port, err, - ) + return fmt.Errorf("ERROR: port %d is occupied or cannot be bound: %w", port, err) } if closeErr := ln.Close(); closeErr != nil { fmt.Fprintf(os.Stderr, diff --git a/messages/en-US.json b/messages/en-US.json index d872d8e3..2dfa496f 100644 --- a/messages/en-US.json +++ b/messages/en-US.json @@ -649,7 +649,8 @@ "resourcesUsersRolesAccess": "User and role-based access control", "resourcesErrorUpdate": "Failed to toggle resource", "resourcesErrorUpdateDescription": "An error occurred while updating the resource", - "access": "Access Control", + "access": "Access", + "accessControl": "Access Control", "shareLink": "{resource} Share Link", "resourceSelect": "Select resource", "shareLinks": "Share Links", diff --git a/server/auth/sessions/resource.ts b/server/auth/sessions/resource.ts index 3b9da3d7..a1ae1337 100644 --- a/server/auth/sessions/resource.ts +++ b/server/auth/sessions/resource.ts @@ -87,7 +87,7 @@ export async function validateResourceSessionToken( if (Date.now() >= resourceSession.expiresAt) { await db .delete(resourceSessions) - .where(eq(resourceSessions.sessionId, resourceSessions.sessionId)); + .where(eq(resourceSessions.sessionId, sessionId)); return { resourceSession: null }; } else if ( Date.now() >= @@ -181,7 +181,7 @@ export function serializeResourceSessionCookie( return `${cookieName}_s.${now}=${token}; HttpOnly; SameSite=Lax; Expires=${expiresAt.toUTCString()}; Path=/; Secure; Domain=${domain}`; } else { if (expiresAt === undefined) { - return `${cookieName}.${now}=${token}; HttpOnly; SameSite=Lax; Path=/; Domain=$domain}`; + return `${cookieName}.${now}=${token}; HttpOnly; SameSite=Lax; Path=/; Domain=${domain}`; } return `${cookieName}.${now}=${token}; HttpOnly; SameSite=Lax; Expires=${expiresAt.toUTCString()}; Path=/; Domain=${domain}`; } diff --git a/server/db/pg/index.ts b/server/db/pg/index.ts index 86e31802..f8c04ac9 100644 --- a/server/db/pg/index.ts +++ b/server/db/pg/index.ts @@ -1,4 +1,5 @@ export * from "./driver"; +export * from "./logsDriver"; export * from "./safeRead"; export * from "./schema/schema"; export * from "./schema/privateSchema"; diff --git a/server/db/pg/logsDriver.ts b/server/db/pg/logsDriver.ts new file mode 100644 index 00000000..3e4c2b7e --- /dev/null +++ b/server/db/pg/logsDriver.ts @@ -0,0 +1,89 @@ +import { drizzle as DrizzlePostgres } from "drizzle-orm/node-postgres"; +import { Pool } from "pg"; +import { readConfigFile } from "@server/lib/readConfigFile"; +import { readPrivateConfigFile } from "@server/private/lib/readConfigFile"; +import { withReplicas } from "drizzle-orm/pg-core"; +import { build } from "@server/build"; +import { db as mainDb, primaryDb as mainPrimaryDb } from "./driver"; + +function createLogsDb() { + // Only use separate logs database in SaaS builds + if (build !== "saas") { + return mainDb; + } + + const config = readConfigFile(); + const privateConfig = readPrivateConfigFile(); + + // Merge configs, prioritizing private config + const logsConfig = privateConfig.postgres_logs || config.postgres_logs; + + // Check environment variable first + let connectionString = process.env.POSTGRES_LOGS_CONNECTION_STRING; + let replicaConnections: Array<{ connection_string: string }> = []; + + if (!connectionString && logsConfig) { + connectionString = logsConfig.connection_string; + replicaConnections = logsConfig.replicas || []; + } + + // If POSTGRES_LOGS_REPLICA_CONNECTION_STRINGS is set, use it + if (process.env.POSTGRES_LOGS_REPLICA_CONNECTION_STRINGS) { + replicaConnections = + process.env.POSTGRES_LOGS_REPLICA_CONNECTION_STRINGS.split(",").map( + (conn) => ({ + connection_string: conn.trim() + }) + ); + } + + // If no logs database is configured, fall back to main database + if (!connectionString) { + return mainDb; + } + + // Create separate connection pool for logs database + const poolConfig = logsConfig?.pool || config.postgres?.pool; + const primaryPool = new Pool({ + connectionString, + max: poolConfig?.max_connections || 20, + idleTimeoutMillis: poolConfig?.idle_timeout_ms || 30000, + connectionTimeoutMillis: poolConfig?.connection_timeout_ms || 5000 + }); + + const replicas = []; + + if (!replicaConnections.length) { + replicas.push( + DrizzlePostgres(primaryPool, { + logger: process.env.QUERY_LOGGING == "true" + }) + ); + } else { + for (const conn of replicaConnections) { + const replicaPool = new Pool({ + connectionString: conn.connection_string, + max: poolConfig?.max_replica_connections || 20, + idleTimeoutMillis: poolConfig?.idle_timeout_ms || 30000, + connectionTimeoutMillis: + poolConfig?.connection_timeout_ms || 5000 + }); + replicas.push( + DrizzlePostgres(replicaPool, { + logger: process.env.QUERY_LOGGING == "true" + }) + ); + } + } + + return withReplicas( + DrizzlePostgres(primaryPool, { + logger: process.env.QUERY_LOGGING == "true" + }), + replicas as any + ); +} + +export const logsDb = createLogsDb(); +export default logsDb; +export const primaryLogsDb = logsDb.$primary; \ No newline at end of file diff --git a/server/db/sqlite/index.ts b/server/db/sqlite/index.ts index 86e31802..f8c04ac9 100644 --- a/server/db/sqlite/index.ts +++ b/server/db/sqlite/index.ts @@ -1,4 +1,5 @@ export * from "./driver"; +export * from "./logsDriver"; export * from "./safeRead"; export * from "./schema/schema"; export * from "./schema/privateSchema"; diff --git a/server/db/sqlite/logsDriver.ts b/server/db/sqlite/logsDriver.ts new file mode 100644 index 00000000..f70c79fc --- /dev/null +++ b/server/db/sqlite/logsDriver.ts @@ -0,0 +1,7 @@ +import { db as mainDb } from "./driver"; + +// SQLite doesn't support separate databases for logs in the same way as Postgres +// Always use the main database connection for SQLite +export const logsDb = mainDb; +export default logsDb; +export const primaryLogsDb = logsDb; \ No newline at end of file diff --git a/server/internalServer.ts b/server/internalServer.ts index d15e3c45..7ba046e4 100644 --- a/server/internalServer.ts +++ b/server/internalServer.ts @@ -16,6 +16,11 @@ const internalPort = config.getRawConfig().server.internal_port; export function createInternalServer() { const internalServer = express(); + const trustProxy = config.getRawConfig().server.trust_proxy; + if (trustProxy) { + internalServer.set("trust proxy", trustProxy); + } + internalServer.use(helmet()); internalServer.use(cors()); internalServer.use(stripDuplicateSesions); diff --git a/server/lib/billing/usageService.ts b/server/lib/billing/usageService.ts index d7299284..74241a4c 100644 --- a/server/lib/billing/usageService.ts +++ b/server/lib/billing/usageService.ts @@ -230,7 +230,7 @@ export class UsageService { const orgIdToUse = await this.getBillingOrg(orgId); const cacheKey = `customer_${orgIdToUse}_${featureId}`; - const cached = cache.get(cacheKey); + const cached = await cache.get(cacheKey); if (cached) { return cached; @@ -253,7 +253,7 @@ export class UsageService { const customerId = customer.customerId; // Cache the result - cache.set(cacheKey, customerId, 300); // 5 minute TTL + await cache.set(cacheKey, customerId, 300); // 5 minute TTL return customerId; } catch (error) { diff --git a/server/lib/blueprints/clientResources.ts b/server/lib/blueprints/clientResources.ts index 64de9867..80c691c6 100644 --- a/server/lib/blueprints/clientResources.ts +++ b/server/lib/blueprints/clientResources.ts @@ -11,7 +11,7 @@ import { userSiteResources } from "@server/db"; import { sites } from "@server/db"; -import { eq, and, ne, inArray } from "drizzle-orm"; +import { eq, and, ne, inArray, or } from "drizzle-orm"; import { Config } from "./types"; import logger from "@server/logger"; import { getNextAvailableAliasAddress } from "../ip"; @@ -142,7 +142,10 @@ export async function updateClientResources( .innerJoin(userOrgs, eq(users.userId, userOrgs.userId)) .where( and( - inArray(users.username, resourceData.users), + or( + inArray(users.username, resourceData.users), + inArray(users.email, resourceData.users) + ), eq(userOrgs.orgId, orgId) ) ); @@ -276,7 +279,10 @@ export async function updateClientResources( .innerJoin(userOrgs, eq(users.userId, userOrgs.userId)) .where( and( - inArray(users.username, resourceData.users), + or( + inArray(users.username, resourceData.users), + inArray(users.email, resourceData.users) + ), eq(userOrgs.orgId, orgId) ) ); diff --git a/server/lib/blueprints/proxyResources.ts b/server/lib/blueprints/proxyResources.ts index 55a7712b..2696b68c 100644 --- a/server/lib/blueprints/proxyResources.ts +++ b/server/lib/blueprints/proxyResources.ts @@ -212,7 +212,10 @@ export async function updateProxyResources( } else { // Update existing resource - const isLicensed = await isLicensedOrSubscribed(orgId, tierMatrix.maintencePage); + const isLicensed = await isLicensedOrSubscribed( + orgId, + tierMatrix.maintencePage + ); if (!isLicensed) { resourceData.maintenance = undefined; } @@ -590,7 +593,10 @@ export async function updateProxyResources( existingRule.action !== getRuleAction(rule.action) || existingRule.match !== rule.match.toUpperCase() || existingRule.value !== - getRuleValue(rule.match.toUpperCase(), rule.value) || + getRuleValue( + rule.match.toUpperCase(), + rule.value + ) || existingRule.priority !== intendedPriority ) { validateRule(rule); @@ -648,7 +654,10 @@ export async function updateProxyResources( ); } - const isLicensed = await isLicensedOrSubscribed(orgId, tierMatrix.maintencePage); + const isLicensed = await isLicensedOrSubscribed( + orgId, + tierMatrix.maintencePage + ); if (!isLicensed) { resourceData.maintenance = undefined; } @@ -935,7 +944,12 @@ async function syncUserResources( .select() .from(users) .innerJoin(userOrgs, eq(users.userId, userOrgs.userId)) - .where(and(eq(users.username, username), eq(userOrgs.orgId, orgId))) + .where( + and( + or(eq(users.username, username), eq(users.email, username)), + eq(userOrgs.orgId, orgId) + ) + ) .limit(1); if (!user) { diff --git a/server/lib/blueprints/types.ts b/server/lib/blueprints/types.ts index edf4b0c7..2239e4f9 100644 --- a/server/lib/blueprints/types.ts +++ b/server/lib/blueprints/types.ts @@ -69,7 +69,7 @@ export const AuthSchema = z.object({ .refine((roles) => !roles.includes("Admin"), { error: "Admin role cannot be included in sso-roles" }), - "sso-users": z.array(z.email()).optional().default([]), + "sso-users": z.array(z.string()).optional().default([]), "whitelist-users": z.array(z.email()).optional().default([]), "auto-login-idp": z.int().positive().optional() }); @@ -335,7 +335,7 @@ export const ClientResourceSchema = z .refine((roles) => !roles.includes("Admin"), { error: "Admin role cannot be included in roles" }), - users: z.array(z.email()).optional().default([]), + users: z.array(z.string()).optional().default([]), machines: z.array(z.string()).optional().default([]) }) .refine( diff --git a/server/lib/cache.ts b/server/lib/cache.ts index 4910d945..1d8c2453 100644 --- a/server/lib/cache.ts +++ b/server/lib/cache.ts @@ -1,9 +1,10 @@ import NodeCache from "node-cache"; import logger from "@server/logger"; +import { redisManager } from "@server/private/lib/redis"; -// Create cache with maxKeys limit to prevent memory leaks +// Create local cache with maxKeys limit to prevent memory leaks // With ~10k requests/day and 5min TTL, 10k keys should be more than sufficient -export const cache = new NodeCache({ +export const localCache = new NodeCache({ stdTTL: 3600, checkperiod: 120, maxKeys: 10000 @@ -11,10 +12,255 @@ export const cache = new NodeCache({ // Log cache statistics periodically for monitoring setInterval(() => { - const stats = cache.getStats(); + const stats = localCache.getStats(); logger.debug( - `Cache stats - Keys: ${stats.keys}, Hits: ${stats.hits}, Misses: ${stats.misses}, Hit rate: ${stats.hits > 0 ? ((stats.hits / (stats.hits + stats.misses)) * 100).toFixed(2) : 0}%` + `Local cache stats - Keys: ${stats.keys}, Hits: ${stats.hits}, Misses: ${stats.misses}, Hit rate: ${stats.hits > 0 ? ((stats.hits / (stats.hits + stats.misses)) * 100).toFixed(2) : 0}%` ); }, 300000); // Every 5 minutes +/** + * Adaptive cache that uses Redis when available in multi-node environments, + * otherwise falls back to local memory cache for single-node deployments. + */ +class AdaptiveCache { + private useRedis(): boolean { + return redisManager.isRedisEnabled() && redisManager.getHealthStatus().isHealthy; + } + + /** + * Set a value in the cache + * @param key - Cache key + * @param value - Value to cache (will be JSON stringified for Redis) + * @param ttl - Time to live in seconds (0 = no expiration) + * @returns boolean indicating success + */ + async set(key: string, value: any, ttl?: number): Promise { + const effectiveTtl = ttl === 0 ? undefined : ttl; + + if (this.useRedis()) { + try { + const serialized = JSON.stringify(value); + const success = await redisManager.set(key, serialized, effectiveTtl); + + if (success) { + logger.debug(`Set key in Redis: ${key}`); + return true; + } + + // Redis failed, fall through to local cache + logger.debug(`Redis set failed for key ${key}, falling back to local cache`); + } catch (error) { + logger.error(`Redis set error for key ${key}:`, error); + // Fall through to local cache + } + } + + // Use local cache as fallback or primary + const success = localCache.set(key, value, effectiveTtl || 0); + if (success) { + logger.debug(`Set key in local cache: ${key}`); + } + return success; + } + + /** + * Get a value from the cache + * @param key - Cache key + * @returns The cached value or undefined if not found + */ + async get(key: string): Promise { + if (this.useRedis()) { + try { + const value = await redisManager.get(key); + + if (value !== null) { + logger.debug(`Cache hit in Redis: ${key}`); + return JSON.parse(value) as T; + } + + logger.debug(`Cache miss in Redis: ${key}`); + return undefined; + } catch (error) { + logger.error(`Redis get error for key ${key}:`, error); + // Fall through to local cache + } + } + + // Use local cache as fallback or primary + const value = localCache.get(key); + if (value !== undefined) { + logger.debug(`Cache hit in local cache: ${key}`); + } else { + logger.debug(`Cache miss in local cache: ${key}`); + } + return value; + } + + /** + * Delete a value from the cache + * @param key - Cache key or array of keys + * @returns Number of deleted entries + */ + async del(key: string | string[]): Promise { + const keys = Array.isArray(key) ? key : [key]; + let deletedCount = 0; + + if (this.useRedis()) { + try { + for (const k of keys) { + const success = await redisManager.del(k); + if (success) { + deletedCount++; + logger.debug(`Deleted key from Redis: ${k}`); + } + } + + if (deletedCount === keys.length) { + return deletedCount; + } + + // Some Redis deletes failed, fall through to local cache + logger.debug(`Some Redis deletes failed, falling back to local cache`); + } catch (error) { + logger.error(`Redis del error for keys ${keys.join(", ")}:`, error); + // Fall through to local cache + deletedCount = 0; + } + } + + // Use local cache as fallback or primary + for (const k of keys) { + const success = localCache.del(k); + if (success > 0) { + deletedCount++; + logger.debug(`Deleted key from local cache: ${k}`); + } + } + + return deletedCount; + } + + /** + * Check if a key exists in the cache + * @param key - Cache key + * @returns boolean indicating if key exists + */ + async has(key: string): Promise { + if (this.useRedis()) { + try { + const value = await redisManager.get(key); + return value !== null; + } catch (error) { + logger.error(`Redis has error for key ${key}:`, error); + // Fall through to local cache + } + } + + // Use local cache as fallback or primary + return localCache.has(key); + } + + /** + * Get multiple values from the cache + * @param keys - Array of cache keys + * @returns Array of values (undefined for missing keys) + */ + async mget(keys: string[]): Promise<(T | undefined)[]> { + if (this.useRedis()) { + try { + const results: (T | undefined)[] = []; + + for (const key of keys) { + const value = await redisManager.get(key); + if (value !== null) { + results.push(JSON.parse(value) as T); + } else { + results.push(undefined); + } + } + + return results; + } catch (error) { + logger.error(`Redis mget error:`, error); + // Fall through to local cache + } + } + + // Use local cache as fallback or primary + return keys.map((key) => localCache.get(key)); + } + + /** + * Flush all keys from the cache + */ + async flushAll(): Promise { + if (this.useRedis()) { + logger.warn("Adaptive cache flushAll called - Redis flush not implemented, only local cache will be flushed"); + } + + localCache.flushAll(); + logger.debug("Flushed local cache"); + } + + /** + * Get cache statistics + * Note: Only returns local cache stats, Redis stats are not included + */ + getStats() { + return localCache.getStats(); + } + + /** + * Get the current cache backend being used + * @returns "redis" if Redis is available and healthy, "local" otherwise + */ + getCurrentBackend(): "redis" | "local" { + return this.useRedis() ? "redis" : "local"; + } + + /** + * Take a key from the cache and delete it + * @param key - Cache key + * @returns The value or undefined if not found + */ + async take(key: string): Promise { + const value = await this.get(key); + if (value !== undefined) { + await this.del(key); + } + return value; + } + + /** + * Get TTL (time to live) for a key + * @param key - Cache key + * @returns TTL in seconds, 0 if no expiration, -1 if key doesn't exist + */ + getTtl(key: string): number { + // Note: This only works for local cache, Redis TTL is not supported + if (this.useRedis()) { + logger.warn(`getTtl called for key ${key} but Redis TTL lookup is not implemented`); + } + + const ttl = localCache.getTtl(key); + if (ttl === undefined) { + return -1; + } + return Math.max(0, Math.floor((ttl - Date.now()) / 1000)); + } + + /** + * Get all keys from the cache + * Note: Only returns local cache keys, Redis keys are not included + */ + keys(): string[] { + if (this.useRedis()) { + logger.warn("keys() called but Redis keys are not included, only local cache keys returned"); + } + return localCache.keys(); + } +} + +// Export singleton instance +export const cache = new AdaptiveCache(); export default cache; diff --git a/server/lib/readConfigFile.ts b/server/lib/readConfigFile.ts index bfca5970..cca0aa6a 100644 --- a/server/lib/readConfigFile.ts +++ b/server/lib/readConfigFile.ts @@ -189,6 +189,46 @@ export const configSchema = z .prefault({}) }) .optional(), + postgres_logs: z + .object({ + connection_string: z + .string() + .optional() + .transform(getEnvOrYaml("POSTGRES_LOGS_CONNECTION_STRING")), + replicas: z + .array( + z.object({ + connection_string: z.string() + }) + ) + .optional(), + pool: z + .object({ + max_connections: z + .number() + .positive() + .optional() + .default(20), + max_replica_connections: z + .number() + .positive() + .optional() + .default(10), + idle_timeout_ms: z + .number() + .positive() + .optional() + .default(30000), + connection_timeout_ms: z + .number() + .positive() + .optional() + .default(5000) + }) + .optional() + .prefault({}) + }) + .optional(), traefik: z .object({ http_entrypoint: z.string().optional().default("web"), diff --git a/server/private/lib/certificates.ts b/server/private/lib/certificates.ts index bc1dffcd..c113ddd9 100644 --- a/server/private/lib/certificates.ts +++ b/server/private/lib/certificates.ts @@ -55,7 +55,7 @@ export async function getValidCertificatesForDomains( if (useCache) { for (const domain of domains) { const cacheKey = `cert:${domain}`; - const cachedCert = cache.get(cacheKey); + const cachedCert = await cache.get(cacheKey); if (cachedCert) { finalResults.push(cachedCert); // Valid cache hit } else { @@ -169,7 +169,7 @@ export async function getValidCertificatesForDomains( // Add to cache for future requests, using the *requested domain* as the key if (useCache) { const cacheKey = `cert:${domain}`; - cache.set(cacheKey, resultCert, 180); + await cache.set(cacheKey, resultCert, 180); } } } diff --git a/server/private/lib/logAccessAudit.ts b/server/private/lib/logAccessAudit.ts index 33dcaf1f..88e553ad 100644 --- a/server/private/lib/logAccessAudit.ts +++ b/server/private/lib/logAccessAudit.ts @@ -11,7 +11,7 @@ * This file is not licensed under the AGPLv3. */ -import { accessAuditLog, db, orgs } from "@server/db"; +import { accessAuditLog, logsDb, db, orgs } from "@server/db"; import { getCountryCodeForIp } from "@server/lib/geoip"; import logger from "@server/logger"; import { and, eq, lt } from "drizzle-orm"; @@ -21,7 +21,7 @@ import { stripPortFromHost } from "@server/lib/ip"; async function getAccessDays(orgId: string): Promise { // check cache first - const cached = cache.get(`org_${orgId}_accessDays`); + const cached = await cache.get(`org_${orgId}_accessDays`); if (cached !== undefined) { return cached; } @@ -39,7 +39,7 @@ async function getAccessDays(orgId: string): Promise { } // store the result in cache - cache.set( + await cache.set( `org_${orgId}_accessDays`, org.settingsLogRetentionDaysAction, 300 @@ -52,7 +52,7 @@ export async function cleanUpOldLogs(orgId: string, retentionDays: number) { const cutoffTimestamp = calculateCutoffTimestamp(retentionDays); try { - await db + await logsDb .delete(accessAuditLog) .where( and( @@ -124,7 +124,7 @@ export async function logAccessAudit(data: { ? await getCountryCodeFromIp(data.requestIp) : undefined; - await db.insert(accessAuditLog).values({ + await logsDb.insert(accessAuditLog).values({ timestamp: timestamp, orgId: data.orgId, actorType, @@ -146,14 +146,14 @@ export async function logAccessAudit(data: { async function getCountryCodeFromIp(ip: string): Promise { const geoIpCacheKey = `geoip_access:${ip}`; - let cachedCountryCode: string | undefined = cache.get(geoIpCacheKey); + let cachedCountryCode: string | undefined = await cache.get(geoIpCacheKey); if (!cachedCountryCode) { cachedCountryCode = await getCountryCodeForIp(ip); // do it locally // Only cache successful lookups to avoid filling cache with undefined values if (cachedCountryCode) { // Cache for longer since IP geolocation doesn't change frequently - cache.set(geoIpCacheKey, cachedCountryCode, 300); // 5 minutes + await cache.set(geoIpCacheKey, cachedCountryCode, 300); // 5 minutes } } diff --git a/server/private/lib/readConfigFile.ts b/server/private/lib/readConfigFile.ts index a9de84e8..a19a1b65 100644 --- a/server/private/lib/readConfigFile.ts +++ b/server/private/lib/readConfigFile.ts @@ -83,6 +83,46 @@ export const privateConfigSchema = z.object({ .optional() }) .optional(), + postgres_logs: z + .object({ + connection_string: z + .string() + .optional() + .transform(getEnvOrYaml("POSTGRES_LOGS_CONNECTION_STRING")), + replicas: z + .array( + z.object({ + connection_string: z.string() + }) + ) + .optional(), + pool: z + .object({ + max_connections: z + .number() + .positive() + .optional() + .default(20), + max_replica_connections: z + .number() + .positive() + .optional() + .default(10), + idle_timeout_ms: z + .number() + .positive() + .optional() + .default(30000), + connection_timeout_ms: z + .number() + .positive() + .optional() + .default(5000) + }) + .optional() + .prefault({}) + }) + .optional(), gerbil: z .object({ local_exit_node_reachable_at: z diff --git a/server/private/middlewares/logActionAudit.ts b/server/private/middlewares/logActionAudit.ts index 17cc67c0..d0474dc3 100644 --- a/server/private/middlewares/logActionAudit.ts +++ b/server/private/middlewares/logActionAudit.ts @@ -12,7 +12,7 @@ */ import { ActionsEnum } from "@server/auth/actions"; -import { actionAuditLog, db, orgs } from "@server/db"; +import { actionAuditLog, logsDb, db, orgs } from "@server/db"; import logger from "@server/logger"; import HttpCode from "@server/types/HttpCode"; import { Request, Response, NextFunction } from "express"; @@ -23,7 +23,7 @@ import { calculateCutoffTimestamp } from "@server/lib/cleanupLogs"; async function getActionDays(orgId: string): Promise { // check cache first - const cached = cache.get(`org_${orgId}_actionDays`); + const cached = await cache.get(`org_${orgId}_actionDays`); if (cached !== undefined) { return cached; } @@ -41,7 +41,7 @@ async function getActionDays(orgId: string): Promise { } // store the result in cache - cache.set( + await cache.set( `org_${orgId}_actionDays`, org.settingsLogRetentionDaysAction, 300 @@ -54,7 +54,7 @@ export async function cleanUpOldLogs(orgId: string, retentionDays: number) { const cutoffTimestamp = calculateCutoffTimestamp(retentionDays); try { - await db + await logsDb .delete(actionAuditLog) .where( and( @@ -123,7 +123,7 @@ export function logActionAudit(action: ActionsEnum) { metadata = JSON.stringify(req.params); } - await db.insert(actionAuditLog).values({ + await logsDb.insert(actionAuditLog).values({ timestamp, orgId, actorType, diff --git a/server/private/routers/auditLogs/queryAccessAuditLog.ts b/server/private/routers/auditLogs/queryAccessAuditLog.ts index 96d241fb..7830dd9d 100644 --- a/server/private/routers/auditLogs/queryAccessAuditLog.ts +++ b/server/private/routers/auditLogs/queryAccessAuditLog.ts @@ -11,11 +11,11 @@ * This file is not licensed under the AGPLv3. */ -import { accessAuditLog, db, resources } from "@server/db"; +import { accessAuditLog, logsDb, resources, db, primaryDb } from "@server/db"; import { registry } from "@server/openApi"; import { NextFunction } from "express"; import { Request, Response } from "express"; -import { eq, gt, lt, and, count, desc } from "drizzle-orm"; +import { eq, gt, lt, and, count, desc, inArray } from "drizzle-orm"; import { OpenAPITags } from "@server/openApi"; import { z } from "zod"; import createHttpError from "http-errors"; @@ -115,15 +115,13 @@ function getWhere(data: Q) { } export function queryAccess(data: Q) { - return db + return logsDb .select({ orgId: accessAuditLog.orgId, action: accessAuditLog.action, actorType: accessAuditLog.actorType, actorId: accessAuditLog.actorId, resourceId: accessAuditLog.resourceId, - resourceName: resources.name, - resourceNiceId: resources.niceId, ip: accessAuditLog.ip, location: accessAuditLog.location, userAgent: accessAuditLog.userAgent, @@ -133,16 +131,46 @@ export function queryAccess(data: Q) { actor: accessAuditLog.actor }) .from(accessAuditLog) - .leftJoin( - resources, - eq(accessAuditLog.resourceId, resources.resourceId) - ) .where(getWhere(data)) .orderBy(desc(accessAuditLog.timestamp), desc(accessAuditLog.id)); } +async function enrichWithResourceDetails(logs: Awaited>) { + // If logs database is the same as main database, we can do a join + // Otherwise, we need to fetch resource details separately + const resourceIds = logs + .map(log => log.resourceId) + .filter((id): id is number => id !== null && id !== undefined); + + if (resourceIds.length === 0) { + return logs.map(log => ({ ...log, resourceName: null, resourceNiceId: null })); + } + + // Fetch resource details from main database + const resourceDetails = await primaryDb + .select({ + resourceId: resources.resourceId, + name: resources.name, + niceId: resources.niceId + }) + .from(resources) + .where(inArray(resources.resourceId, resourceIds)); + + // Create a map for quick lookup + const resourceMap = new Map( + resourceDetails.map(r => [r.resourceId, { name: r.name, niceId: r.niceId }]) + ); + + // Enrich logs with resource details + return logs.map(log => ({ + ...log, + resourceName: log.resourceId ? resourceMap.get(log.resourceId)?.name ?? null : null, + resourceNiceId: log.resourceId ? resourceMap.get(log.resourceId)?.niceId ?? null : null + })); +} + export function countAccessQuery(data: Q) { - const countQuery = db + const countQuery = logsDb .select({ count: count() }) .from(accessAuditLog) .where(getWhere(data)); @@ -161,7 +189,7 @@ async function queryUniqueFilterAttributes( ); // Get unique actors - const uniqueActors = await db + const uniqueActors = await logsDb .selectDistinct({ actor: accessAuditLog.actor }) @@ -169,7 +197,7 @@ async function queryUniqueFilterAttributes( .where(baseConditions); // Get unique locations - const uniqueLocations = await db + const uniqueLocations = await logsDb .selectDistinct({ locations: accessAuditLog.location }) @@ -177,25 +205,40 @@ async function queryUniqueFilterAttributes( .where(baseConditions); // Get unique resources with names - const uniqueResources = await db + const uniqueResources = await logsDb .selectDistinct({ - id: accessAuditLog.resourceId, - name: resources.name + id: accessAuditLog.resourceId }) .from(accessAuditLog) - .leftJoin( - resources, - eq(accessAuditLog.resourceId, resources.resourceId) - ) .where(baseConditions); + // Fetch resource names from main database for the unique resource IDs + const resourceIds = uniqueResources + .map(row => row.id) + .filter((id): id is number => id !== null); + + let resourcesWithNames: Array<{ id: number; name: string | null }> = []; + + if (resourceIds.length > 0) { + const resourceDetails = await primaryDb + .select({ + resourceId: resources.resourceId, + name: resources.name + }) + .from(resources) + .where(inArray(resources.resourceId, resourceIds)); + + resourcesWithNames = resourceDetails.map(r => ({ + id: r.resourceId, + name: r.name + })); + } + return { actors: uniqueActors .map((row) => row.actor) .filter((actor): actor is string => actor !== null), - resources: uniqueResources.filter( - (row): row is { id: number; name: string | null } => row.id !== null - ), + resources: resourcesWithNames, locations: uniqueLocations .map((row) => row.locations) .filter((location): location is string => location !== null) @@ -243,7 +286,10 @@ export async function queryAccessAuditLogs( const baseQuery = queryAccess(data); - const log = await baseQuery.limit(data.limit).offset(data.offset); + const logsRaw = await baseQuery.limit(data.limit).offset(data.offset); + + // Enrich with resource details (handles cross-database scenario) + const log = await enrichWithResourceDetails(logsRaw); const totalCountResult = await countAccessQuery(data); const totalCount = totalCountResult[0].count; diff --git a/server/private/routers/auditLogs/queryActionAuditLog.ts b/server/private/routers/auditLogs/queryActionAuditLog.ts index 7eed741b..bd636dee 100644 --- a/server/private/routers/auditLogs/queryActionAuditLog.ts +++ b/server/private/routers/auditLogs/queryActionAuditLog.ts @@ -11,7 +11,7 @@ * This file is not licensed under the AGPLv3. */ -import { actionAuditLog, db } from "@server/db"; +import { actionAuditLog, logsDb } from "@server/db"; import { registry } from "@server/openApi"; import { NextFunction } from "express"; import { Request, Response } from "express"; @@ -97,7 +97,7 @@ function getWhere(data: Q) { } export function queryAction(data: Q) { - return db + return logsDb .select({ orgId: actionAuditLog.orgId, action: actionAuditLog.action, @@ -113,7 +113,7 @@ export function queryAction(data: Q) { } export function countActionQuery(data: Q) { - const countQuery = db + const countQuery = logsDb .select({ count: count() }) .from(actionAuditLog) .where(getWhere(data)); @@ -132,14 +132,14 @@ async function queryUniqueFilterAttributes( ); // Get unique actors - const uniqueActors = await db + const uniqueActors = await logsDb .selectDistinct({ actor: actionAuditLog.actor }) .from(actionAuditLog) .where(baseConditions); - const uniqueActions = await db + const uniqueActions = await logsDb .selectDistinct({ action: actionAuditLog.action }) diff --git a/server/private/routers/external.ts b/server/private/routers/external.ts index a1352342..bd4d232d 100644 --- a/server/private/routers/external.ts +++ b/server/private/routers/external.ts @@ -480,9 +480,9 @@ authenticated.get( authenticated.post( "/re-key/:clientId/regenerate-client-secret", + verifyClientAccess, // this is first to set the org id verifyValidLicense, verifyValidSubscription(tierMatrix.rotateCredentials), - verifyClientAccess, // this is first to set the org id verifyLimits, verifyUserHasAction(ActionsEnum.reGenerateSecret), reKey.reGenerateClientSecret @@ -490,9 +490,9 @@ authenticated.post( authenticated.post( "/re-key/:siteId/regenerate-site-secret", + verifySiteAccess, // this is first to set the org id verifyValidLicense, verifyValidSubscription(tierMatrix.rotateCredentials), - verifySiteAccess, // this is first to set the org id verifyLimits, verifyUserHasAction(ActionsEnum.reGenerateSecret), reKey.reGenerateSiteSecret diff --git a/server/routers/auditLogs/queryRequestAnalytics.ts b/server/routers/auditLogs/queryRequestAnalytics.ts index a6f9cb76..e838c5f5 100644 --- a/server/routers/auditLogs/queryRequestAnalytics.ts +++ b/server/routers/auditLogs/queryRequestAnalytics.ts @@ -1,4 +1,4 @@ -import { db, requestAuditLog, driver, primaryDb } from "@server/db"; +import { logsDb, requestAuditLog, driver, primaryLogsDb } from "@server/db"; import { registry } from "@server/openApi"; import { NextFunction } from "express"; import { Request, Response } from "express"; @@ -74,12 +74,12 @@ async function query(query: Q) { ); } - const [all] = await primaryDb + const [all] = await primaryLogsDb .select({ total: count() }) .from(requestAuditLog) .where(baseConditions); - const [blocked] = await primaryDb + const [blocked] = await primaryLogsDb .select({ total: count() }) .from(requestAuditLog) .where(and(baseConditions, eq(requestAuditLog.action, false))); @@ -90,7 +90,7 @@ async function query(query: Q) { const DISTINCT_LIMIT = 500; - const requestsPerCountry = await primaryDb + const requestsPerCountry = await primaryLogsDb .selectDistinct({ code: requestAuditLog.location, count: totalQ @@ -118,7 +118,7 @@ async function query(query: Q) { const booleanTrue = driver === "pg" ? sql`true` : sql`1`; const booleanFalse = driver === "pg" ? sql`false` : sql`0`; - const requestsPerDay = await primaryDb + const requestsPerDay = await primaryLogsDb .select({ day: groupByDayFunction.as("day"), allowedCount: diff --git a/server/routers/auditLogs/queryRequestAuditLog.ts b/server/routers/auditLogs/queryRequestAuditLog.ts index 98c23721..3b598e03 100644 --- a/server/routers/auditLogs/queryRequestAuditLog.ts +++ b/server/routers/auditLogs/queryRequestAuditLog.ts @@ -1,8 +1,8 @@ -import { db, primaryDb, requestAuditLog, resources } from "@server/db"; +import { logsDb, primaryLogsDb, requestAuditLog, resources, db, primaryDb } from "@server/db"; import { registry } from "@server/openApi"; import { NextFunction } from "express"; import { Request, Response } from "express"; -import { eq, gt, lt, and, count, desc } from "drizzle-orm"; +import { eq, gt, lt, and, count, desc, inArray } from "drizzle-orm"; import { OpenAPITags } from "@server/openApi"; import { z } from "zod"; import createHttpError from "http-errors"; @@ -107,7 +107,7 @@ function getWhere(data: Q) { } export function queryRequest(data: Q) { - return primaryDb + return primaryLogsDb .select({ id: requestAuditLog.id, timestamp: requestAuditLog.timestamp, @@ -129,21 +129,49 @@ export function queryRequest(data: Q) { host: requestAuditLog.host, path: requestAuditLog.path, method: requestAuditLog.method, - tls: requestAuditLog.tls, - resourceName: resources.name, - resourceNiceId: resources.niceId + tls: requestAuditLog.tls }) .from(requestAuditLog) - .leftJoin( - resources, - eq(requestAuditLog.resourceId, resources.resourceId) - ) // TODO: Is this efficient? .where(getWhere(data)) .orderBy(desc(requestAuditLog.timestamp)); } +async function enrichWithResourceDetails(logs: Awaited>) { + // If logs database is the same as main database, we can do a join + // Otherwise, we need to fetch resource details separately + const resourceIds = logs + .map(log => log.resourceId) + .filter((id): id is number => id !== null && id !== undefined); + + if (resourceIds.length === 0) { + return logs.map(log => ({ ...log, resourceName: null, resourceNiceId: null })); + } + + // Fetch resource details from main database + const resourceDetails = await primaryDb + .select({ + resourceId: resources.resourceId, + name: resources.name, + niceId: resources.niceId + }) + .from(resources) + .where(inArray(resources.resourceId, resourceIds)); + + // Create a map for quick lookup + const resourceMap = new Map( + resourceDetails.map(r => [r.resourceId, { name: r.name, niceId: r.niceId }]) + ); + + // Enrich logs with resource details + return logs.map(log => ({ + ...log, + resourceName: log.resourceId ? resourceMap.get(log.resourceId)?.name ?? null : null, + resourceNiceId: log.resourceId ? resourceMap.get(log.resourceId)?.niceId ?? null : null + })); +} + export function countRequestQuery(data: Q) { - const countQuery = primaryDb + const countQuery = primaryLogsDb .select({ count: count() }) .from(requestAuditLog) .where(getWhere(data)); @@ -185,36 +213,31 @@ async function queryUniqueFilterAttributes( uniquePaths, uniqueResources ] = await Promise.all([ - primaryDb + primaryLogsDb .selectDistinct({ actor: requestAuditLog.actor }) .from(requestAuditLog) .where(baseConditions) .limit(DISTINCT_LIMIT + 1), - primaryDb + primaryLogsDb .selectDistinct({ locations: requestAuditLog.location }) .from(requestAuditLog) .where(baseConditions) .limit(DISTINCT_LIMIT + 1), - primaryDb + primaryLogsDb .selectDistinct({ hosts: requestAuditLog.host }) .from(requestAuditLog) .where(baseConditions) .limit(DISTINCT_LIMIT + 1), - primaryDb + primaryLogsDb .selectDistinct({ paths: requestAuditLog.path }) .from(requestAuditLog) .where(baseConditions) .limit(DISTINCT_LIMIT + 1), - primaryDb + primaryLogsDb .selectDistinct({ - id: requestAuditLog.resourceId, - name: resources.name + id: requestAuditLog.resourceId }) .from(requestAuditLog) - .leftJoin( - resources, - eq(requestAuditLog.resourceId, resources.resourceId) - ) .where(baseConditions) .limit(DISTINCT_LIMIT + 1) ]); @@ -231,13 +254,33 @@ async function queryUniqueFilterAttributes( // throw new Error("Too many distinct filter attributes to retrieve. Please refine your time range."); // } + // Fetch resource names from main database for the unique resource IDs + const resourceIds = uniqueResources + .map(row => row.id) + .filter((id): id is number => id !== null); + + let resourcesWithNames: Array<{ id: number; name: string | null }> = []; + + if (resourceIds.length > 0) { + const resourceDetails = await primaryDb + .select({ + resourceId: resources.resourceId, + name: resources.name + }) + .from(resources) + .where(inArray(resources.resourceId, resourceIds)); + + resourcesWithNames = resourceDetails.map(r => ({ + id: r.resourceId, + name: r.name + })); + } + return { actors: uniqueActors .map((row) => row.actor) .filter((actor): actor is string => actor !== null), - resources: uniqueResources.filter( - (row): row is { id: number; name: string | null } => row.id !== null - ), + resources: resourcesWithNames, locations: uniqueLocations .map((row) => row.locations) .filter((location): location is string => location !== null), @@ -280,7 +323,10 @@ export async function queryRequestAuditLogs( const baseQuery = queryRequest(data); - const log = await baseQuery.limit(data.limit).offset(data.offset); + const logsRaw = await baseQuery.limit(data.limit).offset(data.offset); + + // Enrich with resource details (handles cross-database scenario) + const log = await enrichWithResourceDetails(logsRaw); const totalCountResult = await countRequestQuery(data); const totalCount = totalCountResult[0].count; diff --git a/server/routers/badger/logRequestAudit.ts b/server/routers/badger/logRequestAudit.ts index 5975d8f3..287cb030 100644 --- a/server/routers/badger/logRequestAudit.ts +++ b/server/routers/badger/logRequestAudit.ts @@ -1,4 +1,4 @@ -import { db, orgs, requestAuditLog } from "@server/db"; +import { logsDb, primaryLogsDb, db, orgs, requestAuditLog } from "@server/db"; import logger from "@server/logger"; import { and, eq, lt, sql } from "drizzle-orm"; import cache from "@server/lib/cache"; @@ -69,7 +69,7 @@ async function flushAuditLogs() { try { // Use a transaction to ensure all inserts succeed or fail together // This prevents index corruption from partial writes - await db.transaction(async (tx) => { + await logsDb.transaction(async (tx) => { // Batch insert logs in groups of 25 to avoid overwhelming the database const BATCH_DB_SIZE = 25; for (let i = 0; i < logsToWrite.length; i += BATCH_DB_SIZE) { @@ -130,7 +130,7 @@ export async function shutdownAuditLogger() { async function getRetentionDays(orgId: string): Promise { // check cache first - const cached = cache.get(`org_${orgId}_retentionDays`); + const cached = await cache.get(`org_${orgId}_retentionDays`); if (cached !== undefined) { return cached; } @@ -149,7 +149,7 @@ async function getRetentionDays(orgId: string): Promise { } // store the result in cache - cache.set( + await cache.set( `org_${orgId}_retentionDays`, org.settingsLogRetentionDaysRequest, 300 @@ -162,7 +162,7 @@ export async function cleanUpOldLogs(orgId: string, retentionDays: number) { const cutoffTimestamp = calculateCutoffTimestamp(retentionDays); try { - await db + await logsDb .delete(requestAuditLog) .where( and( diff --git a/server/routers/badger/verifySession.ts b/server/routers/badger/verifySession.ts index b5c66c0e..6d537d52 100644 --- a/server/routers/badger/verifySession.ts +++ b/server/routers/badger/verifySession.ts @@ -37,7 +37,7 @@ import { enforceResourceSessionLength } from "#dynamic/lib/checkOrgAccessPolicy"; import { logRequestAudit } from "./logRequestAudit"; -import cache from "@server/lib/cache"; +import { localCache } from "@server/lib/cache"; import { APP_VERSION } from "@server/lib/consts"; import { isSubscribed } from "#dynamic/lib/isSubscribed"; import { tierMatrix } from "@server/lib/billing/tierMatrix"; @@ -137,7 +137,7 @@ export async function verifyResourceSession( headerAuthExtendedCompatibility: ResourceHeaderAuthExtendedCompatibility | null; org: Org; } - | undefined = cache.get(resourceCacheKey); + | undefined = localCache.get(resourceCacheKey); if (!resourceData) { const result = await getResourceByDomain(cleanHost); @@ -161,7 +161,7 @@ export async function verifyResourceSession( } resourceData = result; - cache.set(resourceCacheKey, resourceData, 5); + localCache.set(resourceCacheKey, resourceData, 5); } const { @@ -405,7 +405,7 @@ export async function verifyResourceSession( // check for HTTP Basic Auth header const clientHeaderAuthKey = `headerAuth:${clientHeaderAuth}`; if (headerAuth && clientHeaderAuth) { - if (cache.get(clientHeaderAuthKey)) { + if (localCache.get(clientHeaderAuthKey)) { logger.debug( "Resource allowed because header auth is valid (cached)" ); @@ -428,7 +428,7 @@ export async function verifyResourceSession( headerAuth.headerAuthHash ) ) { - cache.set(clientHeaderAuthKey, clientHeaderAuth, 5); + localCache.set(clientHeaderAuthKey, clientHeaderAuth, 5); logger.debug("Resource allowed because header auth is valid"); logRequestAudit( @@ -520,7 +520,7 @@ export async function verifyResourceSession( if (resourceSessionToken) { const sessionCacheKey = `session:${resourceSessionToken}`; - let resourceSession: any = cache.get(sessionCacheKey); + let resourceSession: any = localCache.get(sessionCacheKey); if (!resourceSession) { const result = await validateResourceSessionToken( @@ -529,7 +529,7 @@ export async function verifyResourceSession( ); resourceSession = result?.resourceSession; - cache.set(sessionCacheKey, resourceSession, 5); + localCache.set(sessionCacheKey, resourceSession, 5); } if (resourceSession?.isRequestToken) { @@ -662,7 +662,7 @@ export async function verifyResourceSession( }:${resource.resourceId}`; let allowedUserData: BasicUserData | null | undefined = - cache.get(userAccessCacheKey); + localCache.get(userAccessCacheKey); if (allowedUserData === undefined) { allowedUserData = await isUserAllowedToAccessResource( @@ -671,7 +671,7 @@ export async function verifyResourceSession( resourceData.org ); - cache.set(userAccessCacheKey, allowedUserData, 5); + localCache.set(userAccessCacheKey, allowedUserData, 5); } if ( @@ -974,11 +974,11 @@ async function checkRules( ): Promise<"ACCEPT" | "DROP" | "PASS" | undefined> { const ruleCacheKey = `rules:${resourceId}`; - let rules: ResourceRule[] | undefined = cache.get(ruleCacheKey); + let rules: ResourceRule[] | undefined = localCache.get(ruleCacheKey); if (!rules) { rules = await getResourceRules(resourceId); - cache.set(ruleCacheKey, rules, 5); + localCache.set(ruleCacheKey, rules, 5); } if (rules.length === 0) { @@ -1208,13 +1208,13 @@ async function isIpInAsn( async function getAsnFromIp(ip: string): Promise { const asnCacheKey = `asn:${ip}`; - let cachedAsn: number | undefined = cache.get(asnCacheKey); + let cachedAsn: number | undefined = localCache.get(asnCacheKey); if (!cachedAsn) { cachedAsn = await getAsnForIp(ip); // do it locally // Cache for longer since IP ASN doesn't change frequently if (cachedAsn) { - cache.set(asnCacheKey, cachedAsn, 300); // 5 minutes + localCache.set(asnCacheKey, cachedAsn, 300); // 5 minutes } } @@ -1224,14 +1224,14 @@ async function getAsnFromIp(ip: string): Promise { async function getCountryCodeFromIp(ip: string): Promise { const geoIpCacheKey = `geoip:${ip}`; - let cachedCountryCode: string | undefined = cache.get(geoIpCacheKey); + let cachedCountryCode: string | undefined = localCache.get(geoIpCacheKey); if (!cachedCountryCode) { cachedCountryCode = await getCountryCodeForIp(ip); // do it locally // Only cache successful lookups to avoid filling cache with undefined values if (cachedCountryCode) { // Cache for longer since IP geolocation doesn't change frequently - cache.set(geoIpCacheKey, cachedCountryCode, 300); // 5 minutes + localCache.set(geoIpCacheKey, cachedCountryCode, 300); // 5 minutes } } diff --git a/server/routers/integration.ts b/server/routers/integration.ts index 6c39fe98..7272d740 100644 --- a/server/routers/integration.ts +++ b/server/routers/integration.ts @@ -689,6 +689,13 @@ authenticated.get( user.getOrgUser ); +authenticated.get( + "/org/:orgId/user-by-username", + verifyApiKeyOrgAccess, + verifyApiKeyHasAction(ActionsEnum.getOrgUser), + user.getOrgUserByUsername +); + authenticated.post( "/user/:userId/2fa", verifyApiKeyIsRoot, diff --git a/server/routers/newt/handleSocketMessages.ts b/server/routers/newt/handleSocketMessages.ts index f26f69c9..2dd10008 100644 --- a/server/routers/newt/handleSocketMessages.ts +++ b/server/routers/newt/handleSocketMessages.ts @@ -24,8 +24,8 @@ export const handleDockerStatusMessage: MessageHandler = async (context) => { if (available) { logger.info(`Newt ${newt.newtId} has Docker socket access`); - cache.set(`${newt.newtId}:socketPath`, socketPath, 0); - cache.set(`${newt.newtId}:isAvailable`, available, 0); + await cache.set(`${newt.newtId}:socketPath`, socketPath, 0); + await cache.set(`${newt.newtId}:isAvailable`, available, 0); } else { logger.warn(`Newt ${newt.newtId} does not have Docker socket access`); } @@ -54,7 +54,7 @@ export const handleDockerContainersMessage: MessageHandler = async ( ); if (containers && containers.length > 0) { - cache.set(`${newt.newtId}:dockerContainers`, containers, 0); + await cache.set(`${newt.newtId}:dockerContainers`, containers, 0); } else { logger.warn(`Newt ${newt.newtId} does not have Docker containers`); } diff --git a/server/routers/olm/getOlmToken.ts b/server/routers/olm/getOlmToken.ts index c8ede518..2734a63b 100644 --- a/server/routers/olm/getOlmToken.ts +++ b/server/routers/olm/getOlmToken.ts @@ -1,4 +1,7 @@ -import { generateSessionToken } from "@server/auth/sessions/app"; +import { + generateSessionToken, + validateSessionToken +} from "@server/auth/sessions/app"; import { clients, db, @@ -26,8 +29,9 @@ import { APP_VERSION } from "@server/lib/consts"; export const olmGetTokenBodySchema = z.object({ olmId: z.string(), - secret: z.string(), - token: z.string().optional(), + secret: z.string().optional(), + userToken: z.string().optional(), + token: z.string().optional(), // this is the olm token orgId: z.string().optional() }); @@ -49,7 +53,7 @@ export async function getOlmToken( ); } - const { olmId, secret, token, orgId } = parsedBody.data; + const { olmId, secret, token, orgId, userToken } = parsedBody.data; try { if (token) { @@ -84,19 +88,45 @@ export async function getOlmToken( ); } - const validSecret = await verifyPassword( - secret, - existingOlm.secretHash - ); - - if (!validSecret) { - if (config.getRawConfig().app.log_failed_attempts) { - logger.info( - `Olm id or secret is incorrect. Olm: ID ${olmId}. IP: ${req.ip}.` + if (userToken) { + const { session: userSession, user } = + await validateSessionToken(userToken); + if (!userSession || !user) { + return next( + createHttpError(HttpCode.BAD_REQUEST, "Invalid user token") ); } + if (user.userId !== existingOlm.userId) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "User token does not match olm" + ) + ); + } + } else if (secret) { + // this is for backward compatibility, we want to move towards userToken but some old clients may still be using secret so we will support both for now + const validSecret = await verifyPassword( + secret, + existingOlm.secretHash + ); + + if (!validSecret) { + if (config.getRawConfig().app.log_failed_attempts) { + logger.info( + `Olm id or secret is incorrect. Olm: ID ${olmId}. IP: ${req.ip}.` + ); + } + return next( + createHttpError(HttpCode.BAD_REQUEST, "Secret is incorrect") + ); + } + } else { return next( - createHttpError(HttpCode.BAD_REQUEST, "Secret is incorrect") + createHttpError( + HttpCode.BAD_REQUEST, + "Either secret or userToken is required" + ) ); } diff --git a/server/routers/org/updateOrg.ts b/server/routers/org/updateOrg.ts index e94be3a9..5664ee9c 100644 --- a/server/routers/org/updateOrg.ts +++ b/server/routers/org/updateOrg.ts @@ -194,9 +194,9 @@ export async function updateOrg( } // invalidate the cache for all of the orgs retention days - cache.del(`org_${orgId}_retentionDays`); - cache.del(`org_${orgId}_actionDays`); - cache.del(`org_${orgId}_accessDays`); + await cache.del(`org_${orgId}_retentionDays`); + await cache.del(`org_${orgId}_actionDays`); + await cache.del(`org_${orgId}_accessDays`); return response(res, { data: updatedOrg[0], diff --git a/server/routers/site/listSites.ts b/server/routers/site/listSites.ts index e4881b1a..e5685a5a 100644 --- a/server/routers/site/listSites.ts +++ b/server/routers/site/listSites.ts @@ -23,7 +23,7 @@ import { fromError } from "zod-validation-error"; async function getLatestNewtVersion(): Promise { try { - const cachedVersion = cache.get("latestNewtVersion"); + const cachedVersion = await cache.get("latestNewtVersion"); if (cachedVersion) { return cachedVersion; } @@ -55,7 +55,7 @@ async function getLatestNewtVersion(): Promise { tags = tags.filter((version) => !version.name.includes("rc")); const latestVersion = tags[0].name; - cache.set("latestNewtVersion", latestVersion); + await cache.set("latestNewtVersion", latestVersion); return latestVersion; } catch (error: any) { diff --git a/server/routers/site/socketIntegration.ts b/server/routers/site/socketIntegration.ts index e0ad09d1..6a72a5d4 100644 --- a/server/routers/site/socketIntegration.ts +++ b/server/routers/site/socketIntegration.ts @@ -150,7 +150,7 @@ async function triggerFetch(siteId: number) { // clear the cache for this Newt ID so that the site has to keep asking for the containers // this is to ensure that the site always gets the latest data - cache.del(`${newt.newtId}:dockerContainers`); + await cache.del(`${newt.newtId}:dockerContainers`); return { siteId, newtId: newt.newtId }; } @@ -158,7 +158,7 @@ async function triggerFetch(siteId: number) { async function queryContainers(siteId: number) { const { newt } = await getSiteAndNewt(siteId); - const result = cache.get(`${newt.newtId}:dockerContainers`) as Container[]; + const result = await cache.get(`${newt.newtId}:dockerContainers`); if (!result) { throw createHttpError( HttpCode.TOO_EARLY, @@ -173,7 +173,7 @@ async function isDockerAvailable(siteId: number): Promise { const { newt } = await getSiteAndNewt(siteId); const key = `${newt.newtId}:isAvailable`; - const isAvailable = cache.get(key); + const isAvailable = await cache.get(key); return !!isAvailable; } @@ -186,9 +186,11 @@ async function getDockerStatus( const keys = ["isAvailable", "socketPath"]; const mappedKeys = keys.map((x) => `${newt.newtId}:${x}`); + const values = await cache.mget(mappedKeys); + const result = { - isAvailable: cache.get(mappedKeys[0]) as boolean, - socketPath: cache.get(mappedKeys[1]) as string | undefined + isAvailable: values[0] as boolean, + socketPath: values[1] as string | undefined }; return result; diff --git a/server/routers/user/getOrgUser.ts b/server/routers/user/getOrgUser.ts index f22a29d3..c0a990ee 100644 --- a/server/routers/user/getOrgUser.ts +++ b/server/routers/user/getOrgUser.ts @@ -11,7 +11,7 @@ import { fromError } from "zod-validation-error"; import { ActionsEnum, checkUserActionPermission } from "@server/auth/actions"; import { OpenAPITags, registry } from "@server/openApi"; -async function queryUser(orgId: string, userId: string) { +export async function queryUser(orgId: string, userId: string) { const [user] = await db .select({ orgId: userOrgs.orgId, diff --git a/server/routers/user/getOrgUserByUsername.ts b/server/routers/user/getOrgUserByUsername.ts new file mode 100644 index 00000000..b047fdc0 --- /dev/null +++ b/server/routers/user/getOrgUserByUsername.ts @@ -0,0 +1,136 @@ +import { Request, Response, NextFunction } from "express"; +import { z } from "zod"; +import { db } from "@server/db"; +import { userOrgs, users } from "@server/db"; +import { and, eq } from "drizzle-orm"; +import response from "@server/lib/response"; +import HttpCode from "@server/types/HttpCode"; +import createHttpError from "http-errors"; +import logger from "@server/logger"; +import { fromError } from "zod-validation-error"; +import { OpenAPITags, registry } from "@server/openApi"; +import { queryUser, type GetOrgUserResponse } from "./getOrgUser"; + +const getOrgUserByUsernameParamsSchema = z.strictObject({ + orgId: z.string() +}); + +const getOrgUserByUsernameQuerySchema = z.strictObject({ + username: z.string().min(1, "username is required"), + idpId: z + .string() + .optional() + .transform((v) => + v === undefined || v === "" ? undefined : parseInt(v, 10) + ) + .refine( + (v) => + v === undefined || (Number.isInteger(v) && (v as number) > 0), + { message: "idpId must be a positive integer" } + ) +}); + +registry.registerPath({ + method: "get", + path: "/org/{orgId}/user-by-username", + description: + "Get a user in an organization by username. When idpId is not passed, only internal users are searched (username is globally unique for them). For external (OIDC) users, pass idpId to search by username within that identity provider.", + tags: [OpenAPITags.Org, OpenAPITags.User], + request: { + params: getOrgUserByUsernameParamsSchema, + query: getOrgUserByUsernameQuerySchema + }, + responses: {} +}); + +export async function getOrgUserByUsername( + req: Request, + res: Response, + next: NextFunction +): Promise { + try { + const parsedParams = getOrgUserByUsernameParamsSchema.safeParse( + req.params + ); + if (!parsedParams.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedParams.error).toString() + ) + ); + } + + const parsedQuery = getOrgUserByUsernameQuerySchema.safeParse( + req.query + ); + if (!parsedQuery.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedQuery.error).toString() + ) + ); + } + + const { orgId } = parsedParams.data; + const { username, idpId } = parsedQuery.data; + + const conditions = [ + eq(userOrgs.orgId, orgId), + eq(users.username, username) + ]; + if (idpId !== undefined) { + conditions.push(eq(users.idpId, idpId)); + } else { + conditions.push(eq(users.type, "internal")); + } + + const candidates = await db + .select({ userId: users.userId }) + .from(userOrgs) + .innerJoin(users, eq(userOrgs.userId, users.userId)) + .where(and(...conditions)); + + if (candidates.length === 0) { + return next( + createHttpError( + HttpCode.NOT_FOUND, + `User with username '${username}' not found in organization` + ) + ); + } + + if (candidates.length > 1) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "Multiple users with this username (external users from different identity providers). Specify idpId (identity provider ID) to disambiguate. When not specified, this searches for internal users only." + ) + ); + } + + const user = await queryUser(orgId, candidates[0].userId); + if (!user) { + return next( + createHttpError( + HttpCode.NOT_FOUND, + `User with username '${username}' not found in organization` + ) + ); + } + + return response(res, { + data: user, + success: true, + error: false, + message: "User retrieved successfully", + status: HttpCode.OK + }); + } catch (error) { + logger.error(error); + return next( + createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred") + ); + } +} diff --git a/server/routers/user/index.ts b/server/routers/user/index.ts index 35c5c4a7..b6fb05d9 100644 --- a/server/routers/user/index.ts +++ b/server/routers/user/index.ts @@ -5,6 +5,7 @@ export * from "./addUserRole"; export * from "./inviteUser"; export * from "./acceptInvite"; export * from "./getOrgUser"; +export * from "./getOrgUserByUsername"; export * from "./adminListUsers"; export * from "./adminRemoveUser"; export * from "./adminGetUser"; diff --git a/server/routers/user/inviteUser.ts b/server/routers/user/inviteUser.ts index 693ef3b9..26fa8e55 100644 --- a/server/routers/user/inviteUser.ts +++ b/server/routers/user/inviteUser.ts @@ -191,7 +191,7 @@ export async function inviteUser( } if (existingInvite.length) { - const attempts = cache.get(email) || 0; + const attempts = (await cache.get(email)) || 0; if (attempts >= 3) { return next( createHttpError( @@ -201,7 +201,7 @@ export async function inviteUser( ); } - cache.set(email, attempts + 1); + await cache.set(email, attempts + 1); const inviteId = existingInvite[0].inviteId; // Retrieve the original inviteId const token = generateRandomString( diff --git a/src/app/[orgId]/settings/resources/proxy/[niceId]/authentication/page.tsx b/src/app/[orgId]/settings/resources/proxy/[niceId]/authentication/page.tsx index b00ce1ee..a533fb6c 100644 --- a/src/app/[orgId]/settings/resources/proxy/[niceId]/authentication/page.tsx +++ b/src/app/[orgId]/settings/resources/proxy/[niceId]/authentication/page.tsx @@ -187,7 +187,11 @@ export default function ResourceAuthenticationPage() { number | null >(null); - const [ssoEnabled, setSsoEnabled] = useState(resource.sso); + const [ssoEnabled, setSsoEnabled] = useState(resource.sso ?? false); + + useEffect(() => { + setSsoEnabled(resource.sso ?? false); + }, [resource.sso]); const [selectedIdpId, setSelectedIdpId] = useState( resource.skipToIdpId || null @@ -472,7 +476,7 @@ export default function ResourceAuthenticationPage() { setSsoEnabled(val)} /> @@ -800,8 +804,13 @@ function OneTimePasswordFormSection({ }: OneTimePasswordFormSectionProps) { const { env } = useEnvContext(); const [whitelistEnabled, setWhitelistEnabled] = useState( - resource.emailWhitelistEnabled + resource.emailWhitelistEnabled ?? false ); + + useEffect(() => { + setWhitelistEnabled(resource.emailWhitelistEnabled); + }, [resource.emailWhitelistEnabled]); + const queryClient = useQueryClient(); const [loadingSaveWhitelist, startTransition] = useTransition(); @@ -894,7 +903,7 @@ function OneTimePasswordFormSection({ diff --git a/src/app/[orgId]/settings/resources/proxy/[niceId]/rules/page.tsx b/src/app/[orgId]/settings/resources/proxy/[niceId]/rules/page.tsx index f0ac0e1a..2b6a1687 100644 --- a/src/app/[orgId]/settings/resources/proxy/[niceId]/rules/page.tsx +++ b/src/app/[orgId]/settings/resources/proxy/[niceId]/rules/page.tsx @@ -113,7 +113,12 @@ export default function ResourceRules(props: { const [rulesToRemove, setRulesToRemove] = useState([]); const [loading, setLoading] = useState(false); const [pageLoading, setPageLoading] = useState(true); - const [rulesEnabled, setRulesEnabled] = useState(resource.applyRules); + const [rulesEnabled, setRulesEnabled] = useState(resource.applyRules ?? false); + + useEffect(() => { + setRulesEnabled(resource.applyRules); + }, [resource.applyRules]); + const [openCountrySelect, setOpenCountrySelect] = useState(false); const [countrySelectValue, setCountrySelectValue] = useState(""); const [openAddRuleCountrySelect, setOpenAddRuleCountrySelect] = @@ -836,7 +841,7 @@ export default function ResourceRules(props: { setRulesEnabled(val)} /> diff --git a/src/app/navigation.tsx b/src/app/navigation.tsx index 915e5f04..0066721d 100644 --- a/src/app/navigation.tsx +++ b/src/app/navigation.tsx @@ -107,7 +107,7 @@ export const orgNavSections = ( ] }, { - heading: "access", + heading: "accessControl", items: [ { title: "sidebarTeam", diff --git a/src/components/CopyToClipboard.tsx b/src/components/CopyToClipboard.tsx index b755f9a5..dca14728 100644 --- a/src/components/CopyToClipboard.tsx +++ b/src/components/CopyToClipboard.tsx @@ -31,6 +31,18 @@ const CopyToClipboard = ({ return (
+ {isLink ? ( )} -
); }; diff --git a/src/components/LayoutMobileMenu.tsx b/src/components/LayoutMobileMenu.tsx index f24c2f13..b661d780 100644 --- a/src/components/LayoutMobileMenu.tsx +++ b/src/components/LayoutMobileMenu.tsx @@ -5,13 +5,11 @@ import { SidebarNav } from "@app/components/SidebarNav"; import { OrgSelector } from "@app/components/OrgSelector"; import { cn } from "@app/lib/cn"; import { ListUserOrgsResponse } from "@server/routers/org"; -import SupporterStatus from "@app/components/SupporterStatus"; import { Button } from "@app/components/ui/button"; -import { ExternalLink, Menu, Server } from "lucide-react"; +import { ArrowRight, Menu, Server } from "lucide-react"; import Link from "next/link"; import { usePathname } from "next/navigation"; import { useUserContext } from "@app/hooks/useUserContext"; -import { useEnvContext } from "@app/hooks/useEnvContext"; import { useTranslations } from "next-intl"; import ProfileIcon from "@app/components/ProfileIcon"; import ThemeSwitcher from "@app/components/ThemeSwitcher"; @@ -44,7 +42,6 @@ export function LayoutMobileMenu({ const pathname = usePathname(); const isAdminPage = pathname?.startsWith("/admin"); const { user } = useUserContext(); - const { env } = useEnvContext(); const t = useTranslations(); return ( @@ -83,7 +80,7 @@ export function LayoutMobileMenu({
{!isAdminPage && user.serverAdmin && ( -
+
- + {t( "serverAdmin" )} +
)} @@ -115,22 +113,6 @@ export function LayoutMobileMenu({
-
- - {env?.app?.version && ( -
- - v{env.app.version} - - -
- )} -
diff --git a/src/components/LayoutSidebar.tsx b/src/components/LayoutSidebar.tsx index 940e91fe..e9e2d61e 100644 --- a/src/components/LayoutSidebar.tsx +++ b/src/components/LayoutSidebar.tsx @@ -146,6 +146,46 @@ export function LayoutSidebar({ />
+ {!isAdminPage && user.serverAdmin && ( +
+ + + + + {!isSidebarCollapsed && ( + <> + + {t("serverAdmin")} + + + + )} + +
+ )}
- {!isAdminPage && user.serverAdmin && ( -
- - - - - {!isSidebarCollapsed && ( - <> - - {t("serverAdmin")} - - - - )} - -
- )} - {isSidebarCollapsed && (
@@ -218,7 +224,7 @@ export function LayoutSidebar({
-
+
{canShowProductUpdates && (
diff --git a/src/lib/api/cookies.ts b/src/lib/api/cookies.ts index fe3c0090..c4c395c4 100644 --- a/src/lib/api/cookies.ts +++ b/src/lib/api/cookies.ts @@ -2,31 +2,20 @@ import { headers } from "next/headers"; export async function authCookieHeader() { const otherHeaders = await headers(); - const otherHeadersObject = Object.fromEntries(otherHeaders.entries()); + const otherHeadersObject = Object.fromEntries( + Array.from(otherHeaders.entries()).map(([k, v]) => [k.toLowerCase(), v]) + ); return { headers: { - cookie: - otherHeadersObject["cookie"] || otherHeadersObject["Cookie"], - host: otherHeadersObject["host"] || otherHeadersObject["Host"], - "user-agent": - otherHeadersObject["user-agent"] || - otherHeadersObject["User-Agent"], - "x-forwarded-for": - otherHeadersObject["x-forwarded-for"] || - otherHeadersObject["X-Forwarded-For"], - "x-forwarded-host": - otherHeadersObject["fx-forwarded-host"] || - otherHeadersObject["Fx-Forwarded-Host"], - "x-forwarded-port": - otherHeadersObject["x-forwarded-port"] || - otherHeadersObject["X-Forwarded-Port"], - "x-forwarded-proto": - otherHeadersObject["x-forwarded-proto"] || - otherHeadersObject["X-Forwarded-Proto"], - "x-real-ip": - otherHeadersObject["x-real-ip"] || - otherHeadersObject["X-Real-IP"] + cookie: otherHeadersObject["cookie"], + host: otherHeadersObject["host"], + "user-agent": otherHeadersObject["user-agent"], + "x-forwarded-for": otherHeadersObject["x-forwarded-for"], + "x-forwarded-host": otherHeadersObject["x-forwarded-host"], + "x-forwarded-port": otherHeadersObject["x-forwarded-port"], + "x-forwarded-proto": otherHeadersObject["x-forwarded-proto"], + "x-real-ip": otherHeadersObject["x-real-ip"] } }; }