Merge branch 'dev' into feat/device-approvals

This commit is contained in:
Fred KISSIE
2026-01-05 16:54:18 +01:00
165 changed files with 8514 additions and 2346 deletions

View File

@@ -99,12 +99,13 @@ async function query(query: Q) {
.where(and(baseConditions, not(isNull(requestAuditLog.location))))
.groupBy(requestAuditLog.location)
.orderBy(desc(totalQ))
.limit(DISTINCT_LIMIT+1);
.limit(DISTINCT_LIMIT + 1);
if (requestsPerCountry.length > DISTINCT_LIMIT) {
// throw an error
throw createHttpError(
HttpCode.BAD_REQUEST,
// todo: is this even possible?
`Too many distinct countries. Please narrow your query.`
);
}

View File

@@ -189,22 +189,22 @@ async function queryUniqueFilterAttributes(
.selectDistinct({ actor: requestAuditLog.actor })
.from(requestAuditLog)
.where(baseConditions)
.limit(DISTINCT_LIMIT+1),
.limit(DISTINCT_LIMIT + 1),
primaryDb
.selectDistinct({ locations: requestAuditLog.location })
.from(requestAuditLog)
.where(baseConditions)
.limit(DISTINCT_LIMIT+1),
.limit(DISTINCT_LIMIT + 1),
primaryDb
.selectDistinct({ hosts: requestAuditLog.host })
.from(requestAuditLog)
.where(baseConditions)
.limit(DISTINCT_LIMIT+1),
.limit(DISTINCT_LIMIT + 1),
primaryDb
.selectDistinct({ paths: requestAuditLog.path })
.from(requestAuditLog)
.where(baseConditions)
.limit(DISTINCT_LIMIT+1),
.limit(DISTINCT_LIMIT + 1),
primaryDb
.selectDistinct({
id: requestAuditLog.resourceId,
@@ -216,18 +216,20 @@ async function queryUniqueFilterAttributes(
eq(requestAuditLog.resourceId, resources.resourceId)
)
.where(baseConditions)
.limit(DISTINCT_LIMIT+1)
.limit(DISTINCT_LIMIT + 1)
]);
if (
uniqueActors.length > DISTINCT_LIMIT ||
uniqueLocations.length > DISTINCT_LIMIT ||
uniqueHosts.length > DISTINCT_LIMIT ||
uniquePaths.length > DISTINCT_LIMIT ||
uniqueResources.length > DISTINCT_LIMIT
) {
throw new Error("Too many distinct filter attributes to retrieve. Please refine your time range.");
}
// TODO: for stuff like the paths this is too restrictive so lets just show some of the paths and the user needs to
// refine the time range to see what they need to see
// if (
// uniqueActors.length > DISTINCT_LIMIT ||
// uniqueLocations.length > DISTINCT_LIMIT ||
// uniqueHosts.length > DISTINCT_LIMIT ||
// uniquePaths.length > DISTINCT_LIMIT ||
// uniqueResources.length > DISTINCT_LIMIT
// ) {
// throw new Error("Too many distinct filter attributes to retrieve. Please refine your time range.");
// }
return {
actors: uniqueActors
@@ -307,10 +309,12 @@ export async function queryRequestAuditLogs(
} catch (error) {
logger.error(error);
// if the message is "Too many distinct filter attributes to retrieve. Please refine your time range.", return a 400 and the message
if (error instanceof Error && error.message === "Too many distinct filter attributes to retrieve. Please refine your time range.") {
return next(
createHttpError(HttpCode.BAD_REQUEST, error.message)
);
if (
error instanceof Error &&
error.message ===
"Too many distinct filter attributes to retrieve. Please refine your time range."
) {
return next(createHttpError(HttpCode.BAD_REQUEST, error.message));
}
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")

View File

@@ -10,6 +10,7 @@ import { eq, and, gt } from "drizzle-orm";
import { createSession, generateSessionToken } from "@server/auth/sessions/app";
import { encodeHexLowerCase } from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
import { stripPortFromHost } from "@server/lib/ip";
const paramsSchema = z.object({
code: z.string().min(1, "Code is required")
@@ -27,30 +28,6 @@ export type PollDeviceWebAuthResponse = {
token?: string;
};
// Helper function to extract IP from request (same as in startDeviceWebAuth)
function extractIpFromRequest(req: Request): string | undefined {
const ip = req.ip || req.socket.remoteAddress;
if (!ip) {
return undefined;
}
// Handle IPv6 format [::1] or IPv4 format
if (ip.startsWith("[") && ip.includes("]")) {
const ipv6Match = ip.match(/\[(.*?)\]/);
if (ipv6Match) {
return ipv6Match[1];
}
}
// Handle IPv4 with port (split at last colon)
const lastColonIndex = ip.lastIndexOf(":");
if (lastColonIndex !== -1) {
return ip.substring(0, lastColonIndex);
}
return ip;
}
export async function pollDeviceWebAuth(
req: Request,
res: Response,
@@ -70,7 +47,7 @@ export async function pollDeviceWebAuth(
try {
const { code } = parsedParams.data;
const now = Date.now();
const requestIp = extractIpFromRequest(req);
const requestIp = req.ip ? stripPortFromHost(req.ip) : undefined;
// Hash the code before querying
const hashedCode = hashDeviceCode(code);

View File

@@ -12,6 +12,7 @@ import { TimeSpan } from "oslo";
import { maxmindLookup } from "@server/db/maxmind";
import { encodeHexLowerCase } from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
import { stripPortFromHost } from "@server/lib/ip";
const bodySchema = z
.object({
@@ -39,30 +40,6 @@ function hashDeviceCode(code: string): string {
return encodeHexLowerCase(sha256(new TextEncoder().encode(code)));
}
// Helper function to extract IP from request
function extractIpFromRequest(req: Request): string | undefined {
const ip = req.ip;
if (!ip) {
return undefined;
}
// Handle IPv6 format [::1] or IPv4 format
if (ip.startsWith("[") && ip.includes("]")) {
const ipv6Match = ip.match(/\[(.*?)\]/);
if (ipv6Match) {
return ipv6Match[1];
}
}
// Handle IPv4 with port (split at last colon)
const lastColonIndex = ip.lastIndexOf(":");
if (lastColonIndex !== -1) {
return ip.substring(0, lastColonIndex);
}
return ip;
}
// Helper function to get city from IP (if available)
async function getCityFromIp(ip: string): Promise<string | undefined> {
try {
@@ -112,7 +89,7 @@ export async function startDeviceWebAuth(
const hashedCode = hashDeviceCode(code);
// Extract IP from request
const ip = extractIpFromRequest(req);
const ip = req.ip ? stripPortFromHost(req.ip) : undefined;
// Get city (optional, may return undefined)
const city = ip ? await getCityFromIp(ip) : undefined;

View File

@@ -19,6 +19,7 @@ import {
import { SESSION_COOKIE_EXPIRES as RESOURCE_SESSION_COOKIE_EXPIRES } from "@server/auth/sessions/resource";
import config from "@server/lib/config";
import { response } from "@server/lib/response";
import { stripPortFromHost } from "@server/lib/ip";
const exchangeSessionBodySchema = z.object({
requestToken: z.string(),
@@ -62,7 +63,7 @@ export async function exchangeSession(
cleanHost = cleanHost.slice(0, -1 * matched.length);
}
const clientIp = requestIp?.split(":")[0];
const clientIp = requestIp ? stripPortFromHost(requestIp) : undefined;
const [resource] = await db
.select()

View File

@@ -3,6 +3,7 @@ import logger from "@server/logger";
import { and, eq, lt } from "drizzle-orm";
import cache from "@server/lib/cache";
import { calculateCutoffTimestamp } from "@server/lib/cleanupLogs";
import { stripPortFromHost } from "@server/lib/ip";
/**
@@ -10,7 +11,7 @@ Reasons:
100 - Allowed by Rule
101 - Allowed No Auth
102 - Valid Access Token
103 - Valid header auth
103 - Valid Header Auth (HTTP Basic Auth)
104 - Valid Pincode
105 - Valid Password
106 - Valid email
@@ -48,27 +49,43 @@ const auditLogBuffer: Array<{
const BATCH_SIZE = 100; // Write to DB every 100 logs
const BATCH_INTERVAL_MS = 5000; // Or every 5 seconds, whichever comes first
const MAX_BUFFER_SIZE = 10000; // Prevent unbounded memory growth
let flushTimer: NodeJS.Timeout | null = null;
let isFlushInProgress = false;
/**
* Flush buffered logs to database
*/
async function flushAuditLogs() {
if (auditLogBuffer.length === 0) {
if (auditLogBuffer.length === 0 || isFlushInProgress) {
return;
}
isFlushInProgress = true;
// Take all current logs and clear buffer
const logsToWrite = auditLogBuffer.splice(0, auditLogBuffer.length);
try {
// Batch insert all logs at once
await db.insert(requestAuditLog).values(logsToWrite);
// Batch insert logs in groups of 25 to avoid overwhelming the database
const BATCH_DB_SIZE = 25;
for (let i = 0; i < logsToWrite.length; i += BATCH_DB_SIZE) {
const batch = logsToWrite.slice(i, i + BATCH_DB_SIZE);
await db.insert(requestAuditLog).values(batch);
}
logger.debug(`Flushed ${logsToWrite.length} audit logs to database`);
} catch (error) {
logger.error("Error flushing audit logs:", error);
// On error, we lose these logs - consider a fallback strategy if needed
// (e.g., write to file, or put back in buffer with retry limit)
} finally {
isFlushInProgress = false;
// If buffer filled up while we were flushing, flush again
if (auditLogBuffer.length >= BATCH_SIZE) {
flushAuditLogs().catch((err) =>
logger.error("Error in follow-up flush:", err)
);
}
}
}
@@ -94,6 +111,10 @@ export async function shutdownAuditLogger() {
clearTimeout(flushTimer);
flushTimer = null;
}
// Force flush even if one is in progress by waiting and retrying
while (isFlushInProgress) {
await new Promise((resolve) => setTimeout(resolve, 100));
}
await flushAuditLogs();
}
@@ -208,28 +229,17 @@ export async function logRequestAudit(
}
const clientIp = body.requestIp
? (() => {
if (
body.requestIp.startsWith("[") &&
body.requestIp.includes("]")
) {
// if brackets are found, extract the IPv6 address from between the brackets
const ipv6Match = body.requestIp.match(/\[(.*?)\]/);
if (ipv6Match) {
return ipv6Match[1];
}
}
// ivp4
// split at last colon
const lastColonIndex = body.requestIp.lastIndexOf(":");
if (lastColonIndex !== -1) {
return body.requestIp.substring(0, lastColonIndex);
}
return body.requestIp;
})()
? stripPortFromHost(body.requestIp)
: undefined;
// Prevent unbounded buffer growth - drop oldest entries if buffer is too large
if (auditLogBuffer.length >= MAX_BUFFER_SIZE) {
const dropped = auditLogBuffer.splice(0, BATCH_SIZE);
logger.warn(
`Audit log buffer exceeded max size (${MAX_BUFFER_SIZE}), dropped ${dropped.length} oldest entries`
);
}
// Add to buffer instead of writing directly to DB
auditLogBuffer.push({
timestamp,

View File

@@ -14,13 +14,14 @@ import {
Org,
Resource,
ResourceHeaderAuth,
ResourceHeaderAuthExtendedCompatibility,
ResourcePassword,
ResourcePincode,
ResourceRule,
resourceSessions
} from "@server/db";
import config from "@server/lib/config";
import { isIpInCidr } from "@server/lib/ip";
import { isIpInCidr, stripPortFromHost } from "@server/lib/ip";
import { response } from "@server/lib/response";
import logger from "@server/logger";
import HttpCode from "@server/types/HttpCode";
@@ -29,6 +30,7 @@ import createHttpError from "http-errors";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import { getCountryCodeForIp } from "@server/lib/geoip";
import { getAsnForIp } from "@server/lib/asn";
import { getOrgTierData } from "#dynamic/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { verifyPassword } from "@server/auth/password";
@@ -38,6 +40,8 @@ import {
} from "#dynamic/lib/checkOrgAccessPolicy";
import { logRequestAudit } from "./logRequestAudit";
import cache from "@server/lib/cache";
import semver from "semver";
import { APP_VERSION } from "@server/lib/consts";
const verifyResourceSessionSchema = z.object({
sessions: z.record(z.string(), z.string()).optional(),
@@ -49,7 +53,8 @@ const verifyResourceSessionSchema = z.object({
path: z.string(),
method: z.string(),
tls: z.boolean(),
requestIp: z.string().optional()
requestIp: z.string().optional(),
badgerVersion: z.string().optional()
});
export type VerifyResourceSessionSchema = z.infer<
@@ -65,8 +70,10 @@ type BasicUserData = {
export type VerifyUserResponse = {
valid: boolean;
headerAuthChallenged?: boolean;
redirectUrl?: string;
userData?: BasicUserData;
pangolinVersion?: string;
};
export async function verifyResourceSession(
@@ -95,31 +102,15 @@ export async function verifyResourceSession(
requestIp,
path,
headers,
query
query,
badgerVersion
} = parsedBody.data;
// Extract HTTP Basic Auth credentials if present
const clientHeaderAuth = extractBasicAuth(headers);
const clientIp = requestIp
? (() => {
logger.debug("Request IP:", { requestIp });
if (requestIp.startsWith("[") && requestIp.includes("]")) {
// if brackets are found, extract the IPv6 address from between the brackets
const ipv6Match = requestIp.match(/\[(.*?)\]/);
if (ipv6Match) {
return ipv6Match[1];
}
}
// ivp4
// split at last colon
const lastColonIndex = requestIp.lastIndexOf(":");
if (lastColonIndex !== -1) {
return requestIp.substring(0, lastColonIndex);
}
return requestIp;
})()
? stripPortFromHost(requestIp, badgerVersion)
: undefined;
logger.debug("Client IP:", { clientIp });
@@ -128,6 +119,8 @@ export async function verifyResourceSession(
? await getCountryCodeFromIp(clientIp)
: undefined;
const ipAsn = clientIp ? await getAsnFromIp(clientIp) : undefined;
let cleanHost = host;
// if the host ends with :port, strip it
if (cleanHost.match(/:[0-9]{1,5}$/)) {
@@ -142,6 +135,7 @@ export async function verifyResourceSession(
pincode: ResourcePincode | null;
password: ResourcePassword | null;
headerAuth: ResourceHeaderAuth | null;
headerAuthExtendedCompatibility: ResourceHeaderAuthExtendedCompatibility | null;
org: Org;
}
| undefined = cache.get(resourceCacheKey);
@@ -171,7 +165,13 @@ export async function verifyResourceSession(
cache.set(resourceCacheKey, resourceData, 5);
}
const { resource, pincode, password, headerAuth } = resourceData;
const {
resource,
pincode,
password,
headerAuth,
headerAuthExtendedCompatibility
} = resourceData;
if (!resource) {
logger.debug(`Resource not found ${cleanHost}`);
@@ -216,7 +216,8 @@ export async function verifyResourceSession(
resource.resourceId,
clientIp,
path,
ipCC
ipCC,
ipAsn
);
if (action == "ACCEPT") {
@@ -450,7 +451,8 @@ export async function verifyResourceSession(
!sso &&
!pincode &&
!password &&
!resource.emailWhitelistEnabled
!resource.emailWhitelistEnabled &&
!headerAuthExtendedCompatibility?.extendedCompatibilityIsActivated
) {
logRequestAudit(
{
@@ -471,7 +473,8 @@ export async function verifyResourceSession(
!sso &&
!pincode &&
!password &&
!resource.emailWhitelistEnabled
!resource.emailWhitelistEnabled &&
!headerAuthExtendedCompatibility?.extendedCompatibilityIsActivated
) {
logRequestAudit(
{
@@ -557,7 +560,7 @@ export async function verifyResourceSession(
}
if (resourceSession) {
// only run this check if not SSO sesion; SSO session length is checked later
// only run this check if not SSO session; SSO session length is checked later
const accessPolicy = await enforceResourceSessionLength(
resourceSession,
resourceData.org
@@ -701,6 +704,15 @@ export async function verifyResourceSession(
}
}
// If headerAuthExtendedCompatibility is activated but no clientHeaderAuth provided, force client to challenge
if (
headerAuthExtendedCompatibility &&
headerAuthExtendedCompatibility.extendedCompatibilityIsActivated &&
!clientHeaderAuth
) {
return headerAuthChallenged(res, redirectPath, resource.orgId);
}
logger.debug("No more auth to check, resource not allowed");
if (config.getRawConfig().app.log_failed_attempts) {
@@ -809,7 +821,7 @@ async function notAllowed(
}
const data = {
data: { valid: false, redirectUrl },
data: { valid: false, redirectUrl, pangolinVersion: APP_VERSION },
success: true,
error: false,
message: "Access denied",
@@ -823,8 +835,8 @@ function allowed(res: Response, userData?: BasicUserData) {
const data = {
data:
userData !== undefined && userData !== null
? { valid: true, ...userData }
: { valid: true },
? { valid: true, ...userData, pangolinVersion: APP_VERSION }
: { valid: true, pangolinVersion: APP_VERSION },
success: true,
error: false,
message: "Access allowed",
@@ -833,6 +845,51 @@ function allowed(res: Response, userData?: BasicUserData) {
return response<VerifyUserResponse>(res, data);
}
async function headerAuthChallenged(
res: Response,
redirectPath?: string,
orgId?: string
) {
let loginPage: LoginPage | null = null;
if (orgId) {
const { tier } = await getOrgTierData(orgId); // returns null in oss
if (tier === TierId.STANDARD) {
loginPage = await getOrgLoginPage(orgId);
}
}
let redirectUrl: string | undefined = undefined;
if (redirectPath) {
let endpoint: string;
if (loginPage && loginPage.domainId && loginPage.fullDomain) {
const secure = config
.getRawConfig()
.app.dashboard_url?.startsWith("https");
const method = secure ? "https" : "http";
endpoint = `${method}://${loginPage.fullDomain}`;
} else {
endpoint = config.getRawConfig().app.dashboard_url!;
}
redirectUrl = `${endpoint}${redirectPath}`;
}
const data = {
data: {
headerAuthChallenged: true,
valid: false,
redirectUrl,
pangolinVersion: APP_VERSION
},
success: true,
error: false,
message: "Access denied",
status: HttpCode.OK
};
logger.debug(JSON.stringify(data));
return response<VerifyUserResponse>(res, data);
}
async function isUserAllowedToAccessResource(
userSessionId: string,
resource: Resource,
@@ -910,7 +967,8 @@ async function checkRules(
resourceId: number,
clientIp: string | undefined,
path: string | undefined,
ipCC?: string
ipCC?: string,
ipAsn?: number
): Promise<"ACCEPT" | "DROP" | "PASS" | undefined> {
const ruleCacheKey = `rules:${resourceId}`;
@@ -954,6 +1012,12 @@ async function checkRules(
(await isIpInGeoIP(ipCC, rule.value))
) {
return rule.action as any;
} else if (
clientIp &&
rule.match == "ASN" &&
(await isIpInAsn(ipAsn, rule.value))
) {
return rule.action as any;
}
}
@@ -971,14 +1035,25 @@ export function isPathAllowed(pattern: string, path: string): boolean {
logger.debug(`Normalized pattern parts: [${patternParts.join(", ")}]`);
logger.debug(`Normalized path parts: [${pathParts.join(", ")}]`);
// Maximum recursion depth to prevent stack overflow and memory issues
const MAX_RECURSION_DEPTH = 100;
// Recursive function to try different wildcard matches
function matchSegments(patternIndex: number, pathIndex: number): boolean {
const indent = " ".repeat(pathIndex); // Indent based on recursion depth
function matchSegments(patternIndex: number, pathIndex: number, depth: number = 0): boolean {
// Check recursion depth limit
if (depth > MAX_RECURSION_DEPTH) {
logger.warn(
`Path matching exceeded maximum recursion depth (${MAX_RECURSION_DEPTH}) for pattern "${pattern}" and path "${path}"`
);
return false;
}
const indent = " ".repeat(depth); // Indent based on recursion depth
const currentPatternPart = patternParts[patternIndex];
const currentPathPart = pathParts[pathIndex];
logger.debug(
`${indent}Checking patternIndex=${patternIndex} (${currentPatternPart || "END"}) vs pathIndex=${pathIndex} (${currentPathPart || "END"})`
`${indent}Checking patternIndex=${patternIndex} (${currentPatternPart || "END"}) vs pathIndex=${pathIndex} (${currentPathPart || "END"}) [depth=${depth}]`
);
// If we've consumed all pattern parts, we should have consumed all path parts
@@ -1011,7 +1086,7 @@ export function isPathAllowed(pattern: string, path: string): boolean {
logger.debug(
`${indent}Trying to skip wildcard (consume 0 segments)`
);
if (matchSegments(patternIndex + 1, pathIndex)) {
if (matchSegments(patternIndex + 1, pathIndex, depth + 1)) {
logger.debug(
`${indent}Successfully matched by skipping wildcard`
);
@@ -1022,7 +1097,7 @@ export function isPathAllowed(pattern: string, path: string): boolean {
logger.debug(
`${indent}Trying to consume segment "${currentPathPart}" for wildcard`
);
if (matchSegments(patternIndex, pathIndex + 1)) {
if (matchSegments(patternIndex, pathIndex + 1, depth + 1)) {
logger.debug(
`${indent}Successfully matched by consuming segment for wildcard`
);
@@ -1050,7 +1125,7 @@ export function isPathAllowed(pattern: string, path: string): boolean {
logger.debug(
`${indent}Segment with wildcard matches: "${currentPatternPart}" matches "${currentPathPart}"`
);
return matchSegments(patternIndex + 1, pathIndex + 1);
return matchSegments(patternIndex + 1, pathIndex + 1, depth + 1);
}
logger.debug(
@@ -1071,10 +1146,10 @@ export function isPathAllowed(pattern: string, path: string): boolean {
`${indent}Segments match: "${currentPatternPart}" = "${currentPathPart}"`
);
// Move to next segments in both pattern and path
return matchSegments(patternIndex + 1, pathIndex + 1);
return matchSegments(patternIndex + 1, pathIndex + 1, depth + 1);
}
const result = matchSegments(0, 0);
const result = matchSegments(0, 0, 0);
logger.debug(`Final result: ${result}`);
return result;
}
@@ -1090,6 +1165,52 @@ async function isIpInGeoIP(
return ipCountryCode?.toUpperCase() === checkCountryCode.toUpperCase();
}
async function isIpInAsn(
ipAsn: number | undefined,
checkAsn: string
): Promise<boolean> {
// Handle "ALL" special case
if (checkAsn === "ALL" || checkAsn === "AS0") {
return true;
}
if (!ipAsn) {
return false;
}
// Normalize the check ASN - remove "AS" prefix if present and convert to number
const normalizedCheckAsn = checkAsn.toUpperCase().replace(/^AS/, "");
const checkAsnNumber = parseInt(normalizedCheckAsn, 10);
if (isNaN(checkAsnNumber)) {
logger.warn(`Invalid ASN format in rule: ${checkAsn}`);
return false;
}
const match = ipAsn === checkAsnNumber;
logger.debug(
`ASN check: IP ASN ${ipAsn} ${match ? "matches" : "does not match"} rule ASN ${checkAsnNumber}`
);
return match;
}
async function getAsnFromIp(ip: string): Promise<number | undefined> {
const asnCacheKey = `asn:${ip}`;
let cachedAsn: number | undefined = cache.get(asnCacheKey);
if (!cachedAsn) {
cachedAsn = await getAsnForIp(ip); // do it locally
// Cache for longer since IP ASN doesn't change frequently
if (cachedAsn) {
cache.set(asnCacheKey, cachedAsn, 300); // 5 minutes
}
}
return cachedAsn;
}
async function getCountryCodeFromIp(ip: string): Promise<string | undefined> {
const geoIpCacheKey = `geoip:${ip}`;
@@ -1097,8 +1218,11 @@ async function getCountryCodeFromIp(ip: string): Promise<string | undefined> {
if (!cachedCountryCode) {
cachedCountryCode = await getCountryCodeForIp(ip); // do it locally
// Cache for longer since IP geolocation doesn't change frequently
cache.set(geoIpCacheKey, cachedCountryCode, 300); // 5 minutes
// Only cache successful lookups to avoid filling cache with undefined values
if (cachedCountryCode) {
// Cache for longer since IP geolocation doesn't change frequently
cache.set(geoIpCacheKey, cachedCountryCode, 300); // 5 minutes
}
}
return cachedCountryCode;

View File

@@ -36,7 +36,7 @@ async function query(clientId?: number, niceId?: string, orgId?: string) {
.select()
.from(clients)
.where(and(eq(clients.niceId, niceId), eq(clients.orgId, orgId)))
.leftJoin(olms, eq(olms.clientId, olms.clientId))
.leftJoin(olms, eq(clients.clientId, olms.clientId))
.limit(1);
return res;
}

View File

@@ -56,12 +56,12 @@ async function getLatestOlmVersion(): Promise<string | null> {
return null;
}
const tags = await response.json();
let tags = await response.json();
if (!Array.isArray(tags) || tags.length === 0) {
logger.warn("No tags found for Olm repository");
return null;
}
tags = tags.filter((version) => !version.name.includes("rc"));
const latestVersion = tags[0].name;
olmVersionCache.set("latestOlmVersion", latestVersion);

View File

@@ -48,7 +48,6 @@ import createHttpError from "http-errors";
import { build } from "@server/build";
import { createStore } from "#dynamic/lib/rateLimitStore";
import { logActionAudit } from "#dynamic/middlewares";
import { log } from "console";
// Root routes
export const unauthenticated = Router();

View File

@@ -52,7 +52,7 @@ export async function getConfig(
}
// clean up the public key - keep only valid base64 characters (A-Z, a-z, 0-9, +, /, =)
const cleanedPublicKey = publicKey.replace(/[^A-Za-z0-9+/=]/g, '');
const cleanedPublicKey = publicKey.replace(/[^A-Za-z0-9+/=]/g, "");
const exitNode = await createExitNode(cleanedPublicKey, reachableAt);

View File

@@ -605,9 +605,18 @@ export async function validateOidcCallback(
res.appendHeader("Set-Cookie", cookie);
let finalRedirectUrl = postAuthRedirectUrl;
if (loginPageId) {
finalRedirectUrl = `/auth/org/?redirect=${encodeURIComponent(
postAuthRedirectUrl
)}`;
}
logger.debug("Final redirect URL", { finalRedirectUrl });
return response<ValidateOidcUrlCallbackResponse>(res, {
data: {
redirectUrl: postAuthRedirectUrl
redirectUrl: finalRedirectUrl
},
success: true,
error: false,

View File

@@ -858,6 +858,20 @@ authenticated.put(
blueprints.applyJSONBlueprint
);
authenticated.get(
"/org/:orgId/blueprint/:blueprintId",
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.getBlueprint),
blueprints.getBlueprint
);
authenticated.get(
"/org/:orgId/blueprints",
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.listBlueprints),
blueprints.listBlueprints
);
authenticated.get(
"/org/:orgId/logs/request",
verifyApiKeyOrgAccess,

View File

@@ -1,7 +1,7 @@
import { db } from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import { clients, Newt } from "@server/db";
import { eq } from "drizzle-orm";
import { clients } from "@server/db";
import { eq, sql } from "drizzle-orm";
import logger from "@server/logger";
interface PeerBandwidth {
@@ -10,13 +10,57 @@ interface PeerBandwidth {
bytesOut: number;
}
// Retry configuration for deadlock handling
const MAX_RETRIES = 3;
const BASE_DELAY_MS = 50;
/**
* Check if an error is a deadlock error
*/
function isDeadlockError(error: any): boolean {
return (
error?.code === "40P01" ||
error?.cause?.code === "40P01" ||
(error?.message && error.message.includes("deadlock"))
);
}
/**
* Execute a function with retry logic for deadlock handling
*/
async function withDeadlockRetry<T>(
operation: () => Promise<T>,
context: string
): Promise<T> {
let attempt = 0;
while (true) {
try {
return await operation();
} catch (error: any) {
if (isDeadlockError(error) && attempt < MAX_RETRIES) {
attempt++;
const baseDelay = Math.pow(2, attempt - 1) * BASE_DELAY_MS;
const jitter = Math.random() * baseDelay;
const delay = baseDelay + jitter;
logger.warn(
`Deadlock detected in ${context}, retrying attempt ${attempt}/${MAX_RETRIES} after ${delay.toFixed(0)}ms`
);
await new Promise((resolve) => setTimeout(resolve, delay));
continue;
}
throw error;
}
}
}
export const handleReceiveBandwidthMessage: MessageHandler = async (
context
) => {
const { message, client, sendToClient } = context;
const { message } = context;
if (!message.data.bandwidthData) {
logger.warn("No bandwidth data provided");
return;
}
const bandwidthData: PeerBandwidth[] = message.data.bandwidthData;
@@ -25,30 +69,40 @@ export const handleReceiveBandwidthMessage: MessageHandler = async (
throw new Error("Invalid bandwidth data");
}
await db.transaction(async (trx) => {
for (const peer of bandwidthData) {
const { publicKey, bytesIn, bytesOut } = peer;
// Sort bandwidth data by publicKey to ensure consistent lock ordering across all instances
// This is critical for preventing deadlocks when multiple instances update the same clients
const sortedBandwidthData = [...bandwidthData].sort((a, b) =>
a.publicKey.localeCompare(b.publicKey)
);
// Find the client by public key
const [client] = await trx
.select()
.from(clients)
.where(eq(clients.pubKey, publicKey))
.limit(1);
const currentTime = new Date().toISOString();
if (!client) {
continue;
}
// Update each client individually with retry logic
// This reduces transaction scope and allows retries per-client
for (const peer of sortedBandwidthData) {
const { publicKey, bytesIn, bytesOut } = peer;
// Update the client's bandwidth usage
await trx
.update(clients)
.set({
megabytesOut: (client.megabytesIn || 0) + bytesIn,
megabytesIn: (client.megabytesOut || 0) + bytesOut,
lastBandwidthUpdate: new Date().toISOString()
})
.where(eq(clients.clientId, client.clientId));
try {
await withDeadlockRetry(async () => {
// Use atomic SQL increment to avoid SELECT then UPDATE pattern
// This eliminates the need to read the current value first
await db
.update(clients)
.set({
// Note: bytesIn from peer goes to megabytesOut (data sent to client)
// and bytesOut from peer goes to megabytesIn (data received from client)
megabytesOut: sql`COALESCE(${clients.megabytesOut}, 0) + ${bytesIn}`,
megabytesIn: sql`COALESCE(${clients.megabytesIn}, 0) + ${bytesOut}`,
lastBandwidthUpdate: currentTime
})
.where(eq(clients.pubKey, publicKey));
}, `update client bandwidth ${publicKey}`);
} catch (error) {
logger.error(
`Failed to update bandwidth for client ${publicKey}:`,
error
);
// Continue with other clients even if one fails
}
});
}
};

View File

@@ -27,6 +27,7 @@ import { usageService } from "@server/lib/billing/usageService";
import { FeatureId } from "@server/lib/billing";
import { build } from "@server/build";
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
import { doCidrsOverlap } from "@server/lib/ip";
const createOrgSchema = z.strictObject({
orgId: z.string(),
@@ -36,6 +37,11 @@ const createOrgSchema = z.strictObject({
.union([z.cidrv4()]) // for now lets just do ipv4 until we verify ipv6 works everywhere
.refine((val) => isValidCIDR(val), {
message: "Invalid subnet CIDR"
}),
utilitySubnet: z
.union([z.cidrv4()]) // for now lets just do ipv4 until we verify ipv6 works everywhere
.refine((val) => isValidCIDR(val), {
message: "Invalid utility subnet CIDR"
})
});
@@ -84,7 +90,7 @@ export async function createOrg(
);
}
const { orgId, name, subnet } = parsedBody.data;
const { orgId, name, subnet, utilitySubnet } = parsedBody.data;
// TODO: for now we are making all of the orgs the same subnet
// make sure the subnet is unique
@@ -119,6 +125,15 @@ export async function createOrg(
);
}
if (doCidrsOverlap(subnet, utilitySubnet)) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`Subnet ${subnet} overlaps with utility subnet ${utilitySubnet}`
)
);
}
let error = "";
let org: Org | null = null;
@@ -128,9 +143,6 @@ export async function createOrg(
.from(domains)
.where(eq(domains.configManaged, true));
const utilitySubnet =
config.getRawConfig().orgs.utility_subnet_group;
const newOrg = await trx
.insert(orgs)
.values({

View File

@@ -8,6 +8,7 @@ import config from "@server/lib/config";
export type PickOrgDefaultsResponse = {
subnet: string;
utilitySubnet: string;
};
export async function pickOrgDefaults(
@@ -20,10 +21,12 @@ export async function pickOrgDefaults(
// const subnet = await getNextAvailableOrgSubnet();
// Just hard code the subnet for now for everyone
const subnet = config.getRawConfig().orgs.subnet_group;
const utilitySubnet = config.getRawConfig().orgs.utility_subnet_group;
return response<PickOrgDefaultsResponse>(res, {
data: {
subnet: subnet
subnet: subnet,
utilitySubnet: utilitySubnet
},
success: true,
error: false,

View File

@@ -10,10 +10,10 @@ import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import { build } from "@server/build";
import license from "#dynamic/license/license";
import { getOrgTierData } from "#dynamic/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { cache } from "@server/lib/cache";
import { isLicensedOrSubscribed } from "@server/lib/isLicencedOrSubscribed";
const updateOrgParamsSchema = z.strictObject({
orgId: z.string()
@@ -155,22 +155,3 @@ export async function updateOrg(
);
}
}
async function isLicensedOrSubscribed(orgId: string): Promise<boolean> {
if (build === "enterprise") {
const isUnlocked = await license.isUnlocked();
if (!isUnlocked) {
return false;
}
}
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return false;
}
}
return true;
}

View File

@@ -17,7 +17,7 @@ import { OpenAPITags, registry } from "@server/openApi";
const createResourceRuleSchema = z.strictObject({
action: z.enum(["ACCEPT", "DROP", "PASS"]),
match: z.enum(["CIDR", "IP", "PATH", "COUNTRY"]),
match: z.enum(["CIDR", "IP", "PATH", "COUNTRY", "ASN"]),
value: z.string().min(1),
priority: z.int(),
enabled: z.boolean().optional()

View File

@@ -3,6 +3,7 @@ import { z } from "zod";
import {
db,
resourceHeaderAuth,
resourceHeaderAuthExtendedCompatibility,
resourcePassword,
resourcePincode,
resources
@@ -27,6 +28,7 @@ export type GetResourceAuthInfoResponse = {
password: boolean;
pincode: boolean;
headerAuth: boolean;
headerAuthExtendedCompatibility: boolean;
sso: boolean;
blockAccess: boolean;
url: string;
@@ -76,6 +78,13 @@ export async function getResourceAuthInfo(
resources.resourceId
)
)
.leftJoin(
resourceHeaderAuthExtendedCompatibility,
eq(
resourceHeaderAuthExtendedCompatibility.resourceId,
resources.resourceId
)
)
.where(eq(resources.resourceId, Number(resourceGuid)))
.limit(1)
: await db
@@ -89,6 +98,7 @@ export async function getResourceAuthInfo(
resourcePassword,
eq(resourcePassword.resourceId, resources.resourceId)
)
.leftJoin(
resourceHeaderAuth,
eq(
@@ -96,6 +106,13 @@ export async function getResourceAuthInfo(
resources.resourceId
)
)
.leftJoin(
resourceHeaderAuthExtendedCompatibility,
eq(
resourceHeaderAuthExtendedCompatibility.resourceId,
resources.resourceId
)
)
.where(eq(resources.resourceGuid, resourceGuid))
.limit(1);
@@ -109,6 +126,8 @@ export async function getResourceAuthInfo(
const pincode = result?.resourcePincode;
const password = result?.resourcePassword;
const headerAuth = result?.resourceHeaderAuth;
const headerAuthExtendedCompatibility =
result?.resourceHeaderAuthExtendedCompatibility;
const url = `${resource.ssl ? "https" : "http"}://${resource.fullDomain}`;
@@ -121,6 +140,8 @@ export async function getResourceAuthInfo(
password: password !== null,
pincode: pincode !== null,
headerAuth: headerAuth !== null,
headerAuthExtendedCompatibility:
headerAuthExtendedCompatibility !== null,
sso: resource.sso,
blockAccess: resource.blockAccess,
url,

View File

@@ -30,3 +30,4 @@ export * from "./removeRoleFromResource";
export * from "./addUserToResource";
export * from "./removeUserFromResource";
export * from "./listAllResourceNames";
export * from "./removeEmailFromResourceWhitelist";

View File

@@ -1,6 +1,10 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, resourceHeaderAuth } from "@server/db";
import {
db,
resourceHeaderAuth,
resourceHeaderAuthExtendedCompatibility
} from "@server/db";
import {
resources,
userResources,
@@ -109,7 +113,8 @@ function queryResources(accessibleResourceIds: number[], orgId: string) {
domainId: resources.domainId,
niceId: resources.niceId,
headerAuthId: resourceHeaderAuth.headerAuthId,
headerAuthExtendedCompatibilityId:
resourceHeaderAuthExtendedCompatibility.headerAuthExtendedCompatibilityId,
targetId: targets.targetId,
targetIp: targets.ip,
targetPort: targets.port,
@@ -131,6 +136,13 @@ function queryResources(accessibleResourceIds: number[], orgId: string) {
resourceHeaderAuth,
eq(resourceHeaderAuth.resourceId, resources.resourceId)
)
.leftJoin(
resourceHeaderAuthExtendedCompatibility,
eq(
resourceHeaderAuthExtendedCompatibility.resourceId,
resources.resourceId
)
)
.leftJoin(targets, eq(targets.resourceId, resources.resourceId))
.leftJoin(
targetHealthCheck,

View File

@@ -1,6 +1,10 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, resourceHeaderAuth } from "@server/db";
import {
db,
resourceHeaderAuth,
resourceHeaderAuthExtendedCompatibility
} from "@server/db";
import { eq } from "drizzle-orm";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
@@ -16,7 +20,8 @@ const setResourceAuthMethodsParamsSchema = z.object({
const setResourceAuthMethodsBodySchema = z.strictObject({
user: z.string().min(4).max(100).nullable(),
password: z.string().min(4).max(100).nullable()
password: z.string().min(4).max(100).nullable(),
extendedCompatibility: z.boolean().nullable()
});
registry.registerPath({
@@ -67,21 +72,38 @@ export async function setResourceHeaderAuth(
}
const { resourceId } = parsedParams.data;
const { user, password } = parsedBody.data;
const { user, password, extendedCompatibility } = parsedBody.data;
await db.transaction(async (trx) => {
await trx
.delete(resourceHeaderAuth)
.where(eq(resourceHeaderAuth.resourceId, resourceId));
await trx
.delete(resourceHeaderAuthExtendedCompatibility)
.where(
eq(
resourceHeaderAuthExtendedCompatibility.resourceId,
resourceId
)
);
if (user && password) {
if (user && password && extendedCompatibility !== null) {
const headerAuthHash = await hashPassword(
Buffer.from(`${user}:${password}`).toString("base64")
);
await trx
.insert(resourceHeaderAuth)
.values({ resourceId, headerAuthHash });
await Promise.all([
trx
.insert(resourceHeaderAuth)
.values({ resourceId, headerAuthHash }),
trx
.insert(resourceHeaderAuthExtendedCompatibility)
.values({
resourceId,
extendedCompatibilityIsActivated:
extendedCompatibility
})
]);
}
});

View File

@@ -0,0 +1,10 @@
export type GetMaintenanceInfoResponse = {
resourceId: number;
name: string;
fullDomain: string | null;
maintenanceModeEnabled: boolean;
maintenanceModeType: "forced" | "automatic" | null;
maintenanceTitle: string | null;
maintenanceMessage: string | null;
maintenanceEstimatedTime: string | null;
};

View File

@@ -22,8 +22,8 @@ import { registry } from "@server/openApi";
import { OpenAPITags } from "@server/openApi";
import { createCertificate } from "#dynamic/routers/certificates/createCertificate";
import { validateAndConstructDomain } from "@server/lib/domainUtils";
import { validateHeaders } from "@server/lib/validators";
import { build } from "@server/build";
import { isLicensedOrSubscribed } from "@server/lib/isLicencedOrSubscribed";
const updateResourceParamsSchema = z.strictObject({
resourceId: z.string().transform(Number).pipe(z.int().positive())
@@ -48,7 +48,13 @@ const updateHttpResourceBodySchema = z
headers: z
.array(z.strictObject({ name: z.string(), value: z.string() }))
.nullable()
.optional()
.optional(),
// Maintenance mode fields
maintenanceModeEnabled: z.boolean().optional(),
maintenanceModeType: z.enum(["forced", "automatic"]).optional(),
maintenanceTitle: z.string().max(255).nullable().optional(),
maintenanceMessage: z.string().max(2000).nullable().optional(),
maintenanceEstimatedTime: z.string().max(100).nullable().optional()
})
.refine((data) => Object.keys(data).length > 0, {
error: "At least one field must be provided for update"
@@ -335,6 +341,19 @@ async function updateHttpResource(
headers = null;
}
const isLicensed = await isLicensedOrSubscribed(resource.orgId);
if (build == "enterprise" && !isLicensed) {
logger.warn(
"Server is not licensed! Clearing set maintenance screen values"
);
// null the maintenance mode fields if not licensed
updateData.maintenanceModeEnabled = undefined;
updateData.maintenanceModeType = undefined;
updateData.maintenanceTitle = undefined;
updateData.maintenanceMessage = undefined;
updateData.maintenanceEstimatedTime = undefined;
}
const updatedResource = await db
.update(resources)
.set({ ...updateData, headers })

View File

@@ -25,7 +25,7 @@ const updateResourceRuleParamsSchema = z.strictObject({
const updateResourceRuleSchema = z
.strictObject({
action: z.enum(["ACCEPT", "DROP", "PASS"]).optional(),
match: z.enum(["CIDR", "IP", "PATH", "COUNTRY"]).optional(),
match: z.enum(["CIDR", "IP", "PATH", "COUNTRY", "ASN"]).optional(),
value: z.string().min(1).optional(),
priority: z.int(),
enabled: z.boolean().optional()

View File

@@ -39,12 +39,12 @@ async function getLatestNewtVersion(): Promise<string | null> {
return null;
}
const tags = await response.json();
let tags = await response.json();
if (!Array.isArray(tags) || tags.length === 0) {
logger.warn("No tags found for Newt repository");
return null;
}
tags = tags.filter((version) => !version.name.includes("rc"));
const latestVersion = tags[0].name;
cache.set("latestNewtVersion", latestVersion);

View File

@@ -11,7 +11,11 @@ import {
userSiteResources
} from "@server/db";
import { getUniqueSiteResourceName } from "@server/db/names";
import { getNextAvailableAliasAddress, isIpInCidr, portRangeStringSchema } from "@server/lib/ip";
import {
getNextAvailableAliasAddress,
isIpInCidr,
portRangeStringSchema
} from "@server/lib/ip";
import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations";
import response from "@server/lib/response";
import logger from "@server/logger";
@@ -69,7 +73,10 @@ const createSiteResourceSchema = z
const domainRegex =
/^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?$/;
const isValidDomain = domainRegex.test(data.destination);
const isValidAlias = data.alias !== undefined && data.alias !== null && data.alias.trim() !== "";
const isValidAlias =
data.alias !== undefined &&
data.alias !== null &&
data.alias.trim() !== "";
return isValidDomain && isValidAlias; // require the alias to be set in the case of domain
}
@@ -182,7 +189,9 @@ export async function createSiteResource(
.limit(1);
if (!org) {
return next(createHttpError(HttpCode.NOT_FOUND, "Organization not found"));
return next(
createHttpError(HttpCode.NOT_FOUND, "Organization not found")
);
}
if (!org.subnet || !org.utilitySubnet) {
@@ -195,10 +204,13 @@ export async function createSiteResource(
}
// Only check if destination is an IP address
const isIp = z.union([z.ipv4(), z.ipv6()]).safeParse(destination).success;
const isIp = z
.union([z.ipv4(), z.ipv6()])
.safeParse(destination).success;
if (
isIp &&
(isIpInCidr(destination, org.subnet) || isIpInCidr(destination, org.utilitySubnet))
(isIpInCidr(destination, org.subnet) ||
isIpInCidr(destination, org.utilitySubnet))
) {
return next(
createHttpError(

View File

@@ -88,9 +88,7 @@ export async function deleteSiteResource(
);
});
logger.info(
`Deleted site resource ${siteResourceId}`
);
logger.info(`Deleted site resource ${siteResourceId}`);
return response(res, {
data: { message: "Site resource deleted successfully" },

View File

@@ -204,7 +204,9 @@ export async function updateSiteResource(
.limit(1);
if (!org) {
return next(createHttpError(HttpCode.NOT_FOUND, "Organization not found"));
return next(
createHttpError(HttpCode.NOT_FOUND, "Organization not found")
);
}
if (!org.subnet || !org.utilitySubnet) {
@@ -217,10 +219,13 @@ export async function updateSiteResource(
}
// Only check if destination is an IP address
const isIp = z.union([z.ipv4(), z.ipv6()]).safeParse(destination).success;
const isIp = z
.union([z.ipv4(), z.ipv6()])
.safeParse(destination).success;
if (
isIp &&
(isIpInCidr(destination!, org.subnet) || isIpInCidr(destination!, org.utilitySubnet))
(isIpInCidr(destination!, org.subnet) ||
isIpInCidr(destination!, org.utilitySubnet))
) {
return next(
createHttpError(
@@ -295,7 +300,7 @@ export async function updateSiteResource(
const [insertedSiteResource] = await trx
.insert(siteResources)
.values({
...existingSiteResource,
...existingSiteResource
})
.returning();
@@ -517,9 +522,14 @@ export async function handleMessagingForUpdatedSiteResource(
site: { siteId: number; orgId: string },
trx: Transaction
) {
logger.debug("handleMessagingForUpdatedSiteResource: existingSiteResource is: ", existingSiteResource);
logger.debug("handleMessagingForUpdatedSiteResource: updatedSiteResource is: ", updatedSiteResource);
logger.debug(
"handleMessagingForUpdatedSiteResource: existingSiteResource is: ",
existingSiteResource
);
logger.debug(
"handleMessagingForUpdatedSiteResource: updatedSiteResource is: ",
updatedSiteResource
);
const { mergedAllClients } =
await rebuildClientAssociationsFromSiteResource(

View File

@@ -213,9 +213,11 @@ export async function updateTarget(
// When health check is disabled, reset hcHealth to "unknown"
// to prevent previously unhealthy targets from being excluded
// Also when the site is not a newt, set hcHealth to "unknown"
const hcHealthValue =
parsedBody.data.hcEnabled === false ||
parsedBody.data.hcEnabled === null
parsedBody.data.hcEnabled === null ||
site.type !== "newt"
? "unknown"
: undefined;