Compare commits

..

93 Commits

Author SHA1 Message Date
Owen
2b0d6de986 Handle feature lifecycle for multiple orgs 2026-02-17 21:00:48 -08:00
Owen
057f82a561 Fix some cosmetics 2026-02-17 20:46:02 -08:00
Owen
719d2a5ffe Count everything when deleting the org 2026-02-17 20:39:47 -08:00
miloschwartz
d4bff9d5cb clean orgId and fix primary badge 2026-02-17 20:35:36 -08:00
Owen
19fcc1f93b Set org limit 2026-02-17 20:18:50 -08:00
miloschwartz
d45ea127c2 use billing org id in get subscription status 2026-02-17 20:07:29 -08:00
Owen
f591cf8601 Look to the right org to test is subscribed 2026-02-17 20:06:58 -08:00
Owen
6661a76aa8 Update member resources page and testing new org counts 2026-02-17 20:01:43 -08:00
miloschwartz
a2ed22bfcc use add/remove helper functions in auto (de)provision 2026-02-17 17:50:41 -08:00
Owen
e370f8891a Also update in the assign 2026-02-17 17:34:57 -08:00
miloschwartz
8a83e32c42 add send email verification opt out 2026-02-17 17:33:35 -08:00
Owen
831eb6325c Centralize user functions 2026-02-17 17:31:41 -08:00
Owen
4d6240c987 Handle new usage tracking with multi org 2026-02-17 17:10:05 -08:00
miloschwartz
79cf7c84dc support delete org and preserve path on switch 2026-02-17 16:45:15 -08:00
Owen
b71f582329 Use the billing org id when updating and checking usage 2026-02-17 15:09:42 -08:00
miloschwartz
b8c3cc751a support creating multiple orgs in saas 2026-02-17 14:37:46 -08:00
Owen
d00262dc31 Send the right port and cert 2026-02-17 11:43:38 -08:00
Owen
3debc6c8d3 Add round trip tracking for any message 2026-02-16 20:29:55 -08:00
Owen
5092eb58fb Ssh host should be the destination 2026-02-16 15:31:09 -08:00
Owen
f0b9240575 Accept resource as either niceId or alias 2026-02-16 15:29:23 -08:00
Owen
9cf59c409e Initial sign endpoint working 2026-02-16 15:19:29 -08:00
Owen
bfd5aa30a7 Merge branch 'dev' of github.com:fosrl/pangolin into dev 2026-02-15 11:09:11 -08:00
Owen
9737170665 Merge branch 'Lokowitz-update-packages' into dev 2026-02-15 11:08:12 -08:00
Owen
922a040466 Merge branch 'update-packages' of github.com:Lokowitz/pangolin into Lokowitz-update-packages 2026-02-15 11:08:02 -08:00
miloschwartz
33f0782f3a support delete account 2026-02-14 18:01:37 -08:00
Milo Schwartz
e6a5cef945 Merge pull request #2371 from Fredkiss3/refactor/paginated-tables
feat: server side filtered, ordered & paginated tables
2026-02-14 11:43:01 -08:00
miloschwartz
4c8edb80b3 dont show table footer in client-side data-table 2026-02-14 11:40:59 -08:00
miloschwartz
d4668fae99 add openapi types 2026-02-14 11:25:00 -08:00
Fred KISSIE
ddfe55e3ae ♻️ add niceId to query filtering on most tables 2026-02-14 04:19:30 +01:00
Fred KISSIE
761a5f1d4c ♻️ use like & LOWER(column) for searching with query 2026-02-14 04:11:27 +01:00
Fred KISSIE
1fbcad8787 ♻️ refactor 2026-02-14 04:06:11 +01:00
miloschwartz
aba586e605 change translation 2026-02-13 17:35:54 -08:00
Milo Schwartz
27b21b5ad4 Merge pull request #2359 from Fredkiss3/feat/logo-path-in-enterprise
feat: Support file path in branding logo URL for enterprise
2026-02-13 17:16:33 -08:00
Milo Schwartz
b6e54dab17 Merge branch 'dev' into feat/logo-path-in-enterprise 2026-02-13 17:16:25 -08:00
miloschwartz
1f8e89772d disable global idp routes if idp mode is org 2026-02-13 15:46:13 -08:00
Lokowitz
5f3657fd56 update packages 2026-02-13 06:15:26 +00:00
Lokowitz
494162400e Merge remote-tracking branch 'origin/dev' into update-packages 2026-02-13 06:12:24 +00:00
Fred KISSIE
ab65bb6a8a Merge branch 'dev' into refactor/paginated-tables 2026-02-13 06:03:09 +01:00
Lokowitz
4e1e0cade1 upgrade package 2026-02-12 15:51:19 +00:00
Lokowitz
fda5904dac Merge remote-tracking branch 'origin/dev' into update-packages 2026-02-12 15:47:29 +00:00
Fred KISSIE
6d1665004b 🏷️ fix type errors 2026-02-11 04:34:53 +01:00
Fred KISSIE
59b8119fbd Merge branch 'dev' into refactor/paginated-tables 2026-02-11 04:12:40 +01:00
Fred KISSIE
45cd4df6e5 ♻️ agent 2026-02-11 00:37:42 +01:00
Lokowitz
d5b6de70da update package and update dockerfile 2026-02-10 11:49:53 +00:00
Fred KISSIE
d6ade102dc filter & paginate on machine clients table 2026-02-10 05:14:37 +01:00
Fred KISSIE
c94d246c24 ♻️ list machine query 2026-02-10 04:00:45 +01:00
Fred KISSIE
5b779ba9fe ♻️ refactor 2026-02-10 03:21:12 +01:00
Fred KISSIE
3ba2cb19a9 approval feed 2026-02-10 03:20:49 +01:00
Fred KISSIE
da514ef314 ♻️ refactor 2026-02-10 00:45:34 +01:00
Fred KISSIE
7f73cde794 ♻️ refetch approval count every 30s 2026-02-10 00:45:20 +01:00
Fred KISSIE
b0af0d9cd5 ♻️ keep previous data 2026-02-10 00:31:21 +01:00
Lokowitz
8429197b07 fix .gitignore 2026-02-09 19:56:32 +00:00
Lokowitz
44f2081882 update packages 2026-02-09 17:08:59 +00:00
Lokowitz
63f7dd1d20 fix lint error 2026-02-07 08:48:58 +00:00
Lokowitz
57b8c69983 remove date-fns 2026-02-07 08:34:56 +00:00
Lokowitz
aad060810a update package and move eslint to dev 2026-02-07 08:26:05 +00:00
Lokowitz
9222b00a6f Merge remote-tracking branch 'origin/dev' into update-packages 2026-02-07 08:14:16 +00:00
Fred KISSIE
ff61b22e7e ♻️do not set default values 2026-02-07 05:37:52 +01:00
Fred KISSIE
577cb91343 whole table filter 2026-02-07 05:37:01 +01:00
Fred KISSIE
1889386f64 🚧 wip: table filters 2026-02-07 04:51:37 +01:00
Fred KISSIE
5d7f082ebf sort user device table & refactor sort into common functino 2026-02-07 04:41:42 +01:00
Fred KISSIE
db6327c4ff 🔇 remove console.logs in API 2026-02-07 02:52:23 +01:00
Fred KISSIE
fd7f6b2b99 filter user devices API finished 2026-02-07 02:51:32 +01:00
Fred KISSIE
49435398a8 🔥 cleanup imports 2026-02-07 02:50:59 +01:00
Fred KISSIE
9f2fd34e99 🚧 wip: user devices endpoint 2026-02-06 05:37:44 +01:00
Fred KISSIE
67b63d3084 ♻️ make code cleanrer 2026-02-06 04:52:21 +01:00
Fred KISSIE
4a31a7b84b 🚨 fix lint error 2026-02-06 03:55:11 +01:00
Fred KISSIE
538b601b1e Merge branch 'dev' into refactor/paginated-tables 2026-02-06 03:54:50 +01:00
Fred KISSIE
588f064c25 🚸 make resource enabled switch optimistic 2026-02-06 03:53:14 +01:00
Fred KISSIE
d521e79662 🏷️ fix types 2026-02-06 03:21:00 +01:00
Fred KISSIE
ccddb9244d 🏷️ add types on mode in sqlite 2026-02-06 03:14:03 +01:00
Fred KISSIE
0547396213 ♻️ do not sort client resources 2026-02-06 02:44:23 +01:00
Fred KISSIE
6c85171091 serverside filter+paginate client resources table 2026-02-06 02:42:15 +01:00
Lokowitz
0f4d1d2a74 add preview server 2026-02-05 19:46:57 +00:00
Lokowitz
941d5c08e3 upgrade packages 2026-02-05 19:26:36 +00:00
Lokowitz
db9f74158b Merge remote-tracking branch 'origin/dev' into update-packages 2026-02-05 19:24:12 +00:00
Fred KISSIE
609ffccd67 🏷️ fix typescript error 2026-02-05 05:35:59 +01:00
Fred KISSIE
748af1d8cb ♻️ cleanup code for searching & filtering 2026-02-05 05:21:25 +01:00
Fred KISSIE
d309ec249e filter resources by status 2026-02-05 03:15:18 +01:00
Fred KISSIE
67949b4968 🚧 wip: healthStatus 2026-02-04 04:10:08 +01:00
Fred KISSIE
1fc40b3017 filter by auth state 2026-02-04 03:42:05 +01:00
Fred KISSIE
bb1a375484 paginate, search & filter resources by enabled 2026-02-04 02:20:28 +01:00
Lokowitz
13c011895d update packages and node 2026-02-02 19:17:40 +00:00
Lokowitz
bd8d0e3392 update packages 2026-02-02 18:48:35 +00:00
Fred KISSIE
cda6b67bef search, filter & paginate sites table 2026-01-31 03:02:39 +01:00
Fred KISSIE
066305b095 toggle column sorting & pagination 2026-01-31 00:45:14 +01:00
Fred KISSIE
89695df012 🚧 wip: pagination and search work 2026-01-30 05:39:01 +01:00
Fred KISSIE
b04385a340 🚧 search on table 2026-01-29 05:48:41 +01:00
Fred KISSIE
d374ea6ea6 🚧wip 2026-01-29 05:07:41 +01:00
Fred KISSIE
01a2820390 🚧 POC: pagination in sites table 2026-01-29 05:07:27 +01:00
Fred KISSIE
c89c1a03da 🎨 use prettier for formatting typescript 2026-01-29 05:05:34 +01:00
Fred KISSIE
38ac4c5980 🚧 wip: paginated tables 2026-01-28 04:46:54 +01:00
Fred KISSIE
ed3ee64e4b support pathname in logo URL in branding page 2026-01-28 03:04:12 +01:00
115 changed files with 8553 additions and 4351 deletions

View File

@@ -32,4 +32,5 @@ migrations/
config/ config/
build.ts build.ts
tsconfig.json tsconfig.json
Dockerfile*
migrations/ migrations/

View File

@@ -525,10 +525,41 @@ jobs:
VERIFIED_INDEX_KEYLESS=false VERIFIED_INDEX_KEYLESS=false
fi fi
# Check if verification succeeded # If index verification fails, attempt to verify child platform manifests
if [ "${VERIFIED_INDEX}" != "true" ] && [ "${VERIFIED_INDEX_KEYLESS}" != "true" ]; then if [ "${VERIFIED_INDEX}" != "true" ] || [ "${VERIFIED_INDEX_KEYLESS}" != "true" ]; then
echo "⚠️ WARNING: Verification not available for ${BASE_IMAGE}:${IMAGE_TAG}" echo "Index verification not available; attempting child manifest verification for ${BASE_IMAGE}:${IMAGE_TAG}"
echo "This may be due to registry propagation delays. Continuing anyway." CHILD_VERIFIED=false
for ARCH in arm64 amd64; do
CHILD_TAG="${IMAGE_TAG}-${ARCH}"
echo "Resolving child digest for ${BASE_IMAGE}:${CHILD_TAG}"
CHILD_DIGEST="$(skopeo inspect --retry-times 3 docker://${BASE_IMAGE}:${CHILD_TAG} | jq -r '.Digest' || true)"
if [ -n "${CHILD_DIGEST}" ] && [ "${CHILD_DIGEST}" != "null" ]; then
CHILD_REF="${BASE_IMAGE}@${CHILD_DIGEST}"
echo "==> cosign verify (public key) child ${CHILD_REF}"
if retry_verify "cosign verify --key env://COSIGN_PUBLIC_KEY '${CHILD_REF}' -o text"; then
CHILD_VERIFIED=true
echo "Public key verification succeeded for child ${CHILD_REF}"
else
echo "Public key verification failed for child ${CHILD_REF}"
fi
echo "==> cosign verify (keyless policy) child ${CHILD_REF}"
if retry_verify "cosign verify --certificate-oidc-issuer '${issuer}' --certificate-identity-regexp '${id_regex}' '${CHILD_REF}' -o text"; then
CHILD_VERIFIED=true
echo "Keyless verification succeeded for child ${CHILD_REF}"
else
echo "Keyless verification failed for child ${CHILD_REF}"
fi
else
echo "No child digest found for ${BASE_IMAGE}:${CHILD_TAG}; skipping"
fi
done
if [ "${CHILD_VERIFIED}" != "true" ]; then
echo "Failed to verify index and no child manifests verified for ${BASE_IMAGE}:${IMAGE_TAG}"
exit 1
fi
fi fi
) || TAG_FAILED=true ) || TAG_FAILED=true

1
.gitignore vendored
View File

@@ -53,3 +53,4 @@ tsconfig.json
hydrateSaas.ts hydrateSaas.ts
CLAUDE.md CLAUDE.md
drizzle.config.ts drizzle.config.ts
server/setup/migrations.ts

View File

@@ -10,7 +10,7 @@
"editor.defaultFormatter": "esbenp.prettier-vscode" "editor.defaultFormatter": "esbenp.prettier-vscode"
}, },
"[typescript]": { "[typescript]": {
"editor.defaultFormatter": "vscode.typescript-language-features" "editor.defaultFormatter": "esbenp.prettier-vscode"
}, },
"[typescriptreact]": { "[typescriptreact]": {
"editor.defaultFormatter": "esbenp.prettier-vscode" "editor.defaultFormatter": "esbenp.prettier-vscode"

View File

@@ -1,33 +1,54 @@
FROM node:24-alpine AS builder FROM node:24-alpine AS base
WORKDIR /app WORKDIR /app
ARG BUILD=oss
ARG DATABASE=sqlite
RUN apk add --no-cache python3 make g++ RUN apk add --no-cache python3 make g++
# COPY package.json package-lock.json ./
COPY package*.json ./ COPY package*.json ./
FROM base AS builder-dev
RUN npm ci RUN npm ci
COPY . . COPY . .
ARG BUILD=oss
ARG DATABASE=sqlite
RUN if [ "$BUILD" = "oss" ]; then rm -rf server/private; fi && \ RUN if [ "$BUILD" = "oss" ]; then rm -rf server/private; fi && \
npm run set:$DATABASE && \ npm run set:$DATABASE && \
npm run set:$BUILD && \ npm run set:$BUILD && \
npm run db:generate && \ npm run db:generate && \
npm run build && \ npm run build && \
npm run build:cli npm run build:cli && \
test -f dist/server.mjs
# test to make sure the build output is there and error if not FROM base AS builder
RUN test -f dist/server.mjs
# Prune dev dependencies and clean up to prepare for copy to runner RUN npm ci --omit=dev
RUN npm prune --omit=dev && npm cache clean --force
FROM node:24-alpine AS runner FROM node:24-alpine AS runner
WORKDIR /app
RUN apk add --no-cache curl tzdata
COPY --from=builder /app/node_modules ./node_modules
COPY --from=builder /app/package.json ./package.json
COPY --from=builder-dev /app/.next/standalone ./
COPY --from=builder-dev /app/.next/static ./.next/static
COPY --from=builder-dev /app/dist ./dist
COPY --from=builder-dev /app/server/migrations ./dist/init
COPY ./cli/wrapper.sh /usr/local/bin/pangctl
RUN chmod +x /usr/local/bin/pangctl ./dist/cli.mjs
COPY server/db/names.json ./dist/names.json
COPY server/db/ios_models.json ./dist/ios_models.json
COPY server/db/mac_models.json ./dist/mac_models.json
COPY public ./public
# OCI Image Labels - Build Args for dynamic values # OCI Image Labels - Build Args for dynamic values
ARG VERSION="dev" ARG VERSION="dev"
ARG REVISION="" ARG REVISION=""
@@ -38,28 +59,6 @@ ARG LICENSE="AGPL-3.0"
ARG IMAGE_TITLE="Pangolin" ARG IMAGE_TITLE="Pangolin"
ARG IMAGE_DESCRIPTION="Identity-aware VPN and proxy for remote access to anything, anywhere" ARG IMAGE_DESCRIPTION="Identity-aware VPN and proxy for remote access to anything, anywhere"
WORKDIR /app
# Only curl and tzdata needed at runtime - no build tools!
RUN apk add --no-cache curl tzdata
# Copy pre-built node_modules from builder (already pruned to production only)
# This includes the compiled native modules like better-sqlite3
COPY --from=builder /app/node_modules ./node_modules
COPY --from=builder /app/.next/standalone ./
COPY --from=builder /app/.next/static ./.next/static
COPY --from=builder /app/dist ./dist
COPY --from=builder /app/server/migrations ./dist/init
COPY --from=builder /app/package.json ./package.json
COPY ./cli/wrapper.sh /usr/local/bin/pangctl
RUN chmod +x /usr/local/bin/pangctl ./dist/cli.mjs
COPY server/db/names.json ./dist/names.json
COPY server/db/ios_models.json ./dist/ios_models.json
COPY server/db/mac_models.json ./dist/mac_models.json
COPY public ./public
# OCI Image Labels # OCI Image Labels
# https://github.com/opencontainers/image-spec/blob/main/annotations.md # https://github.com/opencontainers/image-spec/blob/main/annotations.md
LABEL org.opencontainers.image.source="https://github.com/fosrl/pangolin" \ LABEL org.opencontainers.image.source="https://github.com/fosrl/pangolin" \

View File

@@ -1,7 +1,9 @@
FROM node:22-alpine FROM node:24-alpine
WORKDIR /app WORKDIR /app
RUN apk add --no-cache python3 make g++
COPY package*.json ./ COPY package*.json ./
# Install dependencies # Install dependencies

View File

@@ -281,7 +281,7 @@ esbuild
}) })
], ],
sourcemap: "inline", sourcemap: "inline",
target: "node22" target: "node24"
}) })
.then((result) => { .then((result) => {
// Check if there were any errors in the build result // Check if there were any errors in the build result

View File

@@ -473,6 +473,8 @@
"filterByApprovalState": "Filter By Approval State", "filterByApprovalState": "Filter By Approval State",
"approvalListEmpty": "No approvals", "approvalListEmpty": "No approvals",
"approvalState": "Approval State", "approvalState": "Approval State",
"approvalLoadMore": "Load more",
"loadingApprovals": "Loading Approvals",
"approve": "Approve", "approve": "Approve",
"approved": "Approved", "approved": "Approved",
"denied": "Denied", "denied": "Denied",
@@ -1029,6 +1031,7 @@
"pangolinSetup": "Setup - Pangolin", "pangolinSetup": "Setup - Pangolin",
"orgNameRequired": "Organization name is required", "orgNameRequired": "Organization name is required",
"orgIdRequired": "Organization ID is required", "orgIdRequired": "Organization ID is required",
"orgIdMaxLength": "Organization ID must be at most 32 characters",
"orgErrorCreate": "An error occurred while creating org", "orgErrorCreate": "An error occurred while creating org",
"pageNotFound": "Page Not Found", "pageNotFound": "Page Not Found",
"pageNotFoundDescription": "Oops! The page you're looking for doesn't exist.", "pageNotFoundDescription": "Oops! The page you're looking for doesn't exist.",
@@ -1181,7 +1184,8 @@
"actionViewLogs": "View Logs", "actionViewLogs": "View Logs",
"noneSelected": "None selected", "noneSelected": "None selected",
"orgNotFound2": "No organizations found.", "orgNotFound2": "No organizations found.",
"searchProgress": "Search...", "searchPlaceholder": "Search...",
"emptySearchOptions": "No options found",
"create": "Create", "create": "Create",
"orgs": "Organizations", "orgs": "Organizations",
"loginError": "An unexpected error occurred. Please try again.", "loginError": "An unexpected error occurred. Please try again.",
@@ -1263,6 +1267,7 @@
"sidebarLogAndAnalytics": "Log & Analytics", "sidebarLogAndAnalytics": "Log & Analytics",
"sidebarBluePrints": "Blueprints", "sidebarBluePrints": "Blueprints",
"sidebarOrganization": "Organization", "sidebarOrganization": "Organization",
"sidebarBillingAndLicenses": "Billing & Licenses",
"sidebarLogsAnalytics": "Analytics", "sidebarLogsAnalytics": "Analytics",
"blueprints": "Blueprints", "blueprints": "Blueprints",
"blueprintsDescription": "Apply declarative configurations and view previous runs", "blueprintsDescription": "Apply declarative configurations and view previous runs",
@@ -1424,6 +1429,7 @@
"billingSites": "Sites", "billingSites": "Sites",
"billingUsers": "Users", "billingUsers": "Users",
"billingDomains": "Domains", "billingDomains": "Domains",
"billingOrganizations": "Orgs",
"billingRemoteExitNodes": "Remote Nodes", "billingRemoteExitNodes": "Remote Nodes",
"billingNoLimitConfigured": "No limit configured", "billingNoLimitConfigured": "No limit configured",
"billingEstimatedPeriod": "Estimated Billing Period", "billingEstimatedPeriod": "Estimated Billing Period",
@@ -1466,6 +1472,7 @@
"failed": "Failed", "failed": "Failed",
"createNewOrgDescription": "Create a new organization", "createNewOrgDescription": "Create a new organization",
"organization": "Organization", "organization": "Organization",
"primary": "Primary",
"port": "Port", "port": "Port",
"securityKeyManage": "Manage Security Keys", "securityKeyManage": "Manage Security Keys",
"securityKeyDescription": "Add or remove security keys for passwordless authentication", "securityKeyDescription": "Add or remove security keys for passwordless authentication",
@@ -1928,6 +1935,9 @@
"authPageBrandingQuestionRemove": "Are you sure you want to remove the branding for Auth Pages ?", "authPageBrandingQuestionRemove": "Are you sure you want to remove the branding for Auth Pages ?",
"authPageBrandingDeleteConfirm": "Confirm Delete Branding", "authPageBrandingDeleteConfirm": "Confirm Delete Branding",
"brandingLogoURL": "Logo URL", "brandingLogoURL": "Logo URL",
"brandingLogoURLOrPath": "Logo URL or Path",
"brandingLogoPathDescription": "Enter a URL or a local path.",
"brandingLogoURLDescription": "Enter a publicly accessible URL to your logo image.",
"brandingPrimaryColor": "Primary Color", "brandingPrimaryColor": "Primary Color",
"brandingLogoWidth": "Width (px)", "brandingLogoWidth": "Width (px)",
"brandingLogoHeight": "Height (px)", "brandingLogoHeight": "Height (px)",

4127
package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -33,8 +33,8 @@
}, },
"dependencies": { "dependencies": {
"@asteasolutions/zod-to-openapi": "8.4.0", "@asteasolutions/zod-to-openapi": "8.4.0",
"@aws-sdk/client-s3": "3.971.0", "@aws-sdk/client-s3": "3.989.0",
"@faker-js/faker": "10.2.0", "@faker-js/faker": "10.3.0",
"@headlessui/react": "2.2.9", "@headlessui/react": "2.2.9",
"@hookform/resolvers": "5.2.2", "@hookform/resolvers": "5.2.2",
"@monaco-editor/react": "4.7.0", "@monaco-editor/react": "4.7.0",
@@ -59,67 +59,66 @@
"@radix-ui/react-tabs": "1.1.13", "@radix-ui/react-tabs": "1.1.13",
"@radix-ui/react-toast": "1.2.15", "@radix-ui/react-toast": "1.2.15",
"@radix-ui/react-tooltip": "1.2.8", "@radix-ui/react-tooltip": "1.2.8",
"@react-email/components": "1.0.2", "@react-email/components": "1.0.7",
"@react-email/render": "2.0.0", "@react-email/render": "2.0.4",
"@react-email/tailwind": "2.0.2", "@react-email/tailwind": "2.0.4",
"@simplewebauthn/browser": "13.2.2", "@simplewebauthn/browser": "13.2.2",
"@simplewebauthn/server": "13.2.2", "@simplewebauthn/server": "13.2.2",
"@tailwindcss/forms": "0.5.11", "@tailwindcss/forms": "0.5.11",
"@tanstack/react-query": "5.90.12", "@tanstack/react-query": "5.90.21",
"@tanstack/react-table": "8.21.3", "@tanstack/react-table": "8.21.3",
"arctic": "3.7.0", "arctic": "3.7.0",
"axios": "1.13.2", "axios": "1.13.5",
"better-sqlite3": "11.9.1", "better-sqlite3": "11.9.1",
"canvas-confetti": "1.9.4", "canvas-confetti": "1.9.4",
"class-variance-authority": "0.7.1", "class-variance-authority": "0.7.1",
"clsx": "2.1.1", "clsx": "2.1.1",
"cmdk": "1.1.1", "cmdk": "1.1.1",
"cookie-parser": "1.4.7", "cookie-parser": "1.4.7",
"cors": "2.8.5", "cors": "2.8.6",
"crypto-js": "4.2.0", "crypto-js": "4.2.0",
"d3": "7.9.0", "d3": "7.9.0",
"date-fns": "4.1.0",
"drizzle-orm": "0.45.1", "drizzle-orm": "0.45.1",
"eslint": "9.39.2",
"eslint-config-next": "16.1.0",
"express": "5.2.1", "express": "5.2.1",
"express-rate-limit": "8.2.1", "express-rate-limit": "8.2.1",
"glob": "13.0.0", "glob": "13.0.3",
"helmet": "8.1.0", "helmet": "8.1.0",
"http-errors": "2.0.1", "http-errors": "2.0.1",
"input-otp": "1.4.2", "input-otp": "1.4.2",
"ioredis": "5.9.2", "ioredis": "5.9.3",
"jmespath": "0.16.0", "jmespath": "0.16.0",
"js-yaml": "4.1.1", "js-yaml": "4.1.1",
"jsonwebtoken": "9.0.3", "jsonwebtoken": "9.0.3",
"lucide-react": "0.562.0", "lucide-react": "0.563.0",
"maxmind": "5.0.1", "maxmind": "5.0.5",
"moment": "2.30.1", "moment": "2.30.1",
"next": "15.5.9", "next": "15.5.12",
"next-intl": "4.7.0", "next-intl": "4.8.2",
"next-themes": "0.4.6", "next-themes": "0.4.6",
"nextjs-toploader": "3.9.17", "nextjs-toploader": "3.9.17",
"node-cache": "5.1.2", "node-cache": "5.1.2",
"nodemailer": "7.0.11", "nodemailer": "8.0.1",
"oslo": "1.2.1", "oslo": "1.2.1",
"pg": "8.17.1", "pg": "8.18.0",
"posthog-node": "5.23.0", "posthog-node": "5.24.15",
"qrcode.react": "4.2.0", "qrcode.react": "4.2.0",
"react": "19.2.3", "react": "19.2.4",
"react-day-picker": "9.13.0", "react-day-picker": "9.13.2",
"react-dom": "19.2.3", "react-dom": "19.2.4",
"react-easy-sort": "1.8.0", "react-easy-sort": "1.8.0",
"react-hook-form": "7.71.1", "react-hook-form": "7.71.1",
"react-icons": "5.5.0", "react-icons": "5.5.0",
"recharts": "2.15.4", "recharts": "2.15.4",
"reodotdev": "1.0.0", "reodotdev": "1.0.0",
"resend": "6.8.0", "resend": "6.9.2",
"semver": "7.7.3", "semver": "7.7.4",
"stripe": "20.2.0", "sshpk": "^1.18.0",
"stripe": "20.3.1",
"swagger-ui-express": "5.0.1", "swagger-ui-express": "5.0.1",
"tailwind-merge": "3.4.0", "tailwind-merge": "3.4.0",
"topojson-client": "3.1.0", "topojson-client": "3.1.0",
"tw-animate-css": "1.4.0", "tw-animate-css": "1.4.0",
"use-debounce": "^10.1.0",
"uuid": "13.0.0", "uuid": "13.0.0",
"vaul": "1.1.2", "vaul": "1.1.2",
"visionscarto-world-atlas": "1.0.0", "visionscarto-world-atlas": "1.0.0",
@@ -128,14 +127,15 @@
"ws": "8.19.0", "ws": "8.19.0",
"yaml": "2.8.2", "yaml": "2.8.2",
"yargs": "18.0.0", "yargs": "18.0.0",
"zod": "4.3.5", "zod": "4.3.6",
"zod-validation-error": "5.0.0" "zod-validation-error": "5.0.0"
}, },
"devDependencies": { "devDependencies": {
"@dotenvx/dotenvx": "1.51.2", "@dotenvx/dotenvx": "1.52.0",
"@esbuild-plugins/tsconfig-paths": "0.1.2", "@esbuild-plugins/tsconfig-paths": "0.1.2",
"@react-email/preview-server": "5.2.8",
"@tailwindcss/postcss": "4.1.18", "@tailwindcss/postcss": "4.1.18",
"@tanstack/react-query-devtools": "5.91.1", "@tanstack/react-query-devtools": "5.91.3",
"@types/better-sqlite3": "7.6.13", "@types/better-sqlite3": "7.6.13",
"@types/cookie-parser": "1.4.10", "@types/cookie-parser": "1.4.10",
"@types/cors": "2.8.19", "@types/cors": "2.8.19",
@@ -144,30 +144,33 @@
"@types/express": "5.0.6", "@types/express": "5.0.6",
"@types/express-session": "1.18.2", "@types/express-session": "1.18.2",
"@types/jmespath": "0.15.2", "@types/jmespath": "0.15.2",
"@types/js-yaml": "4.0.9",
"@types/jsonwebtoken": "9.0.10", "@types/jsonwebtoken": "9.0.10",
"@types/node": "24.10.2", "@types/node": "25.2.3",
"@types/nodemailer": "7.0.4", "@types/nodemailer": "7.0.9",
"@types/nprogress": "0.2.3", "@types/nprogress": "0.2.3",
"@types/pg": "8.16.0", "@types/pg": "8.16.0",
"@types/react": "19.2.7", "@types/react": "19.2.14",
"@types/react-dom": "19.2.3", "@types/react-dom": "19.2.3",
"@types/semver": "7.7.1", "@types/semver": "7.7.1",
"@types/sshpk": "^1.17.4",
"@types/swagger-ui-express": "4.1.8", "@types/swagger-ui-express": "4.1.8",
"@types/topojson-client": "3.1.5", "@types/topojson-client": "3.1.5",
"@types/ws": "8.18.1", "@types/ws": "8.18.1",
"@types/yargs": "17.0.35", "@types/yargs": "17.0.35",
"@types/js-yaml": "4.0.9",
"babel-plugin-react-compiler": "1.0.0", "babel-plugin-react-compiler": "1.0.0",
"drizzle-kit": "0.31.8", "drizzle-kit": "0.31.9",
"esbuild": "0.27.2", "esbuild": "0.27.3",
"esbuild-node-externals": "1.20.1", "esbuild-node-externals": "1.20.1",
"eslint": "9.39.2",
"eslint-config-next": "16.1.6",
"postcss": "8.5.6", "postcss": "8.5.6",
"prettier": "3.8.0", "prettier": "3.8.1",
"react-email": "5.2.5", "react-email": "5.2.8",
"tailwindcss": "4.1.18", "tailwindcss": "4.1.18",
"tsc-alias": "1.8.16", "tsc-alias": "1.8.16",
"tsx": "4.21.0", "tsx": "4.21.0",
"typescript": "5.9.3", "typescript": "5.9.3",
"typescript-eslint": "8.53.1" "typescript-eslint": "8.55.0"
} }
} }

View File

@@ -131,7 +131,8 @@ export enum ActionsEnum {
viewLogs = "viewLogs", viewLogs = "viewLogs",
exportLogs = "exportLogs", exportLogs = "exportLogs",
listApprovals = "listApprovals", listApprovals = "listApprovals",
updateApprovals = "updateApprovals" updateApprovals = "updateApprovals",
signSshKey = "signSshKey"
} }
export async function checkUserActionPermission( export async function checkUserActionPermission(

View File

@@ -0,0 +1,45 @@
import { db } from "@server/db";
import { and, eq } from "drizzle-orm";
import { roleSiteResources, userSiteResources } from "@server/db";
export async function canUserAccessSiteResource({
userId,
resourceId,
roleId
}: {
userId: string;
resourceId: number;
roleId: number;
}): Promise<boolean> {
const roleResourceAccess = await db
.select()
.from(roleSiteResources)
.where(
and(
eq(roleSiteResources.siteResourceId, resourceId),
eq(roleSiteResources.roleId, roleId)
)
)
.limit(1);
if (roleResourceAccess.length > 0) {
return true;
}
const userResourceAccess = await db
.select()
.from(userSiteResources)
.where(
and(
eq(userSiteResources.userId, userId),
eq(userSiteResources.siteResourceId, resourceId)
)
)
.limit(1);
if (userResourceAccess.length > 0) {
return true;
}
return false;
}

View File

@@ -1,18 +1,16 @@
import {
pgTable,
serial,
varchar,
boolean,
integer,
bigint,
real,
text,
index,
uniqueIndex
} from "drizzle-orm/pg-core";
import { InferSelectModel } from "drizzle-orm";
import { randomUUID } from "crypto"; import { randomUUID } from "crypto";
import { alias } from "yargs"; import { InferSelectModel } from "drizzle-orm";
import {
bigint,
boolean,
index,
integer,
pgTable,
real,
serial,
text,
varchar
} from "drizzle-orm/pg-core";
export const domains = pgTable("domains", { export const domains = pgTable("domains", {
domainId: varchar("domainId").primaryKey(), domainId: varchar("domainId").primaryKey(),
@@ -55,7 +53,11 @@ export const orgs = pgTable("orgs", {
.default(0), .default(0),
settingsLogRetentionDaysAction: integer("settingsLogRetentionDaysAction") // where 0 = dont keep logs and -1 = keep forever and 9001 = end of the following year settingsLogRetentionDaysAction: integer("settingsLogRetentionDaysAction") // where 0 = dont keep logs and -1 = keep forever and 9001 = end of the following year
.notNull() .notNull()
.default(0) .default(0),
sshCaPrivateKey: text("sshCaPrivateKey"), // Encrypted SSH CA private key (PEM format)
sshCaPublicKey: text("sshCaPublicKey"), // SSH CA public key (OpenSSH format)
isBillingOrg: boolean("isBillingOrg"),
billingOrgId: varchar("billingOrgId")
}); });
export const orgDomains = pgTable("orgDomains", { export const orgDomains = pgTable("orgDomains", {
@@ -188,7 +190,9 @@ export const targetHealthCheck = pgTable("targetHealthCheck", {
hcFollowRedirects: boolean("hcFollowRedirects").default(true), hcFollowRedirects: boolean("hcFollowRedirects").default(true),
hcMethod: varchar("hcMethod").default("GET"), hcMethod: varchar("hcMethod").default("GET"),
hcStatus: integer("hcStatus"), // http code hcStatus: integer("hcStatus"), // http code
hcHealth: text("hcHealth").default("unknown"), // "unknown", "healthy", "unhealthy" hcHealth: text("hcHealth")
.$type<"unknown" | "healthy" | "unhealthy">()
.default("unknown"), // "unknown", "healthy", "unhealthy"
hcTlsServerName: text("hcTlsServerName") hcTlsServerName: text("hcTlsServerName")
}); });
@@ -218,7 +222,7 @@ export const siteResources = pgTable("siteResources", {
.references(() => orgs.orgId, { onDelete: "cascade" }), .references(() => orgs.orgId, { onDelete: "cascade" }),
niceId: varchar("niceId").notNull(), niceId: varchar("niceId").notNull(),
name: varchar("name").notNull(), name: varchar("name").notNull(),
mode: varchar("mode").notNull(), // "host" | "cidr" | "port" mode: varchar("mode").$type<"host" | "cidr">().notNull(), // "host" | "cidr" | "port"
protocol: varchar("protocol"), // only for port mode protocol: varchar("protocol"), // only for port mode
proxyPort: integer("proxyPort"), // only for port mode proxyPort: integer("proxyPort"), // only for port mode
destinationPort: integer("destinationPort"), // only for port mode destinationPort: integer("destinationPort"), // only for port mode
@@ -328,7 +332,8 @@ export const userOrgs = pgTable("userOrgs", {
.notNull() .notNull()
.references(() => roles.roleId), .references(() => roles.roleId),
isOwner: boolean("isOwner").notNull().default(false), isOwner: boolean("isOwner").notNull().default(false),
autoProvisioned: boolean("autoProvisioned").default(false) autoProvisioned: boolean("autoProvisioned").default(false),
pamUsername: varchar("pamUsername") // cleaned username for ssh and such
}); });
export const emailVerificationCodes = pgTable("emailVerificationCodes", { export const emailVerificationCodes = pgTable("emailVerificationCodes", {
@@ -984,6 +989,16 @@ export const deviceWebAuthCodes = pgTable("deviceWebAuthCodes", {
}) })
}); });
export const roundTripMessageTracker = pgTable("roundTripMessageTracker", {
messageId: serial("messageId").primaryKey(),
wsClientId: varchar("clientId"),
messageType: varchar("messageType"),
sentAt: bigint("sentAt", { mode: "number" }).notNull(),
receivedAt: bigint("receivedAt", { mode: "number" }),
error: text("error"),
complete: boolean("complete").notNull().default(false)
});
export type Org = InferSelectModel<typeof orgs>; export type Org = InferSelectModel<typeof orgs>;
export type User = InferSelectModel<typeof users>; export type User = InferSelectModel<typeof users>;
export type Site = InferSelectModel<typeof sites>; export type Site = InferSelectModel<typeof sites>;
@@ -1044,3 +1059,4 @@ export type SecurityKey = InferSelectModel<typeof securityKeys>;
export type WebauthnChallenge = InferSelectModel<typeof webauthnChallenge>; export type WebauthnChallenge = InferSelectModel<typeof webauthnChallenge>;
export type DeviceWebAuthCode = InferSelectModel<typeof deviceWebAuthCodes>; export type DeviceWebAuthCode = InferSelectModel<typeof deviceWebAuthCodes>;
export type RequestAuditLog = InferSelectModel<typeof requestAuditLog>; export type RequestAuditLog = InferSelectModel<typeof requestAuditLog>;
export type RoundTripMessageTracker = InferSelectModel<typeof roundTripMessageTracker>;

View File

@@ -1,13 +1,6 @@
import { randomUUID } from "crypto"; import { randomUUID } from "crypto";
import { InferSelectModel } from "drizzle-orm"; import { InferSelectModel } from "drizzle-orm";
import { import { index, integer, sqliteTable, text } from "drizzle-orm/sqlite-core";
sqliteTable,
text,
integer,
index,
uniqueIndex
} from "drizzle-orm/sqlite-core";
import { no } from "zod/v4/locales";
export const domains = sqliteTable("domains", { export const domains = sqliteTable("domains", {
domainId: text("domainId").primaryKey(), domainId: text("domainId").primaryKey(),
@@ -52,7 +45,11 @@ export const orgs = sqliteTable("orgs", {
.default(0), .default(0),
settingsLogRetentionDaysAction: integer("settingsLogRetentionDaysAction") // where 0 = dont keep logs and -1 = keep forever and 9001 = end of the following year settingsLogRetentionDaysAction: integer("settingsLogRetentionDaysAction") // where 0 = dont keep logs and -1 = keep forever and 9001 = end of the following year
.notNull() .notNull()
.default(0) .default(0),
sshCaPrivateKey: text("sshCaPrivateKey"), // Encrypted SSH CA private key (PEM format)
sshCaPublicKey: text("sshCaPublicKey"), // SSH CA public key (OpenSSH format)
isBillingOrg: integer("isBillingOrg", { mode: "boolean" }),
billingOrgId: text("billingOrgId")
}); });
export const userDomains = sqliteTable("userDomains", { export const userDomains = sqliteTable("userDomains", {
@@ -214,7 +211,9 @@ export const targetHealthCheck = sqliteTable("targetHealthCheck", {
}).default(true), }).default(true),
hcMethod: text("hcMethod").default("GET"), hcMethod: text("hcMethod").default("GET"),
hcStatus: integer("hcStatus"), // http code hcStatus: integer("hcStatus"), // http code
hcHealth: text("hcHealth").default("unknown"), // "unknown", "healthy", "unhealthy" hcHealth: text("hcHealth")
.$type<"unknown" | "healthy" | "unhealthy">()
.default("unknown"), // "unknown", "healthy", "unhealthy"
hcTlsServerName: text("hcTlsServerName") hcTlsServerName: text("hcTlsServerName")
}); });
@@ -246,7 +245,7 @@ export const siteResources = sqliteTable("siteResources", {
.references(() => orgs.orgId, { onDelete: "cascade" }), .references(() => orgs.orgId, { onDelete: "cascade" }),
niceId: text("niceId").notNull(), niceId: text("niceId").notNull(),
name: text("name").notNull(), name: text("name").notNull(),
mode: text("mode").notNull(), // "host" | "cidr" | "port" mode: text("mode").$type<"host" | "cidr">().notNull(), // "host" | "cidr" | "port"
protocol: text("protocol"), // only for port mode protocol: text("protocol"), // only for port mode
proxyPort: integer("proxyPort"), // only for port mode proxyPort: integer("proxyPort"), // only for port mode
destinationPort: integer("destinationPort"), // only for port mode destinationPort: integer("destinationPort"), // only for port mode
@@ -638,7 +637,8 @@ export const userOrgs = sqliteTable("userOrgs", {
isOwner: integer("isOwner", { mode: "boolean" }).notNull().default(false), isOwner: integer("isOwner", { mode: "boolean" }).notNull().default(false),
autoProvisioned: integer("autoProvisioned", { autoProvisioned: integer("autoProvisioned", {
mode: "boolean" mode: "boolean"
}).default(false) }).default(false),
pamUsername: text("pamUsername") // cleaned username for ssh and such
}); });
export const emailVerificationCodes = sqliteTable("emailVerificationCodes", { export const emailVerificationCodes = sqliteTable("emailVerificationCodes", {
@@ -1080,6 +1080,16 @@ export const deviceWebAuthCodes = sqliteTable("deviceWebAuthCodes", {
}) })
}); });
export const roundTripMessageTracker = sqliteTable("roundTripMessageTracker", {
messageId: integer("messageId").primaryKey({ autoIncrement: true }),
wsClientId: text("clientId"),
messageType: text("messageType"),
sentAt: integer("sentAt").notNull(),
receivedAt: integer("receivedAt"),
error: text("error"),
complete: integer("complete", { mode: "boolean" }).notNull().default(false)
});
export type Org = InferSelectModel<typeof orgs>; export type Org = InferSelectModel<typeof orgs>;
export type User = InferSelectModel<typeof users>; export type User = InferSelectModel<typeof users>;
export type Site = InferSelectModel<typeof sites>; export type Site = InferSelectModel<typeof sites>;
@@ -1141,3 +1151,6 @@ export type SecurityKey = InferSelectModel<typeof securityKeys>;
export type WebauthnChallenge = InferSelectModel<typeof webauthnChallenge>; export type WebauthnChallenge = InferSelectModel<typeof webauthnChallenge>;
export type RequestAuditLog = InferSelectModel<typeof requestAuditLog>; export type RequestAuditLog = InferSelectModel<typeof requestAuditLog>;
export type DeviceWebAuthCode = InferSelectModel<typeof deviceWebAuthCodes>; export type DeviceWebAuthCode = InferSelectModel<typeof deviceWebAuthCodes>;
export type RoundTripMessageTracker = InferSelectModel<
typeof roundTripMessageTracker
>;

View File

@@ -4,6 +4,7 @@ export enum FeatureId {
EGRESS_DATA_MB = "egressDataMb", EGRESS_DATA_MB = "egressDataMb",
DOMAINS = "domains", DOMAINS = "domains",
REMOTE_EXIT_NODES = "remoteExitNodes", REMOTE_EXIT_NODES = "remoteExitNodes",
ORGINIZATIONS = "organizations",
TIER1 = "tier1" TIER1 = "tier1"
} }
@@ -19,6 +20,8 @@ export async function getFeatureDisplayName(featureId: FeatureId): Promise<strin
return "Domains"; return "Domains";
case FeatureId.REMOTE_EXIT_NODES: case FeatureId.REMOTE_EXIT_NODES:
return "Remote Exit Nodes"; return "Remote Exit Nodes";
case FeatureId.ORGINIZATIONS:
return "Organizations";
case FeatureId.TIER1: case FeatureId.TIER1:
return "Home Lab"; return "Home Lab";
default: default:

View File

@@ -7,18 +7,12 @@ export type LimitSet = Partial<{
}; };
}>; }>;
export const sandboxLimitSet: LimitSet = {
[FeatureId.USERS]: { value: 1, description: "Sandbox limit" },
[FeatureId.SITES]: { value: 1, description: "Sandbox limit" },
[FeatureId.DOMAINS]: { value: 0, description: "Sandbox limit" },
[FeatureId.REMOTE_EXIT_NODES]: { value: 0, description: "Sandbox limit" },
};
export const freeLimitSet: LimitSet = { export const freeLimitSet: LimitSet = {
[FeatureId.SITES]: { value: 5, description: "Basic limit" }, [FeatureId.SITES]: { value: 5, description: "Basic limit" },
[FeatureId.USERS]: { value: 5, description: "Basic limit" }, [FeatureId.USERS]: { value: 5, description: "Basic limit" },
[FeatureId.DOMAINS]: { value: 5, description: "Basic limit" }, [FeatureId.DOMAINS]: { value: 5, description: "Basic limit" },
[FeatureId.REMOTE_EXIT_NODES]: { value: 1, description: "Basic limit" }, [FeatureId.REMOTE_EXIT_NODES]: { value: 1, description: "Basic limit" },
[FeatureId.ORGINIZATIONS]: { value: 1, description: "Basic limit" },
}; };
export const tier1LimitSet: LimitSet = { export const tier1LimitSet: LimitSet = {
@@ -26,6 +20,7 @@ export const tier1LimitSet: LimitSet = {
[FeatureId.SITES]: { value: 10, description: "Home limit" }, [FeatureId.SITES]: { value: 10, description: "Home limit" },
[FeatureId.DOMAINS]: { value: 10, description: "Home limit" }, [FeatureId.DOMAINS]: { value: 10, description: "Home limit" },
[FeatureId.REMOTE_EXIT_NODES]: { value: 1, description: "Home limit" }, [FeatureId.REMOTE_EXIT_NODES]: { value: 1, description: "Home limit" },
[FeatureId.ORGINIZATIONS]: { value: 1, description: "Home limit" },
}; };
export const tier2LimitSet: LimitSet = { export const tier2LimitSet: LimitSet = {
@@ -45,6 +40,10 @@ export const tier2LimitSet: LimitSet = {
value: 3, value: 3,
description: "Team limit" description: "Team limit"
}, },
[FeatureId.ORGINIZATIONS]: {
value: 1,
description: "Team limit"
}
}; };
export const tier3LimitSet: LimitSet = { export const tier3LimitSet: LimitSet = {
@@ -64,4 +63,8 @@ export const tier3LimitSet: LimitSet = {
value: 20, value: 20,
description: "Business limit" description: "Business limit"
}, },
[FeatureId.ORGINIZATIONS]: {
value: 5,
description: "Business limit"
},
}; };

View File

@@ -14,7 +14,8 @@ export enum TierFeature {
TwoFactorEnforcement = "twoFactorEnforcement", // handle downgrade by setting to optional TwoFactorEnforcement = "twoFactorEnforcement", // handle downgrade by setting to optional
SessionDurationPolicies = "sessionDurationPolicies", // handle downgrade by setting to default duration SessionDurationPolicies = "sessionDurationPolicies", // handle downgrade by setting to default duration
PasswordExpirationPolicies = "passwordExpirationPolicies", // handle downgrade by setting to default duration PasswordExpirationPolicies = "passwordExpirationPolicies", // handle downgrade by setting to default duration
AutoProvisioning = "autoProvisioning" // handle downgrade by disabling auto provisioning AutoProvisioning = "autoProvisioning", // handle downgrade by disabling auto provisioning
SshPam = "sshPam"
} }
export const tierMatrix: Record<TierFeature, Tier[]> = { export const tierMatrix: Record<TierFeature, Tier[]> = {
@@ -46,5 +47,6 @@ export const tierMatrix: Record<TierFeature, Tier[]> = {
"tier3", "tier3",
"enterprise" "enterprise"
], ],
[TierFeature.AutoProvisioning]: ["tier1", "tier3", "enterprise"] [TierFeature.AutoProvisioning]: ["tier1", "tier3", "enterprise"],
[TierFeature.SshPam]: ["enterprise"]
}; };

View File

@@ -1,34 +1,19 @@
import { eq, sql, and } from "drizzle-orm"; import { eq, sql, and } from "drizzle-orm";
import { v4 as uuidv4 } from "uuid";
import { PutObjectCommand } from "@aws-sdk/client-s3";
import { import {
db, db,
usage, usage,
customers, customers,
sites,
newts,
limits, limits,
Usage, Usage,
Limit, Limit,
Transaction Transaction,
orgs
} from "@server/db"; } from "@server/db";
import { FeatureId, getFeatureMeterId } from "./features"; import { FeatureId, getFeatureMeterId } from "./features";
import logger from "@server/logger"; import logger from "@server/logger";
import { sendToClient } from "#dynamic/routers/ws";
import { build } from "@server/build"; import { build } from "@server/build";
import { s3Client } from "@server/lib/s3";
import cache from "@server/lib/cache"; import cache from "@server/lib/cache";
interface StripeEvent {
identifier?: string;
timestamp: number;
event_name: string;
payload: {
value: number;
stripe_customer_id: string;
};
}
export function noop() { export function noop() {
if (build !== "saas") { if (build !== "saas") {
return true; return true;
@@ -37,41 +22,11 @@ export function noop() {
} }
export class UsageService { export class UsageService {
private bucketName: string | undefined;
private events: StripeEvent[] = [];
private lastUploadTime: number = Date.now();
private isUploading: boolean = false;
constructor() { constructor() {
if (noop()) { if (noop()) {
return; return;
} }
// this.bucketName = process.env.S3_BUCKET || undefined;
// // Periodically check and upload events
// setInterval(() => {
// this.checkAndUploadEvents().catch((err) => {
// logger.error("Error in periodic event upload:", err);
// });
// }, 30000); // every 30 seconds
// // Handle graceful shutdown on SIGTERM
// process.on("SIGTERM", async () => {
// logger.info(
// "SIGTERM received, uploading events before shutdown..."
// );
// await this.forceUpload();
// logger.info("Events uploaded, proceeding with shutdown");
// });
// // Handle SIGINT as well (Ctrl+C)
// process.on("SIGINT", async () => {
// logger.info("SIGINT received, uploading events before shutdown...");
// await this.forceUpload();
// logger.info("Events uploaded, proceeding with shutdown");
// process.exit(0);
// });
} }
/** /**
@@ -91,6 +46,8 @@ export class UsageService {
return null; return null;
} }
let orgIdToUse = await this.getBillingOrg(orgId, transaction);
// Truncate value to 11 decimal places // Truncate value to 11 decimal places
value = this.truncateValue(value); value = this.truncateValue(value);
@@ -100,20 +57,10 @@ export class UsageService {
while (attempt <= maxRetries) { while (attempt <= maxRetries) {
try { try {
// Get subscription data for this org (with caching)
const customerId = await this.getCustomerId(orgId, featureId);
if (!customerId) {
logger.warn(
`No subscription data found for org ${orgId} and feature ${featureId}`
);
return null;
}
let usage; let usage;
if (transaction) { if (transaction) {
usage = await this.internalAddUsage( usage = await this.internalAddUsage(
orgId, orgIdToUse,
featureId, featureId,
value, value,
transaction transaction
@@ -121,7 +68,7 @@ export class UsageService {
} else { } else {
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
usage = await this.internalAddUsage( usage = await this.internalAddUsage(
orgId, orgIdToUse,
featureId, featureId,
value, value,
trx trx
@@ -129,11 +76,6 @@ export class UsageService {
}); });
} }
// Log event for Stripe
// if (privateConfig.getRawPrivateConfig().flags.usage_reporting) {
// await this.logStripeEvent(featureId, value, customerId);
// }
return usage || null; return usage || null;
} catch (error: any) { } catch (error: any) {
// Check if this is a deadlock error // Check if this is a deadlock error
@@ -150,7 +92,7 @@ export class UsageService {
const delay = baseDelay + jitter; const delay = baseDelay + jitter;
logger.warn( logger.warn(
`Deadlock detected for ${orgId}/${featureId}, retrying attempt ${attempt}/${maxRetries} after ${delay.toFixed(0)}ms` `Deadlock detected for ${orgIdToUse}/${featureId}, retrying attempt ${attempt}/${maxRetries} after ${delay.toFixed(0)}ms`
); );
await new Promise((resolve) => setTimeout(resolve, delay)); await new Promise((resolve) => setTimeout(resolve, delay));
@@ -158,7 +100,7 @@ export class UsageService {
} }
logger.error( logger.error(
`Failed to add usage for ${orgId}/${featureId} after ${attempt} attempts:`, `Failed to add usage for ${orgIdToUse}/${featureId} after ${attempt} attempts:`,
error error
); );
break; break;
@@ -169,7 +111,7 @@ export class UsageService {
} }
private async internalAddUsage( private async internalAddUsage(
orgId: string, orgId: string, // here the orgId is the billing org already resolved by getBillingOrg in updateCount
featureId: FeatureId, featureId: FeatureId,
value: number, value: number,
trx: Transaction trx: Transaction
@@ -188,17 +130,22 @@ export class UsageService {
featureId, featureId,
orgId, orgId,
meterId, meterId,
latestValue: value, instantaneousValue: value || 0,
latestValue: value || 0,
updatedAt: Math.floor(Date.now() / 1000) updatedAt: Math.floor(Date.now() / 1000)
}) })
.onConflictDoUpdate({ .onConflictDoUpdate({
target: usage.usageId, target: usage.usageId,
set: { set: {
latestValue: sql`${usage.latestValue} + ${value}` instantaneousValue: sql`COALESCE(${usage.instantaneousValue}, 0) + ${value}`
} }
}) })
.returning(); .returning();
logger.debug(
`Added usage for org ${orgId} feature ${featureId}: +${value}, new instantaneousValue: ${returnUsage.instantaneousValue}`
);
return returnUsage; return returnUsage;
} }
@@ -221,18 +168,10 @@ export class UsageService {
if (noop()) { if (noop()) {
return; return;
} }
try {
if (!customerId) {
customerId =
(await this.getCustomerId(orgId, featureId)) || undefined;
if (!customerId) {
logger.warn(
`No subscription data found for org ${orgId} and feature ${featureId}`
);
return;
}
}
let orgIdToUse = await this.getBillingOrg(orgId);
try {
// Truncate value to 11 decimal places if provided // Truncate value to 11 decimal places if provided
if (value !== undefined && value !== null) { if (value !== undefined && value !== null) {
value = this.truncateValue(value); value = this.truncateValue(value);
@@ -242,7 +181,7 @@ export class UsageService {
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
// Get existing meter record // Get existing meter record
const usageId = `${orgId}-${featureId}`; const usageId = `${orgIdToUse}-${featureId}`;
// Get current usage record // Get current usage record
[currentUsage] = await trx [currentUsage] = await trx
.select() .select()
@@ -264,7 +203,7 @@ export class UsageService {
await trx.insert(usage).values({ await trx.insert(usage).values({
usageId, usageId,
featureId, featureId,
orgId, orgId: orgIdToUse,
meterId, meterId,
instantaneousValue: value || 0, instantaneousValue: value || 0,
latestValue: value || 0, latestValue: value || 0,
@@ -278,7 +217,7 @@ export class UsageService {
// } // }
} catch (error) { } catch (error) {
logger.error( logger.error(
`Failed to update count usage for ${orgId}/${featureId}:`, `Failed to update count usage for ${orgIdToUse}/${featureId}:`,
error error
); );
} }
@@ -288,7 +227,9 @@ export class UsageService {
orgId: string, orgId: string,
featureId: FeatureId featureId: FeatureId
): Promise<string | null> { ): Promise<string | null> {
const cacheKey = `customer_${orgId}_${featureId}`; let orgIdToUse = await this.getBillingOrg(orgId);
const cacheKey = `customer_${orgIdToUse}_${featureId}`;
const cached = cache.get<string>(cacheKey); const cached = cache.get<string>(cacheKey);
if (cached) { if (cached) {
@@ -302,7 +243,7 @@ export class UsageService {
customerId: customers.customerId customerId: customers.customerId
}) })
.from(customers) .from(customers)
.where(eq(customers.orgId, orgId)) .where(eq(customers.orgId, orgIdToUse))
.limit(1); .limit(1);
if (!customer) { if (!customer) {
@@ -317,112 +258,13 @@ export class UsageService {
return customerId; return customerId;
} catch (error) { } catch (error) {
logger.error( logger.error(
`Failed to get subscription data for ${orgId}/${featureId}:`, `Failed to get subscription data for ${orgIdToUse}/${featureId}:`,
error error
); );
return null; return null;
} }
} }
private async logStripeEvent(
featureId: FeatureId,
value: number,
customerId: string
): Promise<void> {
// Truncate value to 11 decimal places before sending to Stripe
const truncatedValue = this.truncateValue(value);
const event: StripeEvent = {
identifier: uuidv4(),
timestamp: Math.floor(new Date().getTime() / 1000),
event_name: featureId,
payload: {
value: truncatedValue,
stripe_customer_id: customerId
}
};
this.addEventToMemory(event);
await this.checkAndUploadEvents();
}
private addEventToMemory(event: StripeEvent): void {
if (!this.bucketName) {
logger.warn(
"S3 bucket name is not configured, skipping event storage."
);
return;
}
this.events.push(event);
}
private async checkAndUploadEvents(): Promise<void> {
const now = Date.now();
const timeSinceLastUpload = now - this.lastUploadTime;
// Check if at least 1 minute has passed since last upload
if (timeSinceLastUpload >= 60000 && this.events.length > 0) {
await this.uploadEventsToS3();
}
}
private async uploadEventsToS3(): Promise<void> {
if (!this.bucketName) {
logger.warn(
"S3 bucket name is not configured, skipping S3 upload."
);
return;
}
if (this.events.length === 0) {
return;
}
// Check if already uploading
if (this.isUploading) {
logger.debug("Already uploading events, skipping");
return;
}
this.isUploading = true;
try {
// Take a snapshot of current events and clear the array
const eventsToUpload = [...this.events];
this.events = [];
this.lastUploadTime = Date.now();
const fileName = this.generateEventFileName();
const fileContent = JSON.stringify(eventsToUpload, null, 2);
// Upload to S3
const uploadCommand = new PutObjectCommand({
Bucket: this.bucketName,
Key: fileName,
Body: fileContent,
ContentType: "application/json"
});
await s3Client.send(uploadCommand);
logger.info(
`Uploaded ${fileName} to S3 with ${eventsToUpload.length} events`
);
} catch (error) {
logger.error("Failed to upload events to S3:", error);
// Note: Events are lost if upload fails. In a production system,
// you might want to add the events back to the array or implement retry logic
} finally {
this.isUploading = false;
}
}
private generateEventFileName(): string {
const timestamp = new Date().toISOString().replace(/[:.]/g, "-");
const uuid = uuidv4().substring(0, 8);
return `events-${timestamp}-${uuid}.json`;
}
public async getUsage( public async getUsage(
orgId: string, orgId: string,
featureId: FeatureId, featureId: FeatureId,
@@ -432,7 +274,9 @@ export class UsageService {
return null; return null;
} }
const usageId = `${orgId}-${featureId}`; let orgIdToUse = await this.getBillingOrg(orgId, trx);
const usageId = `${orgIdToUse}-${featureId}`;
try { try {
const [result] = await trx const [result] = await trx
@@ -444,7 +288,7 @@ export class UsageService {
if (!result) { if (!result) {
// Lets create one if it doesn't exist using upsert to handle race conditions // Lets create one if it doesn't exist using upsert to handle race conditions
logger.info( logger.info(
`Creating new usage record for ${orgId}/${featureId}` `Creating new usage record for ${orgIdToUse}/${featureId}`
); );
const meterId = getFeatureMeterId(featureId); const meterId = getFeatureMeterId(featureId);
@@ -454,7 +298,7 @@ export class UsageService {
.values({ .values({
usageId, usageId,
featureId, featureId,
orgId, orgId: orgIdToUse,
meterId, meterId,
latestValue: 0, latestValue: 0,
updatedAt: Math.floor(Date.now() / 1000) updatedAt: Math.floor(Date.now() / 1000)
@@ -476,7 +320,7 @@ export class UsageService {
} catch (insertError) { } catch (insertError) {
// Fallback: try to fetch existing record in case of any insert issues // Fallback: try to fetch existing record in case of any insert issues
logger.warn( logger.warn(
`Insert failed for ${orgId}/${featureId}, attempting to fetch existing record:`, `Insert failed for ${orgIdToUse}/${featureId}, attempting to fetch existing record:`,
insertError insertError
); );
const [existingUsage] = await trx const [existingUsage] = await trx
@@ -491,19 +335,41 @@ export class UsageService {
return result; return result;
} catch (error) { } catch (error) {
logger.error( logger.error(
`Failed to get usage for ${orgId}/${featureId}:`, `Failed to get usage for ${orgIdToUse}/${featureId}:`,
error error
); );
throw error; throw error;
} }
} }
public async forceUpload(): Promise<void> { public async getBillingOrg(
if (this.events.length > 0) { orgId: string,
// Force upload regardless of time trx: Transaction | typeof db = db
this.lastUploadTime = 0; // Reset to force upload ): Promise<string> {
await this.uploadEventsToS3(); let orgIdToUse = orgId;
// get the org
const [org] = await trx
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (!org) {
throw new Error(`Organization with ID ${orgId} not found`);
} }
if (!org.isBillingOrg) {
if (org.billingOrgId) {
orgIdToUse = org.billingOrgId;
} else {
throw new Error(
`Organization ${orgId} is not a billing org and does not have a billingOrgId set`
);
}
}
return orgIdToUse;
} }
public async checkLimitSet( public async checkLimitSet(
@@ -515,6 +381,9 @@ export class UsageService {
if (noop()) { if (noop()) {
return false; return false;
} }
let orgIdToUse = await this.getBillingOrg(orgId, trx);
// This method should check the current usage against the limits set for the organization // This method should check the current usage against the limits set for the organization
// and kick out all of the sites on the org // and kick out all of the sites on the org
let hasExceededLimits = false; let hasExceededLimits = false;
@@ -528,7 +397,7 @@ export class UsageService {
.from(limits) .from(limits)
.where( .where(
and( and(
eq(limits.orgId, orgId), eq(limits.orgId, orgIdToUse),
eq(limits.featureId, featureId) eq(limits.featureId, featureId)
) )
); );
@@ -537,11 +406,11 @@ export class UsageService {
orgLimits = await trx orgLimits = await trx
.select() .select()
.from(limits) .from(limits)
.where(eq(limits.orgId, orgId)); .where(eq(limits.orgId, orgIdToUse));
} }
if (orgLimits.length === 0) { if (orgLimits.length === 0) {
logger.debug(`No limits set for org ${orgId}`); logger.debug(`No limits set for org ${orgIdToUse}`);
return false; return false;
} }
@@ -552,7 +421,7 @@ export class UsageService {
currentUsage = usage; currentUsage = usage;
} else { } else {
currentUsage = await this.getUsage( currentUsage = await this.getUsage(
orgId, orgIdToUse,
limit.featureId as FeatureId, limit.featureId as FeatureId,
trx trx
); );
@@ -563,10 +432,10 @@ export class UsageService {
currentUsage?.latestValue || currentUsage?.latestValue ||
0; 0;
logger.debug( logger.debug(
`Current usage for org ${orgId} on feature ${limit.featureId}: ${usageValue}` `Current usage for org ${orgIdToUse} on feature ${limit.featureId}: ${usageValue}`
); );
logger.debug( logger.debug(
`Limit for org ${orgId} on feature ${limit.featureId}: ${limit.value}` `Limit for org ${orgIdToUse} on feature ${limit.featureId}: ${limit.value}`
); );
if ( if (
currentUsage && currentUsage &&
@@ -574,7 +443,7 @@ export class UsageService {
usageValue > limit.value usageValue > limit.value
) { ) {
logger.debug( logger.debug(
`Org ${orgId} has exceeded limit for ${limit.featureId}: ` + `Org ${orgIdToUse} has exceeded limit for ${limit.featureId}: ` +
`${usageValue} > ${limit.value}` `${usageValue} > ${limit.value}`
); );
hasExceededLimits = true; hasExceededLimits = true;
@@ -582,7 +451,7 @@ export class UsageService {
} }
} }
} catch (error) { } catch (error) {
logger.error(`Error checking limits for org ${orgId}:`, error); logger.error(`Error checking limits for org ${orgIdToUse}:`, error);
} }
return hasExceededLimits; return hasExceededLimits;

View File

@@ -1,197 +0,0 @@
import { isValidCIDR } from "@server/lib/validators";
import { getNextAvailableOrgSubnet } from "@server/lib/ip";
import {
actions,
apiKeyOrg,
apiKeys,
db,
domains,
Org,
orgDomains,
orgs,
roleActions,
roles,
userOrgs
} from "@server/db";
import { eq } from "drizzle-orm";
import { defaultRoleAllowedActions } from "@server/routers/role";
import { FeatureId, limitsService, sandboxLimitSet } from "@server/lib/billing";
import { createCustomer } from "#dynamic/lib/billing";
import { usageService } from "@server/lib/billing/usageService";
import config from "@server/lib/config";
export async function createUserAccountOrg(
userId: string,
userEmail: string
): Promise<{
success: boolean;
org?: {
orgId: string;
name: string;
subnet: string;
};
error?: string;
}> {
// const subnet = await getNextAvailableOrgSubnet();
const orgId = "org_" + userId;
const name = `${userEmail}'s Organization`;
// if (!isValidCIDR(subnet)) {
// return {
// success: false,
// error: "Invalid subnet format. Please provide a valid CIDR notation."
// };
// }
// // make sure the subnet is unique
// const subnetExists = await db
// .select()
// .from(orgs)
// .where(eq(orgs.subnet, subnet))
// .limit(1);
// if (subnetExists.length > 0) {
// return { success: false, error: `Subnet ${subnet} already exists` };
// }
// make sure the orgId is unique
const orgExists = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (orgExists.length > 0) {
return {
success: false,
error: `Organization with ID ${orgId} already exists`
};
}
let error = "";
let org: Org | null = null;
await db.transaction(async (trx) => {
const allDomains = await trx
.select()
.from(domains)
.where(eq(domains.configManaged, true));
const utilitySubnet = config.getRawConfig().orgs.utility_subnet_group;
const newOrg = await trx
.insert(orgs)
.values({
orgId,
name,
// subnet
subnet: "100.90.128.0/24", // TODO: this should not be hardcoded - or can it be the same in all orgs?
utilitySubnet: utilitySubnet,
createdAt: new Date().toISOString()
})
.returning();
if (newOrg.length === 0) {
error = "Failed to create organization";
trx.rollback();
return;
}
org = newOrg[0];
// Create admin role within the same transaction
const [insertedRole] = await trx
.insert(roles)
.values({
orgId: newOrg[0].orgId,
isAdmin: true,
name: "Admin",
description: "Admin role with the most permissions"
})
.returning({ roleId: roles.roleId });
if (!insertedRole || !insertedRole.roleId) {
error = "Failed to create Admin role";
trx.rollback();
return;
}
const roleId = insertedRole.roleId;
// Get all actions and create role actions
const actionIds = await trx.select().from(actions).execute();
if (actionIds.length > 0) {
await trx.insert(roleActions).values(
actionIds.map((action) => ({
roleId,
actionId: action.actionId,
orgId: newOrg[0].orgId
}))
);
}
if (allDomains.length) {
await trx.insert(orgDomains).values(
allDomains.map((domain) => ({
orgId: newOrg[0].orgId,
domainId: domain.domainId
}))
);
}
await trx.insert(userOrgs).values({
userId,
orgId: newOrg[0].orgId,
roleId: roleId,
isOwner: true
});
const memberRole = await trx
.insert(roles)
.values({
name: "Member",
description: "Members can only view resources",
orgId
})
.returning();
await trx.insert(roleActions).values(
defaultRoleAllowedActions.map((action) => ({
roleId: memberRole[0].roleId,
actionId: action,
orgId
}))
);
});
await limitsService.applyLimitSetToOrg(orgId, sandboxLimitSet);
if (!org) {
return { success: false, error: "Failed to create org" };
}
if (error) {
return {
success: false,
error: `Failed to create org: ${error}`
};
}
// make sure we have the stripe customer
const customerId = await createCustomer(orgId, userEmail);
if (customerId) {
await usageService.updateCount(orgId, FeatureId.USERS, 1, customerId); // Only 1 because we are crating the org
}
return {
org: {
orgId,
name,
// subnet
subnet: "100.90.128.0/24"
},
success: true
};
}

View File

@@ -4,14 +4,18 @@ import {
clientSitesAssociationsCache, clientSitesAssociationsCache,
db, db,
domains, domains,
exitNodeOrgs,
exitNodes,
olms, olms,
orgDomains, orgDomains,
orgs, orgs,
remoteExitNodes,
resources, resources,
sites sites,
userOrgs
} from "@server/db"; } from "@server/db";
import { newts, newtSessions } from "@server/db"; import { newts, newtSessions } from "@server/db";
import { eq, and, inArray, sql } from "drizzle-orm"; import { eq, and, inArray, sql, count, countDistinct } from "drizzle-orm";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import logger from "@server/logger"; import logger from "@server/logger";
@@ -19,6 +23,8 @@ import { sendToClient } from "#dynamic/routers/ws";
import { deletePeer } from "@server/routers/gerbil/peers"; import { deletePeer } from "@server/routers/gerbil/peers";
import { OlmErrorCodes } from "@server/routers/olm/error"; import { OlmErrorCodes } from "@server/routers/olm/error";
import { sendTerminateClient } from "@server/routers/client/terminate"; import { sendTerminateClient } from "@server/routers/client/terminate";
import { usageService } from "./billing/usageService";
import { FeatureId } from "./billing";
export type DeleteOrgByIdResult = { export type DeleteOrgByIdResult = {
deletedNewtIds: string[]; deletedNewtIds: string[];
@@ -60,6 +66,11 @@ export async function deleteOrgById(
const deletedNewtIds: string[] = []; const deletedNewtIds: string[] = [];
const olmsToTerminate: string[] = []; const olmsToTerminate: string[] = [];
let domainCount: number | null = null;
let siteCount: number | null = null;
let userCount: number | null = null;
let remoteExitNodeCount: number | null = null;
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
for (const site of orgSites) { for (const site of orgSites) {
if (site.pubKey) { if (site.pubKey) {
@@ -74,9 +85,7 @@ export async function deleteOrgById(
deletedNewtIds.push(deletedNewt.newtId); deletedNewtIds.push(deletedNewt.newtId);
await trx await trx
.delete(newtSessions) .delete(newtSessions)
.where( .where(eq(newtSessions.newtId, deletedNewt.newtId));
eq(newtSessions.newtId, deletedNewt.newtId)
);
} }
} }
} }
@@ -137,9 +146,74 @@ export async function deleteOrgById(
.where(inArray(domains.domainId, domainIdsToDelete)); .where(inArray(domains.domainId, domainIdsToDelete));
} }
await trx.delete(resources).where(eq(resources.orgId, orgId)); await trx.delete(resources).where(eq(resources.orgId, orgId));
await usageService.add(orgId, FeatureId.ORGINIZATIONS, -1, trx); // here we are decreasing the org count BEFORE deleting the org because we need to still be able to get the org to get the billing org inside of here
await trx.delete(orgs).where(eq(orgs.orgId, orgId)); await trx.delete(orgs).where(eq(orgs.orgId, orgId));
if (org.billingOrgId) {
const billingOrgs = await trx
.select()
.from(orgs)
.where(eq(orgs.billingOrgId, org.billingOrgId));
if (billingOrgs.length > 0) {
const billingOrgIds = billingOrgs.map((org) => org.orgId);
const [domainCountRes] = await trx
.select({ count: count() })
.from(orgDomains)
.where(inArray(orgDomains.orgId, billingOrgIds));
domainCount = domainCountRes.count;
const [siteCountRes] = await trx
.select({ count: count() })
.from(sites)
.where(inArray(sites.orgId, billingOrgIds));
siteCount = siteCountRes.count;
const [userCountRes] = await trx
.select({ count: countDistinct(userOrgs.userId) })
.from(userOrgs)
.where(inArray(userOrgs.orgId, billingOrgIds));
userCount = userCountRes.count;
const [remoteExitNodeCountRes] = await trx
.select({ count: countDistinct(exitNodeOrgs.exitNodeId) })
.from(exitNodeOrgs)
.where(inArray(exitNodeOrgs.orgId, billingOrgIds));
remoteExitNodeCount = remoteExitNodeCountRes.count;
}
}
}); });
if (org.billingOrgId) {
usageService.updateCount(
org.billingOrgId,
FeatureId.DOMAINS,
domainCount ?? 0
);
usageService.updateCount(
org.billingOrgId,
FeatureId.SITES,
siteCount ?? 0
);
usageService.updateCount(
org.billingOrgId,
FeatureId.USERS,
userCount ?? 0
);
usageService.updateCount(
org.billingOrgId,
FeatureId.REMOTE_EXIT_NODES,
remoteExitNodeCount ?? 0
);
}
return { deletedNewtIds, olmsToTerminate }; return { deletedNewtIds, olmsToTerminate };
} }
@@ -155,15 +229,13 @@ export function sendTerminationMessages(result: DeleteOrgByIdResult): void {
); );
} }
for (const olmId of result.olmsToTerminate) { for (const olmId of result.olmsToTerminate) {
sendTerminateClient( sendTerminateClient(0, OlmErrorCodes.TERMINATED_REKEYED, olmId).catch(
0, (error) => {
OlmErrorCodes.TERMINATED_REKEYED,
olmId
).catch((error) => {
logger.error( logger.error(
"Failed to send termination message to olm:", "Failed to send termination message to olm:",
error error
); );
}); }
);
} }
} }

142
server/lib/userOrg.ts Normal file
View File

@@ -0,0 +1,142 @@
import {
db,
Org,
orgs,
resources,
siteResources,
sites,
Transaction,
UserOrg,
userOrgs,
userResources,
userSiteResources,
userSites
} from "@server/db";
import { eq, and, inArray, ne, exists } from "drizzle-orm";
import { usageService } from "@server/lib/billing/usageService";
import { FeatureId } from "@server/lib/billing";
export async function assignUserToOrg(
org: Org,
values: typeof userOrgs.$inferInsert,
trx: Transaction | typeof db = db
) {
const [userOrg] = await trx.insert(userOrgs).values(values).returning();
// calculate if the user is in any other of the orgs before we count it as an add to the billing org
if (org.billingOrgId) {
const otherBillingOrgs = await trx
.select()
.from(orgs)
.where(
and(
eq(orgs.billingOrgId, org.billingOrgId),
ne(orgs.orgId, org.orgId)
)
);
const billingOrgIds = otherBillingOrgs.map((o) => o.orgId);
const orgsInBillingDomainThatTheUserIsStillIn = await trx
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.userId, userOrg.userId),
inArray(userOrgs.orgId, billingOrgIds)
)
);
if (orgsInBillingDomainThatTheUserIsStillIn.length === 0) {
await usageService.add(org.orgId, FeatureId.USERS, 1, trx);
}
}
}
export async function removeUserFromOrg(
org: Org,
userId: string,
trx: Transaction | typeof db = db
) {
await trx
.delete(userOrgs)
.where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, org.orgId)));
await trx.delete(userResources).where(
and(
eq(userResources.userId, userId),
exists(
trx
.select()
.from(resources)
.where(
and(
eq(resources.resourceId, userResources.resourceId),
eq(resources.orgId, org.orgId)
)
)
)
)
);
await trx.delete(userSiteResources).where(
and(
eq(userSiteResources.userId, userId),
exists(
trx
.select()
.from(siteResources)
.where(
and(
eq(
siteResources.siteResourceId,
userSiteResources.siteResourceId
),
eq(siteResources.orgId, org.orgId)
)
)
)
)
);
await trx.delete(userSites).where(
and(
eq(userSites.userId, userId),
exists(
db
.select()
.from(sites)
.where(
and(
eq(sites.siteId, userSites.siteId),
eq(sites.orgId, org.orgId)
)
)
)
)
);
// calculate if the user is in any other of the orgs before we count it as an remove to the billing org
if (org.billingOrgId) {
const billingOrgs = await trx
.select()
.from(orgs)
.where(eq(orgs.billingOrgId, org.billingOrgId));
const billingOrgIds = billingOrgs.map((o) => o.orgId);
const orgsInBillingDomainThatTheUserIsStillIn = await trx
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.userId, userId),
inArray(userOrgs.orgId, billingOrgIds)
)
);
if (orgsInBillingDomainThatTheUserIsStillIn.length === 0) {
await usageService.add(org.orgId, FeatureId.USERS, -1, trx);
}
}
}

View File

@@ -16,5 +16,6 @@ export enum OpenAPITags {
Client = "Client", Client = "Client",
ApiKey = "API Key", ApiKey = "API Key",
Domain = "Domain", Domain = "Domain",
Blueprint = "Blueprint" Blueprint = "Blueprint",
Ssh = "SSH"
} }

View File

@@ -12,7 +12,8 @@
*/ */
import { build } from "@server/build"; import { build } from "@server/build";
import { db, customers, subscriptions } from "@server/db"; import { db, customers, subscriptions, orgs } from "@server/db";
import logger from "@server/logger";
import { Tier } from "@server/types/Tiers"; import { Tier } from "@server/types/Tiers";
import { eq, and, ne } from "drizzle-orm"; import { eq, and, ne } from "drizzle-orm";
@@ -27,14 +28,38 @@ export async function getOrgTierData(
} }
try { try {
const [org] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (!org) {
return { tier, active };
}
let orgIdToUse = org.orgId;
if (!org.isBillingOrg) {
if (!org.billingOrgId) {
logger.warn(
`Org ${orgId} is not a billing org and does not have a billingOrgId`
);
return { tier, active };
}
orgIdToUse = org.billingOrgId;
}
// Get customer for org // Get customer for org
const [customer] = await db const [customer] = await db
.select() .select()
.from(customers) .from(customers)
.where(eq(customers.orgId, orgId)) .where(eq(customers.orgId, orgIdToUse))
.limit(1); .limit(1);
if (customer) { if (!customer) {
return { tier, active };
}
// Query for active subscriptions that are not license type // Query for active subscriptions that are not license type
const [subscription] = await db const [subscription] = await db
.select() .select()
@@ -59,7 +84,6 @@ export async function getOrgTierData(
active = true; active = true;
} }
} }
}
} catch (error) { } catch (error) {
// If org not found or error occurs, return null tier and inactive // If org not found or error occurs, return null tier and inactive
// This is acceptable behavior as per the function signature // This is acceptable behavior as per the function signature

442
server/private/lib/sshCA.ts Normal file
View File

@@ -0,0 +1,442 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
import * as crypto from "crypto";
/**
* SSH CA "Server" - Pure TypeScript Implementation
*
* This module provides basic SSH Certificate Authority functionality using
* only Node.js built-in crypto module. No external dependencies or subprocesses.
*
* Usage:
* 1. generateCA() - Creates a new CA key pair, returns CA info including the
* TrustedUserCAKeys line to add to servers
* 2. signPublicKey() - Signs a user's public key with the CA, returns a certificate
*/
// ============================================================================
// SSH Wire Format Helpers
// ============================================================================
/**
* Encode a string in SSH wire format (4-byte length prefix + data)
*/
function encodeString(data: Buffer | string): Buffer {
const buf = typeof data === "string" ? Buffer.from(data, "utf8") : data;
const len = Buffer.alloc(4);
len.writeUInt32BE(buf.length, 0);
return Buffer.concat([len, buf]);
}
/**
* Encode a uint32 in SSH wire format (big-endian)
*/
function encodeUInt32(value: number): Buffer {
const buf = Buffer.alloc(4);
buf.writeUInt32BE(value, 0);
return buf;
}
/**
* Encode a uint64 in SSH wire format (big-endian)
*/
function encodeUInt64(value: bigint): Buffer {
const buf = Buffer.alloc(8);
buf.writeBigUInt64BE(value, 0);
return buf;
}
/**
* Decode a string from SSH wire format at the given offset
* Returns the string buffer and the new offset
*/
function decodeString(data: Buffer, offset: number): { value: Buffer; newOffset: number } {
const len = data.readUInt32BE(offset);
const value = data.subarray(offset + 4, offset + 4 + len);
return { value, newOffset: offset + 4 + len };
}
// ============================================================================
// SSH Public Key Parsing/Encoding
// ============================================================================
/**
* Parse an OpenSSH public key line (e.g., "ssh-ed25519 AAAA... comment")
*/
function parseOpenSSHPublicKey(pubKeyLine: string): {
keyType: string;
keyData: Buffer;
comment: string;
} {
const parts = pubKeyLine.trim().split(/\s+/);
if (parts.length < 2) {
throw new Error("Invalid public key format");
}
const keyType = parts[0];
const keyData = Buffer.from(parts[1], "base64");
const comment = parts.slice(2).join(" ") || "";
// Verify the key type in the blob matches
const { value: blobKeyType } = decodeString(keyData, 0);
if (blobKeyType.toString("utf8") !== keyType) {
throw new Error(`Key type mismatch: ${blobKeyType.toString("utf8")} vs ${keyType}`);
}
return { keyType, keyData, comment };
}
/**
* Encode an Ed25519 public key in OpenSSH format
*/
function encodeEd25519PublicKey(publicKey: Buffer): Buffer {
return Buffer.concat([
encodeString("ssh-ed25519"),
encodeString(publicKey)
]);
}
/**
* Format a public key blob as an OpenSSH public key line
*/
function formatOpenSSHPublicKey(keyBlob: Buffer, comment: string = ""): string {
const { value: keyType } = decodeString(keyBlob, 0);
const base64 = keyBlob.toString("base64");
return `${keyType.toString("utf8")} ${base64}${comment ? " " + comment : ""}`;
}
// ============================================================================
// SSH Certificate Building
// ============================================================================
interface CertificateOptions {
/** Serial number for the certificate */
serial?: bigint;
/** Certificate type: 1 = user, 2 = host */
certType?: number;
/** Key ID (usually username or identifier) */
keyId: string;
/** List of valid principals (usernames the cert is valid for) */
validPrincipals: string[];
/** Valid after timestamp (seconds since epoch) */
validAfter?: bigint;
/** Valid before timestamp (seconds since epoch) */
validBefore?: bigint;
/** Critical options (usually empty for user certs) */
criticalOptions?: Map<string, string>;
/** Extensions to enable */
extensions?: string[];
}
/**
* Build the extensions section of the certificate
*/
function buildExtensions(extensions: string[]): Buffer {
// Extensions are a series of name-value pairs, sorted by name
// For boolean extensions, the value is empty
const sortedExtensions = [...extensions].sort();
const parts: Buffer[] = [];
for (const ext of sortedExtensions) {
parts.push(encodeString(ext));
parts.push(encodeString("")); // Empty value for boolean extensions
}
return encodeString(Buffer.concat(parts));
}
/**
* Build the critical options section
*/
function buildCriticalOptions(options: Map<string, string>): Buffer {
const sortedKeys = [...options.keys()].sort();
const parts: Buffer[] = [];
for (const key of sortedKeys) {
parts.push(encodeString(key));
parts.push(encodeString(encodeString(options.get(key)!)));
}
return encodeString(Buffer.concat(parts));
}
/**
* Build the valid principals section
*/
function buildPrincipals(principals: string[]): Buffer {
const parts: Buffer[] = [];
for (const principal of principals) {
parts.push(encodeString(principal));
}
return encodeString(Buffer.concat(parts));
}
/**
* Extract the raw Ed25519 public key from an OpenSSH public key blob
*/
function extractEd25519PublicKey(keyBlob: Buffer): Buffer {
const { newOffset } = decodeString(keyBlob, 0); // Skip key type
const { value: publicKey } = decodeString(keyBlob, newOffset);
return publicKey;
}
// ============================================================================
// CA Interface
// ============================================================================
export interface CAKeyPair {
/** CA private key in PEM format (keep this secret!) */
privateKeyPem: string;
/** CA public key in PEM format */
publicKeyPem: string;
/** CA public key in OpenSSH format (for TrustedUserCAKeys) */
publicKeyOpenSSH: string;
/** Raw CA public key bytes (Ed25519) */
publicKeyRaw: Buffer;
}
export interface SignedCertificate {
/** The certificate in OpenSSH format (save as id_ed25519-cert.pub or similar) */
certificate: string;
/** The certificate type string */
certType: string;
/** Serial number */
serial: bigint;
/** Key ID */
keyId: string;
/** Valid principals */
validPrincipals: string[];
/** Valid from timestamp */
validAfter: Date;
/** Valid until timestamp */
validBefore: Date;
}
// ============================================================================
// Main Functions
// ============================================================================
/**
* Generate a new SSH Certificate Authority key pair.
*
* Returns the CA keys and the line to add to /etc/ssh/sshd_config:
* TrustedUserCAKeys /etc/ssh/ca.pub
*
* Then save the publicKeyOpenSSH to /etc/ssh/ca.pub on the server.
*
* @param comment - Optional comment for the CA public key
* @returns CA key pair and configuration info
*/
export function generateCA(comment: string = "ssh-ca"): CAKeyPair {
// Generate Ed25519 key pair
const { publicKey, privateKey } = crypto.generateKeyPairSync("ed25519", {
publicKeyEncoding: { type: "spki", format: "pem" },
privateKeyEncoding: { type: "pkcs8", format: "pem" }
});
// Get raw public key bytes
const pubKeyObj = crypto.createPublicKey(publicKey);
const rawPubKey = pubKeyObj.export({ type: "spki", format: "der" });
// Ed25519 SPKI format: 12 byte header + 32 byte key
const ed25519PubKey = rawPubKey.subarray(rawPubKey.length - 32);
// Create OpenSSH format public key
const pubKeyBlob = encodeEd25519PublicKey(ed25519PubKey);
const publicKeyOpenSSH = formatOpenSSHPublicKey(pubKeyBlob, comment);
return {
privateKeyPem: privateKey,
publicKeyPem: publicKey,
publicKeyOpenSSH,
publicKeyRaw: ed25519PubKey
};
}
// ============================================================================
// Helper Functions
// ============================================================================
/**
* Get and decrypt the SSH CA keys for an organization.
*
* @param orgId - Organization ID
* @param decryptionKey - Key to decrypt the CA private key (typically server.secret from config)
* @returns CA key pair or null if not found
*/
export async function getOrgCAKeys(
orgId: string,
decryptionKey: string
): Promise<CAKeyPair | null> {
const { db, orgs } = await import("@server/db");
const { eq } = await import("drizzle-orm");
const { decrypt } = await import("@server/lib/crypto");
const [org] = await db
.select({
sshCaPrivateKey: orgs.sshCaPrivateKey,
sshCaPublicKey: orgs.sshCaPublicKey
})
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (!org || !org.sshCaPrivateKey || !org.sshCaPublicKey) {
return null;
}
const privateKeyPem = decrypt(org.sshCaPrivateKey, decryptionKey);
// Extract raw public key from the OpenSSH format
const { keyData } = parseOpenSSHPublicKey(org.sshCaPublicKey);
const { newOffset } = decodeString(keyData, 0); // Skip key type
const { value: publicKeyRaw } = decodeString(keyData, newOffset);
// Get PEM format of public key
const pubKeyObj = crypto.createPublicKey({
key: privateKeyPem,
format: "pem"
});
const publicKeyPem = pubKeyObj.export({ type: "spki", format: "pem" }) as string;
return {
privateKeyPem,
publicKeyPem,
publicKeyOpenSSH: org.sshCaPublicKey,
publicKeyRaw
};
}
/**
* Sign a user's SSH public key with the CA, producing a certificate.
*
* The resulting certificate should be saved alongside the user's private key
* with a -cert.pub suffix. For example:
* - Private key: ~/.ssh/id_ed25519
* - Certificate: ~/.ssh/id_ed25519-cert.pub
*
* @param caPrivateKeyPem - CA private key in PEM format
* @param userPublicKeyLine - User's public key in OpenSSH format
* @param options - Certificate options (principals, validity, etc.)
* @returns Signed certificate
*/
export function signPublicKey(
caPrivateKeyPem: string,
userPublicKeyLine: string,
options: CertificateOptions
): SignedCertificate {
// Parse the user's public key
const { keyType, keyData } = parseOpenSSHPublicKey(userPublicKeyLine);
// Determine certificate type string
let certTypeString: string;
if (keyType === "ssh-ed25519") {
certTypeString = "ssh-ed25519-cert-v01@openssh.com";
} else if (keyType === "ssh-rsa") {
certTypeString = "ssh-rsa-cert-v01@openssh.com";
} else if (keyType === "ecdsa-sha2-nistp256") {
certTypeString = "ecdsa-sha2-nistp256-cert-v01@openssh.com";
} else if (keyType === "ecdsa-sha2-nistp384") {
certTypeString = "ecdsa-sha2-nistp384-cert-v01@openssh.com";
} else if (keyType === "ecdsa-sha2-nistp521") {
certTypeString = "ecdsa-sha2-nistp521-cert-v01@openssh.com";
} else {
throw new Error(`Unsupported key type: ${keyType}`);
}
// Get CA public key from private key
const caPrivKey = crypto.createPrivateKey(caPrivateKeyPem);
const caPubKey = crypto.createPublicKey(caPrivKey);
const caRawPubKey = caPubKey.export({ type: "spki", format: "der" });
const caEd25519PubKey = caRawPubKey.subarray(caRawPubKey.length - 32);
const caPubKeyBlob = encodeEd25519PublicKey(caEd25519PubKey);
// Set defaults
const serial = options.serial ?? BigInt(Date.now());
const certType = options.certType ?? 1; // 1 = user cert
const now = BigInt(Math.floor(Date.now() / 1000));
const validAfter = options.validAfter ?? (now - 60n); // 1 minute ago
const validBefore = options.validBefore ?? (now + 86400n * 365n); // 1 year from now
// Default extensions for user certificates
const defaultExtensions = [
"permit-X11-forwarding",
"permit-agent-forwarding",
"permit-port-forwarding",
"permit-pty",
"permit-user-rc"
];
const extensions = options.extensions ?? defaultExtensions;
const criticalOptions = options.criticalOptions ?? new Map();
// Generate nonce (random bytes)
const nonce = crypto.randomBytes(32);
// Extract the public key portion from the user's key blob
// For Ed25519: skip the key type string, get the public key (already encoded)
let userKeyPortion: Buffer;
if (keyType === "ssh-ed25519") {
// Skip the key type string, take the rest (which is encodeString(32-byte-key))
const { newOffset } = decodeString(keyData, 0);
userKeyPortion = keyData.subarray(newOffset);
} else {
// For other key types, extract everything after the key type
const { newOffset } = decodeString(keyData, 0);
userKeyPortion = keyData.subarray(newOffset);
}
// Build the certificate body (to be signed)
const certBody = Buffer.concat([
encodeString(certTypeString),
encodeString(nonce),
userKeyPortion,
encodeUInt64(serial),
encodeUInt32(certType),
encodeString(options.keyId),
buildPrincipals(options.validPrincipals),
encodeUInt64(validAfter),
encodeUInt64(validBefore),
buildCriticalOptions(criticalOptions),
buildExtensions(extensions),
encodeString(""), // reserved
encodeString(caPubKeyBlob) // signature key (CA public key)
]);
// Sign the certificate body
const signature = crypto.sign(null, certBody, caPrivKey);
// Build the full signature blob (algorithm + signature)
const signatureBlob = Buffer.concat([
encodeString("ssh-ed25519"),
encodeString(signature)
]);
// Build complete certificate
const certificate = Buffer.concat([
certBody,
encodeString(signatureBlob)
]);
// Format as OpenSSH certificate line
const certLine = `${certTypeString} ${certificate.toString("base64")} ${options.keyId}`;
return {
certificate: certLine,
certType: certTypeString,
serial,
keyId: options.keyId,
validPrincipals: options.validPrincipals,
validAfter: new Date(Number(validAfter) * 1000),
validBefore: new Date(Number(validBefore) * 1000)
};
}

View File

@@ -19,7 +19,7 @@ import { fromError } from "zod-validation-error";
import type { Request, Response, NextFunction } from "express"; import type { Request, Response, NextFunction } from "express";
import { approvals, db, type Approval } from "@server/db"; import { approvals, db, type Approval } from "@server/db";
import { eq, sql, and } from "drizzle-orm"; import { eq, sql, and, inArray } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
const paramsSchema = z.strictObject({ const paramsSchema = z.strictObject({
@@ -88,7 +88,7 @@ export async function countApprovals(
.where( .where(
and( and(
eq(approvals.orgId, orgId), eq(approvals.orgId, orgId),
sql`${approvals.decision} in ${state}` inArray(approvals.decision, state)
) )
); );

View File

@@ -28,7 +28,7 @@ import {
currentFingerprint, currentFingerprint,
type Approval type Approval
} from "@server/db"; } from "@server/db";
import { eq, isNull, sql, not, and, desc } from "drizzle-orm"; import { eq, isNull, sql, not, and, desc, gte, lte } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
import { getUserDeviceName } from "@server/db/names"; import { getUserDeviceName } from "@server/db/names";
@@ -37,18 +37,26 @@ const paramsSchema = z.strictObject({
}); });
const querySchema = z.strictObject({ const querySchema = z.strictObject({
limit: z limit: z.coerce
.string() .number<string>() // for prettier formatting
.int()
.positive()
.optional() .optional()
.default("1000") .catch(20)
.transform(Number) .default(20),
.pipe(z.int().nonnegative()), cursorPending: z.coerce // pending cursor
offset: z .number<string>()
.string() .int()
.max(1) // 0 means non pending
.min(0) // 1 means pending
.optional() .optional()
.default("0") .catch(undefined),
.transform(Number) cursorTimestamp: z.coerce
.pipe(z.int().nonnegative()), .number<string>()
.int()
.positive()
.optional()
.catch(undefined),
approvalState: z approvalState: z
.enum(["pending", "approved", "denied", "all"]) .enum(["pending", "approved", "denied", "all"])
.optional() .optional()
@@ -61,13 +69,21 @@ const querySchema = z.strictObject({
.pipe(z.number().int().positive().optional()) .pipe(z.number().int().positive().optional())
}); });
async function queryApprovals( async function queryApprovals({
orgId: string, orgId,
limit: number, limit,
offset: number, approvalState,
approvalState: z.infer<typeof querySchema>["approvalState"], cursorPending,
clientId?: number cursorTimestamp,
) { clientId
}: {
orgId: string;
limit: number;
approvalState: z.infer<typeof querySchema>["approvalState"];
cursorPending?: number;
cursorTimestamp?: number;
clientId?: number;
}) {
let state: Array<Approval["decision"]> = []; let state: Array<Approval["decision"]> = [];
switch (approvalState) { switch (approvalState) {
case "pending": case "pending":
@@ -83,6 +99,26 @@ async function queryApprovals(
state = ["approved", "denied", "pending"]; state = ["approved", "denied", "pending"];
} }
const conditions = [
eq(approvals.orgId, orgId),
sql`${approvals.decision} in ${state}`
];
if (clientId) {
conditions.push(eq(approvals.clientId, clientId));
}
const pendingSortKey = sql`CASE ${approvals.decision} WHEN 'pending' THEN 1 ELSE 0 END`;
if (cursorPending != null && cursorTimestamp != null) {
// https://stackoverflow.com/a/79720298/10322846
// composite cursor, next data means (pending, timestamp) <= cursor
conditions.push(
lte(pendingSortKey, cursorPending),
lte(approvals.timestamp, cursorTimestamp)
);
}
const res = await db const res = await db
.select({ .select({
approvalId: approvals.approvalId, approvalId: approvals.approvalId,
@@ -105,7 +141,8 @@ async function queryApprovals(
fingerprintArch: currentFingerprint.arch, fingerprintArch: currentFingerprint.arch,
fingerprintSerialNumber: currentFingerprint.serialNumber, fingerprintSerialNumber: currentFingerprint.serialNumber,
fingerprintUsername: currentFingerprint.username, fingerprintUsername: currentFingerprint.username,
fingerprintHostname: currentFingerprint.hostname fingerprintHostname: currentFingerprint.hostname,
timestamp: approvals.timestamp
}) })
.from(approvals) .from(approvals)
.innerJoin(users, and(eq(approvals.userId, users.userId))) .innerJoin(users, and(eq(approvals.userId, users.userId)))
@@ -118,22 +155,12 @@ async function queryApprovals(
) )
.leftJoin(olms, eq(clients.clientId, olms.clientId)) .leftJoin(olms, eq(clients.clientId, olms.clientId))
.leftJoin(currentFingerprint, eq(olms.olmId, currentFingerprint.olmId)) .leftJoin(currentFingerprint, eq(olms.olmId, currentFingerprint.olmId))
.where( .where(and(...conditions))
and( .orderBy(desc(pendingSortKey), desc(approvals.timestamp))
eq(approvals.orgId, orgId), .limit(limit + 1); // the `+1` is used for the cursor
sql`${approvals.decision} in ${state}`,
...(clientId ? [eq(approvals.clientId, clientId)] : [])
)
)
.orderBy(
sql`CASE ${approvals.decision} WHEN 'pending' THEN 0 ELSE 1 END`,
desc(approvals.timestamp)
)
.limit(limit)
.offset(offset);
// Process results to format device names and build fingerprint objects // Process results to format device names and build fingerprint objects
return res.map((approval) => { const approvalsList = res.slice(0, limit).map((approval) => {
const model = approval.deviceModel || null; const model = approval.deviceModel || null;
const deviceName = approval.clientName const deviceName = approval.clientName
? getUserDeviceName(model, approval.clientName) ? getUserDeviceName(model, approval.clientName)
@@ -152,14 +179,14 @@ async function queryApprovals(
const fingerprint = hasFingerprintData const fingerprint = hasFingerprintData
? { ? {
platform: approval.fingerprintPlatform || null, platform: approval.fingerprintPlatform ?? null,
osVersion: approval.fingerprintOsVersion || null, osVersion: approval.fingerprintOsVersion ?? null,
kernelVersion: approval.fingerprintKernelVersion || null, kernelVersion: approval.fingerprintKernelVersion ?? null,
arch: approval.fingerprintArch || null, arch: approval.fingerprintArch ?? null,
deviceModel: approval.deviceModel || null, deviceModel: approval.deviceModel ?? null,
serialNumber: approval.fingerprintSerialNumber || null, serialNumber: approval.fingerprintSerialNumber ?? null,
username: approval.fingerprintUsername || null, username: approval.fingerprintUsername ?? null,
hostname: approval.fingerprintHostname || null hostname: approval.fingerprintHostname ?? null
} }
: null; : null;
@@ -183,11 +210,30 @@ async function queryApprovals(
niceId: approval.niceId || null niceId: approval.niceId || null
}; };
}); });
let nextCursorPending: number | null = null;
let nextCursorTimestamp: number | null = null;
if (res.length > limit) {
const lastItem = res[limit];
nextCursorPending = lastItem.decision === "pending" ? 1 : 0;
nextCursorTimestamp = lastItem.timestamp;
}
return {
approvalsList,
nextCursorPending,
nextCursorTimestamp
};
} }
export type ListApprovalsResponse = { export type ListApprovalsResponse = {
approvals: NonNullable<Awaited<ReturnType<typeof queryApprovals>>>; approvals: NonNullable<
pagination: { total: number; limit: number; offset: number }; Awaited<ReturnType<typeof queryApprovals>>
>["approvalsList"];
pagination: {
total: number;
limit: number;
cursorPending: number | null;
cursorTimestamp: number | null;
};
}; };
export async function listApprovals( export async function listApprovals(
@@ -215,17 +261,25 @@ export async function listApprovals(
) )
); );
} }
const { limit, offset, approvalState, clientId } = parsedQuery.data; const {
limit,
cursorPending,
cursorTimestamp,
approvalState,
clientId
} = parsedQuery.data;
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
const approvalsList = await queryApprovals( const { approvalsList, nextCursorPending, nextCursorTimestamp } =
orgId.toString(), await queryApprovals({
orgId: orgId.toString(),
limit, limit,
offset, cursorPending,
cursorTimestamp,
approvalState, approvalState,
clientId clientId
); });
const [{ count }] = await db const [{ count }] = await db
.select({ count: sql<number>`count(*)` }) .select({ count: sql<number>`count(*)` })
@@ -237,7 +291,8 @@ export async function listApprovals(
pagination: { pagination: {
total: count, total: count,
limit, limit,
offset cursorPending: nextCursorPending,
cursorTimestamp: nextCursorTimestamp
} }
}, },
success: true, success: true,

View File

@@ -15,7 +15,18 @@ import { SubscriptionType } from "./hooks/getSubType";
import { TierFeature, tierMatrix } from "@server/lib/billing/tierMatrix"; import { TierFeature, tierMatrix } from "@server/lib/billing/tierMatrix";
import { Tier } from "@server/types/Tiers"; import { Tier } from "@server/types/Tiers";
import logger from "@server/logger"; import logger from "@server/logger";
import { db, idp, idpOrg, loginPage, loginPageBranding, loginPageBrandingOrg, loginPageOrg, orgs, resources, roles } from "@server/db"; import {
db,
idp,
idpOrg,
loginPage,
loginPageBranding,
loginPageBrandingOrg,
loginPageOrg,
orgs,
resources,
roles
} from "@server/db";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
/** /**
@@ -59,10 +70,7 @@ async function capRetentionDays(
} }
// Get current org settings // Get current org settings
const [org] = await db const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId));
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId));
if (!org) { if (!org) {
logger.warn(`Org ${orgId} not found when capping retention days`); logger.warn(`Org ${orgId} not found when capping retention days`);
@@ -110,18 +118,13 @@ async function capRetentionDays(
// Apply updates if needed // Apply updates if needed
if (needsUpdate) { if (needsUpdate) {
await db await db.update(orgs).set(updates).where(eq(orgs.orgId, orgId));
.update(orgs)
.set(updates)
.where(eq(orgs.orgId, orgId));
logger.info( logger.info(
`Successfully capped retention days for org ${orgId} to max ${maxRetentionDays} days` `Successfully capped retention days for org ${orgId} to max ${maxRetentionDays} days`
); );
} else { } else {
logger.debug( logger.debug(`No retention day capping needed for org ${orgId}`);
`No retention day capping needed for org ${orgId}`
);
} }
} }
@@ -134,6 +137,35 @@ export async function handleTierChange(
`Handling tier change for org ${orgId}: ${previousTier || "none"} -> ${newTier || "free"}` `Handling tier change for org ${orgId}: ${previousTier || "none"} -> ${newTier || "free"}`
); );
// Get all orgs that have this orgId as their billingOrgId
const associatedOrgs = await db
.select()
.from(orgs)
.where(eq(orgs.billingOrgId, orgId));
logger.info(
`Found ${associatedOrgs.length} org(s) associated with billing org ${orgId}`
);
// Loop over all associated orgs and apply tier changes
for (const org of associatedOrgs) {
await handleTierChangeForOrg(org.orgId, newTier, previousTier);
}
logger.info(
`Completed tier change handling for all orgs associated with billing org ${orgId}`
);
}
async function handleTierChangeForOrg(
orgId: string,
newTier: SubscriptionType | null,
previousTier?: SubscriptionType | null
): Promise<void> {
logger.info(
`Handling tier change for org ${orgId}: ${previousTier || "none"} -> ${newTier || "free"}`
);
// License subscriptions are handled separately and don't use the tier matrix // License subscriptions are handled separately and don't use the tier matrix
if (newTier === "license") { if (newTier === "license") {
logger.debug( logger.debug(
@@ -314,9 +346,7 @@ async function disableLoginPageDomain(orgId: string): Promise<void> {
); );
if (existingLoginPage) { if (existingLoginPage) {
await db await db.delete(loginPageOrg).where(eq(loginPageOrg.orgId, orgId));
.delete(loginPageOrg)
.where(eq(loginPageOrg.orgId, orgId));
await db await db
.delete(loginPage) .delete(loginPage)

View File

@@ -112,11 +112,13 @@ export async function getOrgSubscriptionsData(
throw new Error(`Not found`); throw new Error(`Not found`);
} }
const billingOrgId = org[0].billingOrgId || org[0].orgId;
// Get customer for org // Get customer for org
const customer = await db const customer = await db
.select() .select()
.from(customers) .from(customers)
.where(eq(customers.orgId, orgId)) .where(eq(customers.orgId, billingOrgId))
.limit(1); .limit(1);
const subscriptionsWithItems: Array<{ const subscriptionsWithItems: Array<{

View File

@@ -85,10 +85,14 @@ export async function getOrgUsage(
orgId, orgId,
FeatureId.REMOTE_EXIT_NODES FeatureId.REMOTE_EXIT_NODES
); );
const egressData = await usageService.getUsage( const organizations = await usageService.getUsage(
orgId, orgId,
FeatureId.EGRESS_DATA_MB FeatureId.ORGINIZATIONS
); );
// const egressData = await usageService.getUsage(
// orgId,
// FeatureId.EGRESS_DATA_MB
// );
if (sites) { if (sites) {
usageData.push(sites); usageData.push(sites);
@@ -96,15 +100,18 @@ export async function getOrgUsage(
if (users) { if (users) {
usageData.push(users); usageData.push(users);
} }
if (egressData) { // if (egressData) {
usageData.push(egressData); // usageData.push(egressData);
} // }
if (domains) { if (domains) {
usageData.push(domains); usageData.push(domains);
} }
if (remoteExitNodes) { if (remoteExitNodes) {
usageData.push(remoteExitNodes); usageData.push(remoteExitNodes);
} }
if (organizations) {
usageData.push(organizations);
}
const orgLimits = await db const orgLimits = await db
.select() .select()

View File

@@ -25,6 +25,7 @@ import * as logs from "#private/routers/auditLogs";
import * as misc from "#private/routers/misc"; import * as misc from "#private/routers/misc";
import * as reKey from "#private/routers/re-key"; import * as reKey from "#private/routers/re-key";
import * as approval from "#private/routers/approvals"; import * as approval from "#private/routers/approvals";
import * as ssh from "#private/routers/ssh";
import { import {
verifyOrgAccess, verifyOrgAccess,
@@ -506,3 +507,14 @@ authenticated.put(
verifyUserHasAction(ActionsEnum.reGenerateSecret), verifyUserHasAction(ActionsEnum.reGenerateSecret),
reKey.reGenerateExitNodeSecret reKey.reGenerateExitNodeSecret
); );
authenticated.post(
"/org/:orgId/ssh/sign-key",
verifyValidLicense,
verifyValidSubscription(tierMatrix.sshPam),
verifyOrgAccess,
verifyLimits,
// verifyUserHasAction(ActionsEnum.signSshKey),
logActionAudit(ActionsEnum.signSshKey),
ssh.signSshKey
);

View File

@@ -37,8 +37,9 @@ export async function generateNewEnterpriseLicense(
next: NextFunction next: NextFunction
): Promise<any> { ): Promise<any> {
try { try {
const parsedParams = generateNewEnterpriseLicenseParamsSchema.safeParse(
const parsedParams = generateNewEnterpriseLicenseParamsSchema.safeParse(req.params); req.params
);
if (!parsedParams.success) { if (!parsedParams.success) {
return next( return next(
createHttpError( createHttpError(
@@ -63,7 +64,10 @@ export async function generateNewEnterpriseLicense(
const licenseData = req.body; const licenseData = req.body;
if (licenseData.tier != "big_license" && licenseData.tier != "small_license") { if (
licenseData.tier != "big_license" &&
licenseData.tier != "small_license"
) {
return next( return next(
createHttpError( createHttpError(
HttpCode.BAD_REQUEST, HttpCode.BAD_REQUEST,
@@ -79,7 +83,8 @@ export async function generateNewEnterpriseLicense(
return next( return next(
createHttpError( createHttpError(
apiResponse.status || HttpCode.BAD_REQUEST, apiResponse.status || HttpCode.BAD_REQUEST,
apiResponse.message || "Failed to create license from Fossorial API" apiResponse.message ||
"Failed to create license from Fossorial API"
) )
); );
} }
@@ -112,7 +117,10 @@ export async function generateNewEnterpriseLicense(
); );
} }
const tier = licenseData.tier === "big_license" ? LicenseId.BIG_LICENSE : LicenseId.SMALL_LICENSE; const tier =
licenseData.tier === "big_license"
? LicenseId.BIG_LICENSE
: LicenseId.SMALL_LICENSE;
const tierPrice = getLicensePriceSet()[tier]; const tierPrice = getLicensePriceSet()[tier];
const session = await stripe!.checkout.sessions.create({ const session = await stripe!.checkout.sessions.create({
@@ -122,7 +130,7 @@ export async function generateNewEnterpriseLicense(
{ {
price: tierPrice, // Use the standard tier price: tierPrice, // Use the standard tier
quantity: 1 quantity: 1
}, }
], // Start with the standard feature set that matches the free limits ], // Start with the standard feature set that matches the free limits
customer: customer.customerId, customer: customer.customerId,
mode: "subscription", mode: "subscription",

View File

@@ -26,6 +26,7 @@ import logger from "@server/logger";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { eq, InferInsertModel } from "drizzle-orm"; import { eq, InferInsertModel } from "drizzle-orm";
import { build } from "@server/build"; import { build } from "@server/build";
import { validateLocalPath } from "@app/lib/validateLocalPath";
import config from "#private/lib/config"; import config from "#private/lib/config";
const paramsSchema = z.strictObject({ const paramsSchema = z.strictObject({
@@ -37,14 +38,36 @@ const bodySchema = z.strictObject({
.union([ .union([
z.literal(""), z.literal(""),
z z
.url("Must be a valid URL") .string()
.superRefine(async (url, ctx) => { .superRefine(async (urlOrPath, ctx) => {
const parseResult = z.url().safeParse(urlOrPath);
if (!parseResult.success) {
if (build !== "enterprise") {
ctx.addIssue({
code: "custom",
message: "Must be a valid URL"
});
return;
} else {
try { try {
const response = await fetch(url, { validateLocalPath(urlOrPath);
} catch (error) {
ctx.addIssue({
code: "custom",
message: "Must be either a valid image URL or a valid pathname starting with `/` and not containing query parameters, `..` or `*`"
});
} finally {
return;
}
}
}
try {
const response = await fetch(urlOrPath, {
method: "HEAD" method: "HEAD"
}).catch(() => { }).catch(() => {
// If HEAD fails (CORS or method not allowed), try GET // If HEAD fails (CORS or method not allowed), try GET
return fetch(url, { method: "GET" }); return fetch(urlOrPath, { method: "GET" });
}); });
if (response.status !== 200) { if (response.status !== 200) {

View File

@@ -12,7 +12,14 @@
*/ */
import { NextFunction, Request, Response } from "express"; import { NextFunction, Request, Response } from "express";
import { db, exitNodes, exitNodeOrgs, ExitNode, ExitNodeOrg } from "@server/db"; import {
db,
exitNodes,
exitNodeOrgs,
ExitNode,
ExitNodeOrg,
orgs
} from "@server/db";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import { z } from "zod"; import { z } from "zod";
import { remoteExitNodes } from "@server/db"; import { remoteExitNodes } from "@server/db";
@@ -25,7 +32,7 @@ import { createRemoteExitNodeSession } from "#private/auth/sessions/remoteExitNo
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { hashPassword, verifyPassword } from "@server/auth/password"; import { hashPassword, verifyPassword } from "@server/auth/password";
import logger from "@server/logger"; import logger from "@server/logger";
import { and, eq } from "drizzle-orm"; import { and, eq, inArray, ne } from "drizzle-orm";
import { getNextAvailableSubnet } from "@server/lib/exitNodes"; import { getNextAvailableSubnet } from "@server/lib/exitNodes";
import { usageService } from "@server/lib/billing/usageService"; import { usageService } from "@server/lib/billing/usageService";
import { FeatureId } from "@server/lib/billing"; import { FeatureId } from "@server/lib/billing";
@@ -169,7 +176,17 @@ export async function createRemoteExitNode(
); );
} }
let numExitNodeOrgs: ExitNodeOrg[] | undefined; const [org] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (!org) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Organization not found")
);
}
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
if (!existingExitNode) { if (!existingExitNode) {
@@ -217,19 +234,43 @@ export async function createRemoteExitNode(
}); });
} }
numExitNodeOrgs = await trx // calculate if the node is in any other of the orgs before we count it as an add to the billing org
if (org.billingOrgId) {
const otherBillingOrgs = await trx
.select()
.from(orgs)
.where(
and(
eq(orgs.billingOrgId, org.billingOrgId),
ne(orgs.orgId, orgId)
)
);
const billingOrgIds = otherBillingOrgs.map((o) => o.orgId);
const orgsInBillingDomainThatTheNodeIsStillIn = await trx
.select() .select()
.from(exitNodeOrgs) .from(exitNodeOrgs)
.where(eq(exitNodeOrgs.orgId, orgId)); .where(
}); and(
eq(
exitNodeOrgs.exitNodeId,
existingExitNode.exitNodeId
),
inArray(exitNodeOrgs.orgId, billingOrgIds)
)
);
if (numExitNodeOrgs) { if (orgsInBillingDomainThatTheNodeIsStillIn.length === 0) {
await usageService.updateCount( await usageService.add(
orgId, orgId,
FeatureId.REMOTE_EXIT_NODES, FeatureId.REMOTE_EXIT_NODES,
numExitNodeOrgs.length 1,
trx
); );
} }
}
});
const token = generateSessionToken(); const token = generateSessionToken();
await createRemoteExitNodeSession(token, remoteExitNodeId); await createRemoteExitNodeSession(token, remoteExitNodeId);

View File

@@ -13,9 +13,9 @@
import { NextFunction, Request, Response } from "express"; import { NextFunction, Request, Response } from "express";
import { z } from "zod"; import { z } from "zod";
import { db, ExitNodeOrg, exitNodeOrgs, exitNodes } from "@server/db"; import { db, ExitNodeOrg, exitNodeOrgs, exitNodes, orgs } from "@server/db";
import { remoteExitNodes } from "@server/db"; import { remoteExitNodes } from "@server/db";
import { and, count, eq } from "drizzle-orm"; import { and, count, eq, inArray } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
@@ -50,7 +50,8 @@ export async function deleteRemoteExitNode(
const [remoteExitNode] = await db const [remoteExitNode] = await db
.select() .select()
.from(remoteExitNodes) .from(remoteExitNodes)
.where(eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId)); .where(eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId))
.limit(1);
if (!remoteExitNode) { if (!remoteExitNode) {
return next( return next(
@@ -70,7 +71,17 @@ export async function deleteRemoteExitNode(
); );
} }
let numExitNodeOrgs: ExitNodeOrg[] | undefined; const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId));
if (!org) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Org with ID ${orgId} not found`
)
);
}
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
await trx await trx
.delete(exitNodeOrgs) .delete(exitNodeOrgs)
@@ -81,37 +92,38 @@ export async function deleteRemoteExitNode(
) )
); );
const [remainingExitNodeOrgs] = await trx // calculate if the user is in any other of the orgs before we count it as an remove to the billing org
.select({ count: count() }) if (org.billingOrgId) {
.from(exitNodeOrgs) const otherBillingOrgs = await trx
.where(eq(exitNodeOrgs.exitNodeId, remoteExitNode.exitNodeId!)); .select()
.from(orgs)
.where(eq(orgs.billingOrgId, org.billingOrgId));
if (remainingExitNodeOrgs.count === 0) { const billingOrgIds = otherBillingOrgs.map((o) => o.orgId);
await trx
.delete(remoteExitNodes)
.where(
eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId)
);
await trx
.delete(exitNodes)
.where(
eq(exitNodes.exitNodeId, remoteExitNode.exitNodeId!)
);
}
numExitNodeOrgs = await trx const orgsInBillingDomainThatTheNodeIsStillIn = await trx
.select() .select()
.from(exitNodeOrgs) .from(exitNodeOrgs)
.where(eq(exitNodeOrgs.orgId, orgId)); .where(
}); and(
eq(
exitNodeOrgs.exitNodeId,
remoteExitNode.exitNodeId!
),
inArray(exitNodeOrgs.orgId, billingOrgIds)
)
);
if (numExitNodeOrgs) { if (orgsInBillingDomainThatTheNodeIsStillIn.length === 0) {
await usageService.updateCount( await usageService.add(
orgId, orgId,
FeatureId.REMOTE_EXIT_NODES, FeatureId.REMOTE_EXIT_NODES,
numExitNodeOrgs.length -1,
trx
); );
} }
}
});
return response(res, { return response(res, {
data: null, data: null,

View File

@@ -0,0 +1,14 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
export * from "./signSshKey";

View File

@@ -0,0 +1,403 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, newts, orgs, roundTripMessageTracker, siteResources, sites, userOrgs } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import { eq, or, and } from "drizzle-orm";
import { canUserAccessSiteResource } from "@server/auth/canUserAccessSiteResource";
import { signPublicKey, getOrgCAKeys } from "#private/lib/sshCA";
import config from "@server/lib/config";
import { sendToClient } from "#private/routers/ws";
const paramsSchema = z.strictObject({
orgId: z.string().nonempty()
});
const bodySchema = z
.strictObject({
publicKey: z.string().nonempty(),
resourceId: z.number().int().positive().optional(),
resource: z.string().nonempty().optional() // this is either the nice id or the alias
})
.refine(
(data) => {
const fields = [data.resourceId, data.resource];
const definedFields = fields.filter((field) => field !== undefined);
return definedFields.length === 1;
},
{
message:
"Exactly one of resourceId, niceId, or alias must be provided"
}
);
export type SignSshKeyResponse = {
certificate: string;
messageId: number;
sshUsername: string;
sshHost: string;
resourceId: number;
keyId: string;
validPrincipals: string[];
validAfter: string;
validBefore: string;
expiresIn: number;
};
// registry.registerPath({
// method: "post",
// path: "/org/{orgId}/ssh/sign-key",
// description: "Sign an SSH public key for access to a resource.",
// tags: [OpenAPITags.Org, OpenAPITags.Ssh],
// request: {
// params: paramsSchema,
// body: {
// content: {
// "application/json": {
// schema: bodySchema
// }
// }
// }
// },
// responses: {}
// });
export async function signSshKey(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = paramsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const parsedBody = bodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { orgId } = parsedParams.data;
const {
publicKey,
resourceId,
resource: resourceQueryString
} = parsedBody.data;
const userId = req.user?.userId;
const roleId = req.userOrgRoleId!;
if (!userId) {
return next(
createHttpError(HttpCode.UNAUTHORIZED, "User not authenticated")
);
}
const [userOrg] = await db
.select()
.from(userOrgs)
.where(and(eq(userOrgs.orgId, orgId), eq(userOrgs.userId, userId)))
.limit(1);
if (!userOrg) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"User does not belong to the specified organization"
)
);
}
let usernameToUse;
if (!userOrg.pamUsername) {
if (req.user?.email) {
// Extract username from email (first part before @)
usernameToUse = req.user?.email.split("@")[0];
if (!usernameToUse) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Unable to extract username from email"
)
);
}
} else if (req.user?.username) {
usernameToUse = req.user.username;
// We need to clean out any spaces or special characters from the username to ensure it's valid for SSH certificates
usernameToUse = usernameToUse.replace(/[^a-zA-Z0-9_-]/g, "");
if (!usernameToUse) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Username is not valid for SSH certificate"
)
);
}
} else {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"User does not have a valid email or username for SSH certificate"
)
);
}
// check if we have a existing user in this org with the same
const [existingUserWithSameName] = await db
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.orgId, orgId),
eq(userOrgs.pamUsername, usernameToUse)
)
)
.limit(1);
if (existingUserWithSameName) {
let foundUniqueUsername = false;
for (let attempt = 0; attempt < 20; attempt++) {
const randomNum = Math.floor(Math.random() * 101); // 0 to 100
const candidateUsername = `${usernameToUse}${randomNum}`;
const [existingUser] = await db
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.orgId, orgId),
eq(userOrgs.pamUsername, candidateUsername)
)
)
.limit(1);
if (!existingUser) {
usernameToUse = candidateUsername;
foundUniqueUsername = true;
break;
}
}
if (!foundUniqueUsername) {
return next(
createHttpError(
HttpCode.CONFLICT,
"Unable to generate a unique username for SSH certificate"
)
);
}
}
} else {
usernameToUse = userOrg.pamUsername;
}
// Get and decrypt the org's CA keys
const caKeys = await getOrgCAKeys(
orgId,
config.getRawConfig().server.secret!
);
if (!caKeys) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"SSH CA not configured for this organization"
)
);
}
// Verify the resource exists and belongs to the org
// Build the where clause dynamically based on which field is provided
let whereClause;
if (resourceId !== undefined) {
whereClause = eq(siteResources.siteResourceId, resourceId);
} else if (resourceQueryString !== undefined) {
whereClause = or(
eq(siteResources.niceId, resourceQueryString),
eq(siteResources.alias, resourceQueryString)
);
} else {
// This should never happen due to the schema validation, but TypeScript doesn't know that
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"One of resourceId, niceId, or alias must be provided"
)
);
}
const resources = await db
.select()
.from(siteResources)
.where(and(whereClause, eq(siteResources.orgId, orgId)));
if (!resources || resources.length === 0) {
return next(
createHttpError(HttpCode.NOT_FOUND, `Resource not found`)
);
}
if (resources.length > 1) {
// error but this should not happen because the nice id cant contain a dot and the alias has to have a dot and both have to be unique within the org so there should never be multiple matches
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`Multiple resources found matching the criteria`
)
);
}
const resource = resources[0];
if (resource.orgId !== orgId) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Resource does not belong to the specified organization"
)
);
}
// Check if the user has access to the resource
const hasAccess = await canUserAccessSiteResource({
userId: userId,
resourceId: resource.siteResourceId,
roleId: roleId
});
if (!hasAccess) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"User does not have access to this resource"
)
);
}
// get the site
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, resource.siteId))
.limit(1);
if (!newt) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Site associated with resource not found"
)
);
}
// Sign the public key
const now = BigInt(Math.floor(Date.now() / 1000));
// only valid for 5 minutes
const validFor = 300n;
const cert = signPublicKey(caKeys.privateKeyPem, publicKey, {
keyId: `${usernameToUse}@${resource.niceId}`,
validPrincipals: [usernameToUse, resource.niceId],
validAfter: now - 60n, // Start 1 min ago for clock skew
validBefore: now + validFor
});
const [message] = await db
.insert(roundTripMessageTracker)
.values({
wsClientId: newt.newtId,
messageType: `newt/pam/connection`,
sentAt: Math.floor(Date.now() / 1000),
})
.returning();
if (!message) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to create message tracker entry"
)
);
}
await sendToClient(newt.newtId, {
type: `newt/pam/connection`,
data: {
messageId: message.messageId,
orgId: orgId,
agentPort: 22123,
agentHost: resource.destination,
caCert: caKeys.publicKeyOpenSSH,
username: usernameToUse,
niceId: resource.niceId,
metadata: {
sudo: true, // we are hardcoding these for now but should make configurable from the role or something
homedir: true
}
}
});
const expiresIn = Number(validFor); // seconds
let sshHost;
if (resource.alias && resource.alias != "") {
sshHost = resource.alias;
} else {
sshHost = resource.destination;
}
return response<SignSshKeyResponse>(res, {
data: {
certificate: cert.certificate,
messageId: message.messageId,
sshUsername: usernameToUse,
sshHost: sshHost,
resourceId: resource.siteResourceId,
keyId: cert.keyId,
validPrincipals: cert.validPrincipals,
validAfter: cert.validAfter.toISOString(),
validBefore: cert.validBefore.toISOString(),
expiresIn
},
success: true,
error: false,
message: "SSH key signed successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error("Error signing SSH key:", error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"An error occurred while signing the SSH key"
)
);
}
}

View File

@@ -15,11 +15,10 @@ import {
import { verifyPassword } from "@server/auth/password"; import { verifyPassword } from "@server/auth/password";
import { verifyTotpCode } from "@server/auth/totp"; import { verifyTotpCode } from "@server/auth/totp";
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
import { import { deleteOrgById, sendTerminationMessages } from "@server/lib/deleteOrg";
deleteOrgById,
sendTerminationMessages
} from "@server/lib/deleteOrg";
import { UserType } from "@server/types/UserTypes"; import { UserType } from "@server/types/UserTypes";
import { build } from "@server/build";
import { getOrgTierData } from "#dynamic/lib/billing";
const deleteMyAccountBody = z.strictObject({ const deleteMyAccountBody = z.strictObject({
password: z.string().optional(), password: z.string().optional(),
@@ -40,11 +39,6 @@ export type DeleteMyAccountSuccessResponse = {
success: true; success: true;
}; };
/**
* Self-service account deletion (saas only). Returns preview when no password;
* requires password and optional 2FA code to perform deletion. Uses shared
* deleteOrgById for each owned org (delete-my-account may delete multiple orgs).
*/
export async function deleteMyAccount( export async function deleteMyAccount(
req: Request, req: Request,
res: Response, res: Response,
@@ -91,18 +85,35 @@ export async function deleteMyAccount(
const ownedOrgsRows = await db const ownedOrgsRows = await db
.select({ .select({
orgId: userOrgs.orgId orgId: userOrgs.orgId,
isOwner: userOrgs.isOwner,
isBillingOrg: orgs.isBillingOrg
}) })
.from(userOrgs) .from(userOrgs)
.innerJoin(orgs, eq(userOrgs.orgId, orgs.orgId))
.where( .where(
and( and(eq(userOrgs.userId, userId), eq(userOrgs.isOwner, true))
eq(userOrgs.userId, userId),
eq(userOrgs.isOwner, true)
)
); );
const orgIds = ownedOrgsRows.map((r) => r.orgId); const orgIds = ownedOrgsRows.map((r) => r.orgId);
if (build === "saas" && orgIds.length > 0) {
const primaryOrgId = ownedOrgsRows.find(
(r) => r.isBillingOrg && r.isOwner
)?.orgId;
if (primaryOrgId) {
const { tier, active } = await getOrgTierData(primaryOrgId);
if (active && tier) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"You must cancel your subscription before deleting your account"
)
);
}
}
}
if (!password) { if (!password) {
const orgsWithNames = const orgsWithNames =
orgIds.length > 0 orgIds.length > 0
@@ -219,10 +230,7 @@ export async function deleteMyAccount(
} catch (error) { } catch (error) {
logger.error(error); logger.error(error);
return next( return next(
createHttpError( createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
HttpCode.INTERNAL_SERVER_ERROR,
"An error occurred"
)
); );
} }
} }

View File

@@ -1,7 +1,7 @@
import { NextFunction, Request, Response } from "express"; import { NextFunction, Request, Response } from "express";
import { db, users } from "@server/db"; import { db, users } from "@server/db";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import { z } from "zod"; import { email, z } from "zod";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import response from "@server/lib/response"; import response from "@server/lib/response";
@@ -21,7 +21,6 @@ import { hashPassword } from "@server/auth/password";
import { checkValidInvite } from "@server/auth/checkValidInvite"; import { checkValidInvite } from "@server/auth/checkValidInvite";
import { passwordSchema } from "@server/auth/passwordSchema"; import { passwordSchema } from "@server/auth/passwordSchema";
import { UserType } from "@server/types/UserTypes"; import { UserType } from "@server/types/UserTypes";
import { createUserAccountOrg } from "@server/lib/createUserAccountOrg";
import { build } from "@server/build"; import { build } from "@server/build";
import resend, { AudienceIds, moveEmailToAudience } from "#dynamic/lib/resend"; import resend, { AudienceIds, moveEmailToAudience } from "#dynamic/lib/resend";
@@ -31,7 +30,8 @@ export const signupBodySchema = z.object({
inviteToken: z.string().optional(), inviteToken: z.string().optional(),
inviteId: z.string().optional(), inviteId: z.string().optional(),
termsAcceptedTimestamp: z.string().nullable().optional(), termsAcceptedTimestamp: z.string().nullable().optional(),
marketingEmailConsent: z.boolean().optional() marketingEmailConsent: z.boolean().optional(),
skipVerificationEmail: z.boolean().optional()
}); });
export type SignUpBody = z.infer<typeof signupBodySchema>; export type SignUpBody = z.infer<typeof signupBodySchema>;
@@ -62,7 +62,8 @@ export async function signup(
inviteToken, inviteToken,
inviteId, inviteId,
termsAcceptedTimestamp, termsAcceptedTimestamp,
marketingEmailConsent marketingEmailConsent,
skipVerificationEmail
} = parsedBody.data; } = parsedBody.data;
const passwordHash = await hashPassword(password); const passwordHash = await hashPassword(password);
@@ -198,26 +199,6 @@ export async function signup(
// orgId: null, // orgId: null,
// }); // });
if (build == "saas") {
const { success, error, org } = await createUserAccountOrg(
userId,
email
);
if (!success) {
if (error) {
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, error)
);
}
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to create user account and organization"
)
);
}
}
const token = generateSessionToken(); const token = generateSessionToken();
const sess = await createSession(token, userId); const sess = await createSession(token, userId);
const isSecure = req.protocol === "https"; const isSecure = req.protocol === "https";
@@ -235,7 +216,13 @@ export async function signup(
} }
if (config.getRawConfig().flags?.require_email_verification) { if (config.getRawConfig().flags?.require_email_verification) {
if (!skipVerificationEmail) {
sendEmailVerificationCode(email, userId); sendEmailVerificationCode(email, userId);
} else {
logger.debug(
`User ${email} opted out of verification email during signup.`
);
}
return response<SignUpResponse>(res, { return response<SignUpResponse>(res, {
data: { data: {
@@ -243,7 +230,9 @@ export async function signup(
}, },
success: true, success: true,
error: false, error: false,
message: `User created successfully. We sent an email to ${email} with a verification code.`, message: skipVerificationEmail
? "User created successfully. Please verify your email."
: `User created successfully. We sent an email to ${email} with a verification code.`,
status: HttpCode.OK status: HttpCode.OK
}); });
} }

View File

@@ -6,6 +6,7 @@ export * from "./unarchiveClient";
export * from "./blockClient"; export * from "./blockClient";
export * from "./unblockClient"; export * from "./unblockClient";
export * from "./listClients"; export * from "./listClients";
export * from "./listUserDevices";
export * from "./updateClient"; export * from "./updateClient";
export * from "./getClient"; export * from "./getClient";
export * from "./createUserClient"; export * from "./createUserClient";

View File

@@ -1,34 +1,38 @@
import { db, olms, users } from "@server/db";
import { import {
clients, clients,
clientSitesAssociationsCache,
currentFingerprint,
db,
olms,
orgs, orgs,
roleClients, roleClients,
sites, sites,
userClients, userClients,
clientSitesAssociationsCache, users
currentFingerprint
} from "@server/db"; } from "@server/db";
import logger from "@server/logger";
import HttpCode from "@server/types/HttpCode";
import response from "@server/lib/response"; import response from "@server/lib/response";
import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
import HttpCode from "@server/types/HttpCode";
import type { PaginatedResponse } from "@server/types/Pagination";
import { import {
and, and,
count, asc,
desc,
eq, eq,
inArray, inArray,
isNotNull,
isNull, isNull,
like,
or, or,
sql sql,
type SQL
} from "drizzle-orm"; } from "drizzle-orm";
import { NextFunction, Request, Response } from "express"; import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import NodeCache from "node-cache"; import NodeCache from "node-cache";
import semver from "semver"; import semver from "semver";
import { getUserDeviceName } from "@server/db/names"; import { z } from "zod";
import { fromError } from "zod-validation-error";
const olmVersionCache = new NodeCache({ stdTTL: 3600 }); const olmVersionCache = new NodeCache({ stdTTL: 3600 });
@@ -89,38 +93,86 @@ const listClientsParamsSchema = z.strictObject({
}); });
const listClientsSchema = z.object({ const listClientsSchema = z.object({
limit: z pageSize: z.coerce
.string() .number<string>() // for prettier formatting
.int()
.positive()
.optional() .optional()
.default("1000") .catch(20)
.transform(Number) .default(20)
.pipe(z.int().positive()), .openapi({
offset: z type: "integer",
.string() default: 20,
description: "Number of items per page"
}),
page: z.coerce
.number<string>() // for prettier formatting
.int()
.min(0)
.optional() .optional()
.default("0") .catch(1)
.transform(Number) .default(1)
.pipe(z.int().nonnegative()), .openapi({
filter: z.enum(["user", "machine"]).optional() type: "integer",
default: 1,
description: "Page number to retrieve"
}),
query: z.string().optional(),
sort_by: z
.enum(["megabytesIn", "megabytesOut"])
.optional()
.catch(undefined)
.openapi({
type: "string",
enum: ["megabytesIn", "megabytesOut"],
description: "Field to sort by"
}),
order: z
.enum(["asc", "desc"])
.optional()
.default("asc")
.catch("asc")
.openapi({
type: "string",
enum: ["asc", "desc"],
default: "asc",
description: "Sort order"
}),
online: z
.enum(["true", "false"])
.transform((v) => v === "true")
.optional()
.catch(undefined)
.openapi({
type: "boolean",
description: "Filter by online status"
}),
status: z.preprocess(
(val: string | undefined) => {
if (val) {
return val.split(","); // the search query array is an array joined by commas
}
return undefined;
},
z
.array(z.enum(["active", "blocked", "archived"]))
.optional()
.default(["active"])
.catch(["active"])
.openapi({
type: "array",
items: {
type: "string",
enum: ["active", "blocked", "archived"]
},
default: ["active"],
description:
"Filter by client status. Can be a comma-separated list of values. Defaults to 'active'."
})
)
}); });
function queryClients( function queryClientsBase() {
orgId: string,
accessibleClientIds: number[],
filter?: "user" | "machine"
) {
const conditions = [
inArray(clients.clientId, accessibleClientIds),
eq(clients.orgId, orgId)
];
// Add filter condition based on filter type
if (filter === "user") {
conditions.push(isNotNull(clients.userId));
} else if (filter === "machine") {
conditions.push(isNull(clients.userId));
}
return db return db
.select({ .select({
clientId: clients.clientId, clientId: clients.clientId,
@@ -142,22 +194,13 @@ function queryClients(
approvalState: clients.approvalState, approvalState: clients.approvalState,
olmArchived: olms.archived, olmArchived: olms.archived,
archived: clients.archived, archived: clients.archived,
blocked: clients.blocked, blocked: clients.blocked
deviceModel: currentFingerprint.deviceModel,
fingerprintPlatform: currentFingerprint.platform,
fingerprintOsVersion: currentFingerprint.osVersion,
fingerprintKernelVersion: currentFingerprint.kernelVersion,
fingerprintArch: currentFingerprint.arch,
fingerprintSerialNumber: currentFingerprint.serialNumber,
fingerprintUsername: currentFingerprint.username,
fingerprintHostname: currentFingerprint.hostname
}) })
.from(clients) .from(clients)
.leftJoin(orgs, eq(clients.orgId, orgs.orgId)) .leftJoin(orgs, eq(clients.orgId, orgs.orgId))
.leftJoin(olms, eq(clients.clientId, olms.clientId)) .leftJoin(olms, eq(clients.clientId, olms.clientId))
.leftJoin(users, eq(clients.userId, users.userId)) .leftJoin(users, eq(clients.userId, users.userId))
.leftJoin(currentFingerprint, eq(olms.olmId, currentFingerprint.olmId)) .leftJoin(currentFingerprint, eq(olms.olmId, currentFingerprint.olmId));
.where(and(...conditions));
} }
async function getSiteAssociations(clientIds: number[]) { async function getSiteAssociations(clientIds: number[]) {
@@ -175,7 +218,7 @@ async function getSiteAssociations(clientIds: number[]) {
.where(inArray(clientSitesAssociationsCache.clientId, clientIds)); .where(inArray(clientSitesAssociationsCache.clientId, clientIds));
} }
type ClientWithSites = Awaited<ReturnType<typeof queryClients>>[0] & { type ClientWithSites = Awaited<ReturnType<typeof queryClientsBase>>[0] & {
sites: Array<{ sites: Array<{
siteId: number; siteId: number;
siteName: string | null; siteName: string | null;
@@ -186,10 +229,9 @@ type ClientWithSites = Awaited<ReturnType<typeof queryClients>>[0] & {
type OlmWithUpdateAvailable = ClientWithSites; type OlmWithUpdateAvailable = ClientWithSites;
export type ListClientsResponse = { export type ListClientsResponse = PaginatedResponse<{
clients: Array<ClientWithSites>; clients: Array<ClientWithSites>;
pagination: { total: number; limit: number; offset: number }; }>;
};
registry.registerPath({ registry.registerPath({
method: "get", method: "get",
@@ -218,7 +260,8 @@ export async function listClients(
) )
); );
} }
const { limit, offset, filter } = parsedQuery.data; const { page, pageSize, online, query, status, sort_by, order } =
parsedQuery.data;
const parsedParams = listClientsParamsSchema.safeParse(req.params); const parsedParams = listClientsParamsSchema.safeParse(req.params);
if (!parsedParams.success) { if (!parsedParams.success) {
@@ -267,28 +310,73 @@ export async function listClients(
const accessibleClientIds = accessibleClients.map( const accessibleClientIds = accessibleClients.map(
(client) => client.clientId (client) => client.clientId
); );
const baseQuery = queryClients(orgId, accessibleClientIds, filter);
// Get client count with filter // Get client count with filter
const countConditions = [ const conditions = [
and(
inArray(clients.clientId, accessibleClientIds), inArray(clients.clientId, accessibleClientIds),
eq(clients.orgId, orgId) eq(clients.orgId, orgId),
isNull(clients.userId)
)
]; ];
if (filter === "user") { if (typeof online !== "undefined") {
countConditions.push(isNotNull(clients.userId)); conditions.push(eq(clients.online, online));
} else if (filter === "machine") {
countConditions.push(isNull(clients.userId));
} }
const countQuery = db if (status.length > 0) {
.select({ count: count() }) const filterAggregates: (SQL<unknown> | undefined)[] = [];
.from(clients)
.where(and(...countConditions));
const clientsList = await baseQuery.limit(limit).offset(offset); if (status.includes("active")) {
const totalCountResult = await countQuery; filterAggregates.push(
const totalCount = totalCountResult[0].count; and(eq(clients.archived, false), eq(clients.blocked, false))
);
}
if (status.includes("archived")) {
filterAggregates.push(eq(clients.archived, true));
}
if (status.includes("blocked")) {
filterAggregates.push(eq(clients.blocked, true));
}
conditions.push(or(...filterAggregates));
}
if (query) {
conditions.push(
or(
like(
sql`LOWER(${clients.name})`,
"%" + query.toLowerCase() + "%"
),
like(
sql`LOWER(${clients.niceId})`,
"%" + query.toLowerCase() + "%"
)
)
);
}
const baseQuery = queryClientsBase().where(and(...conditions));
const countQuery = db.$count(baseQuery.as("filtered_clients"));
const listMachinesQuery = baseQuery
.limit(page)
.offset(pageSize * (page - 1))
.orderBy(
sort_by
? order === "asc"
? asc(clients[sort_by])
: desc(clients[sort_by])
: asc(clients.clientId)
);
const [clientsList, totalCount] = await Promise.all([
listMachinesQuery,
countQuery
]);
// Get associated sites for all clients // Get associated sites for all clients
const clientIds = clientsList.map((client) => client.clientId); const clientIds = clientsList.map((client) => client.clientId);
@@ -319,14 +407,8 @@ export async function listClients(
// Merge clients with their site associations and replace name with device name // Merge clients with their site associations and replace name with device name
const clientsWithSites = clientsList.map((client) => { const clientsWithSites = clientsList.map((client) => {
const model = client.deviceModel || null;
let newName = client.name;
if (filter === "user") {
newName = getUserDeviceName(model, client.name);
}
return { return {
...client, ...client,
name: newName,
sites: sitesByClient[client.clientId] || [] sites: sitesByClient[client.clientId] || []
}; };
}); });
@@ -371,8 +453,8 @@ export async function listClients(
clients: olmsWithUpdates, clients: olmsWithUpdates,
pagination: { pagination: {
total: totalCount, total: totalCount,
limit, page,
offset pageSize
} }
}, },
success: true, success: true,

View File

@@ -0,0 +1,500 @@
import { build } from "@server/build";
import {
clients,
currentFingerprint,
db,
olms,
orgs,
roleClients,
userClients,
users
} from "@server/db";
import { getUserDeviceName } from "@server/db/names";
import response from "@server/lib/response";
import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
import HttpCode from "@server/types/HttpCode";
import type { PaginatedResponse } from "@server/types/Pagination";
import {
and,
asc,
desc,
eq,
inArray,
isNotNull,
isNull,
like,
or,
sql,
type SQL
} from "drizzle-orm";
import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors";
import NodeCache from "node-cache";
import semver from "semver";
import { z } from "zod";
import { fromError } from "zod-validation-error";
const olmVersionCache = new NodeCache({ stdTTL: 3600 });
async function getLatestOlmVersion(): Promise<string | null> {
try {
const cachedVersion = olmVersionCache.get<string>("latestOlmVersion");
if (cachedVersion) {
return cachedVersion;
}
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), 1500);
const response = await fetch(
"https://api.github.com/repos/fosrl/olm/tags",
{
signal: controller.signal
}
);
clearTimeout(timeoutId);
if (!response.ok) {
logger.warn(
`Failed to fetch latest Olm version from GitHub: ${response.status} ${response.statusText}`
);
return null;
}
let tags = await response.json();
if (!Array.isArray(tags) || tags.length === 0) {
logger.warn("No tags found for Olm repository");
return null;
}
tags = tags.filter((version) => !version.name.includes("rc"));
const latestVersion = tags[0].name;
olmVersionCache.set("latestOlmVersion", latestVersion);
return latestVersion;
} catch (error: any) {
if (error.name === "AbortError") {
logger.warn("Request to fetch latest Olm version timed out (1.5s)");
} else if (error.cause?.code === "UND_ERR_CONNECT_TIMEOUT") {
logger.warn("Connection timeout while fetching latest Olm version");
} else {
logger.warn(
"Error fetching latest Olm version:",
error.message || error
);
}
return null;
}
}
const listUserDevicesParamsSchema = z.strictObject({
orgId: z.string()
});
const listUserDevicesSchema = z.object({
pageSize: z.coerce
.number<string>() // for prettier formatting
.int()
.positive()
.optional()
.catch(20)
.default(20)
.openapi({
type: "integer",
default: 20,
description: "Number of items per page"
}),
page: z.coerce
.number<string>() // for prettier formatting
.int()
.min(0)
.optional()
.catch(1)
.default(1)
.openapi({
type: "integer",
default: 1,
description: "Page number to retrieve"
}),
query: z.string().optional(),
sort_by: z
.enum(["megabytesIn", "megabytesOut"])
.optional()
.catch(undefined)
.openapi({
type: "string",
enum: ["megabytesIn", "megabytesOut"],
description: "Field to sort by"
}),
order: z
.enum(["asc", "desc"])
.optional()
.default("asc")
.catch("asc")
.openapi({
type: "string",
enum: ["asc", "desc"],
default: "asc",
description: "Sort order"
}),
online: z
.enum(["true", "false"])
.transform((v) => v === "true")
.optional()
.catch(undefined)
.openapi({
type: "boolean",
description: "Filter by online status"
}),
agent: z
.enum([
"windows",
"android",
"cli",
"olm",
"macos",
"ios",
"ipados",
"unknown"
])
.optional()
.catch(undefined)
.openapi({
type: "string",
enum: [
"windows",
"android",
"cli",
"olm",
"macos",
"ios",
"ipados",
"unknown"
],
description:
"Filter by agent type. Use 'unknown' to filter clients with no agent detected."
}),
status: z.preprocess(
(val: string | undefined) => {
if (val) {
return val.split(","); // the search query array is an array joined by commas
}
return undefined;
},
z
.array(
z.enum(["active", "pending", "denied", "blocked", "archived"])
)
.optional()
.default(["active", "pending"])
.catch(["active", "pending"])
.openapi({
type: "array",
items: {
type: "string",
enum: ["active", "pending", "denied", "blocked", "archived"]
},
default: ["active", "pending"],
description:
"Filter by device status. Can include multiple values separated by commas. 'active' means not archived, not blocked, and if approval is enabled, approved. 'pending' and 'denied' are only applicable if approval is enabled."
})
)
});
function queryUserDevicesBase() {
return db
.select({
clientId: clients.clientId,
orgId: clients.orgId,
name: clients.name,
pubKey: clients.pubKey,
subnet: clients.subnet,
megabytesIn: clients.megabytesIn,
megabytesOut: clients.megabytesOut,
orgName: orgs.name,
type: clients.type,
online: clients.online,
olmVersion: olms.version,
userId: clients.userId,
username: users.username,
userEmail: users.email,
niceId: clients.niceId,
agent: olms.agent,
approvalState: clients.approvalState,
olmArchived: olms.archived,
archived: clients.archived,
blocked: clients.blocked,
deviceModel: currentFingerprint.deviceModel,
fingerprintPlatform: currentFingerprint.platform,
fingerprintOsVersion: currentFingerprint.osVersion,
fingerprintKernelVersion: currentFingerprint.kernelVersion,
fingerprintArch: currentFingerprint.arch,
fingerprintSerialNumber: currentFingerprint.serialNumber,
fingerprintUsername: currentFingerprint.username,
fingerprintHostname: currentFingerprint.hostname
})
.from(clients)
.leftJoin(orgs, eq(clients.orgId, orgs.orgId))
.leftJoin(olms, eq(clients.clientId, olms.clientId))
.leftJoin(users, eq(clients.userId, users.userId))
.leftJoin(currentFingerprint, eq(olms.olmId, currentFingerprint.olmId));
}
type OlmWithUpdateAvailable = Awaited<
ReturnType<typeof queryUserDevicesBase>
>[0] & {
olmUpdateAvailable?: boolean;
};
export type ListUserDevicesResponse = PaginatedResponse<{
devices: Array<OlmWithUpdateAvailable>;
}>;
registry.registerPath({
method: "get",
path: "/org/{orgId}/user-devices",
description: "List all user devices for an organization.",
tags: [OpenAPITags.Client, OpenAPITags.Org],
request: {
query: listUserDevicesSchema,
params: listUserDevicesParamsSchema
},
responses: {}
});
export async function listUserDevices(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedQuery = listUserDevicesSchema.safeParse(req.query);
if (!parsedQuery.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedQuery.error)
)
);
}
const { page, pageSize, query, sort_by, online, status, agent, order } =
parsedQuery.data;
const parsedParams = listUserDevicesParamsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error)
)
);
}
const { orgId } = parsedParams.data;
if (req.user && orgId && orgId !== req.userOrgId) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"User does not have access to this organization"
)
);
}
let accessibleClients;
if (req.user) {
accessibleClients = await db
.select({
clientId: sql<number>`COALESCE(${userClients.clientId}, ${roleClients.clientId})`
})
.from(userClients)
.fullJoin(
roleClients,
eq(userClients.clientId, roleClients.clientId)
)
.where(
or(
eq(userClients.userId, req.user!.userId),
eq(roleClients.roleId, req.userOrgRoleId!)
)
);
} else {
accessibleClients = await db
.select({ clientId: clients.clientId })
.from(clients)
.where(eq(clients.orgId, orgId));
}
const accessibleClientIds = accessibleClients.map(
(client) => client.clientId
);
// Get client count with filter
const conditions = [
and(
inArray(clients.clientId, accessibleClientIds),
eq(clients.orgId, orgId),
isNotNull(clients.userId)
)
];
if (query) {
conditions.push(
or(
like(
sql`LOWER(${clients.name})`,
"%" + query.toLowerCase() + "%"
),
like(
sql`LOWER(${clients.niceId})`,
"%" + query.toLowerCase() + "%"
),
like(
sql`LOWER(${users.email})`,
"%" + query.toLowerCase() + "%"
)
)
);
}
if (typeof online !== "undefined") {
conditions.push(eq(clients.online, online));
}
const agentValueMap = {
windows: "Pangolin Windows",
android: "Pangolin Android",
ios: "Pangolin iOS",
ipados: "Pangolin iPadOS",
macos: "Pangolin macOS",
cli: "Pangolin CLI",
olm: "Olm CLI"
} satisfies Record<
Exclude<typeof agent, undefined | "unknown">,
string
>;
if (typeof agent !== "undefined") {
if (agent === "unknown") {
conditions.push(isNull(olms.agent));
} else {
conditions.push(eq(olms.agent, agentValueMap[agent]));
}
}
if (status.length > 0) {
const filterAggregates: (SQL<unknown> | undefined)[] = [];
if (status.includes("active")) {
filterAggregates.push(
and(
eq(clients.archived, false),
eq(clients.blocked, false),
build !== "oss"
? or(
eq(clients.approvalState, "approved"),
isNull(clients.approvalState) // approval state of `NULL` means approved by default
)
: undefined // undefined are automatically ignored by `drizzle-orm`
)
);
}
if (status.includes("archived")) {
filterAggregates.push(eq(clients.archived, true));
}
if (status.includes("blocked")) {
filterAggregates.push(eq(clients.blocked, true));
}
if (build !== "oss") {
if (status.includes("pending")) {
filterAggregates.push(eq(clients.approvalState, "pending"));
}
if (status.includes("denied")) {
filterAggregates.push(eq(clients.approvalState, "denied"));
}
}
conditions.push(or(...filterAggregates));
}
const baseQuery = queryUserDevicesBase().where(and(...conditions));
const countQuery = db.$count(baseQuery.as("filtered_clients"));
const listDevicesQuery = baseQuery
.limit(pageSize)
.offset(pageSize * (page - 1))
.orderBy(
sort_by
? order === "asc"
? asc(clients[sort_by])
: desc(clients[sort_by])
: asc(clients.clientId)
);
const [clientsList, totalCount] = await Promise.all([
listDevicesQuery,
countQuery
]);
// Merge clients with their site associations and replace name with device name
const olmsWithUpdates: OlmWithUpdateAvailable[] = clientsList.map(
(client) => {
const model = client.deviceModel || null;
const newName = getUserDeviceName(model, client.name);
const OlmWithUpdate: OlmWithUpdateAvailable = {
...client,
name: newName
};
// Initially set to false, will be updated if version check succeeds
OlmWithUpdate.olmUpdateAvailable = false;
return OlmWithUpdate;
}
);
// Try to get the latest version, but don't block if it fails
try {
const latestOlmVersion = await getLatestOlmVersion();
if (latestOlmVersion) {
olmsWithUpdates.forEach((client) => {
try {
client.olmUpdateAvailable = semver.lt(
client.olmVersion ? client.olmVersion : "",
latestOlmVersion
);
} catch (error) {
client.olmUpdateAvailable = false;
}
});
}
} catch (error) {
// Log the error but don't let it block the response
logger.warn(
"Failed to check for OLM updates, continuing without update info:",
error
);
}
return response<ListUserDevicesResponse>(res, {
data: {
devices: olmsWithUpdates,
pagination: {
total: totalCount,
page,
pageSize
}
},
success: true,
error: false,
message: "Clients retrieved successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -148,7 +148,6 @@ export async function createOrgDomain(
} }
} }
let numOrgDomains: OrgDomains[] | undefined;
let aRecords: CreateDomainResponse["aRecords"]; let aRecords: CreateDomainResponse["aRecords"];
let cnameRecords: CreateDomainResponse["cnameRecords"]; let cnameRecords: CreateDomainResponse["cnameRecords"];
let txtRecords: CreateDomainResponse["txtRecords"]; let txtRecords: CreateDomainResponse["txtRecords"];
@@ -347,20 +346,9 @@ export async function createOrgDomain(
await trx.insert(dnsRecords).values(recordsToInsert); await trx.insert(dnsRecords).values(recordsToInsert);
} }
numOrgDomains = await trx await usageService.add(orgId, FeatureId.DOMAINS, 1, trx);
.select()
.from(orgDomains)
.where(eq(orgDomains.orgId, orgId));
}); });
if (numOrgDomains) {
await usageService.updateCount(
orgId,
FeatureId.DOMAINS,
numOrgDomains.length
);
}
if (!returned) { if (!returned) {
return next( return next(
createHttpError( createHttpError(

View File

@@ -36,8 +36,6 @@ export async function deleteAccountDomain(
} }
const { domainId, orgId } = parsed.data; const { domainId, orgId } = parsed.data;
let numOrgDomains: OrgDomains[] | undefined;
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
const [existing] = await trx const [existing] = await trx
.select() .select()
@@ -79,20 +77,9 @@ export async function deleteAccountDomain(
await trx.delete(domains).where(eq(domains.domainId, domainId)); await trx.delete(domains).where(eq(domains.domainId, domainId));
numOrgDomains = await trx await usageService.add(orgId, FeatureId.DOMAINS, -1, trx);
.select()
.from(orgDomains)
.where(eq(orgDomains.orgId, orgId));
}); });
if (numOrgDomains) {
await usageService.updateCount(
orgId,
FeatureId.DOMAINS,
numOrgDomains.length
);
}
return response<DeleteAccountDomainResponse>(res, { return response<DeleteAccountDomainResponse>(res, {
data: { success: true }, data: { success: true },
success: true, success: true,

View File

@@ -50,6 +50,7 @@ import createHttpError from "http-errors";
import { build } from "@server/build"; import { build } from "@server/build";
import { createStore } from "#dynamic/lib/rateLimitStore"; import { createStore } from "#dynamic/lib/rateLimitStore";
import { logActionAudit } from "#dynamic/middlewares"; import { logActionAudit } from "#dynamic/middlewares";
import { checkRoundTripMessage } from "./ws";
// Root routes // Root routes
export const unauthenticated = Router(); export const unauthenticated = Router();
@@ -64,9 +65,8 @@ authenticated.use(verifySessionUserMiddleware);
authenticated.get("/pick-org-defaults", org.pickOrgDefaults); authenticated.get("/pick-org-defaults", org.pickOrgDefaults);
authenticated.get("/org/checkId", org.checkId); authenticated.get("/org/checkId", org.checkId);
if (build === "oss" || build === "enterprise") {
authenticated.put("/org", getUserOrgs, org.createOrg); authenticated.put("/org", getUserOrgs, org.createOrg);
}
authenticated.get("/orgs", verifyUserIsServerAdmin, org.listOrgs); authenticated.get("/orgs", verifyUserIsServerAdmin, org.listOrgs);
authenticated.get("/user/:userId/orgs", verifyIsLoggedInUser, org.listUserOrgs); authenticated.get("/user/:userId/orgs", verifyIsLoggedInUser, org.listUserOrgs);
@@ -86,7 +86,6 @@ authenticated.post(
org.updateOrg org.updateOrg
); );
if (build !== "saas") {
authenticated.delete( authenticated.delete(
"/org/:orgId", "/org/:orgId",
verifyOrgAccess, verifyOrgAccess,
@@ -95,7 +94,6 @@ if (build !== "saas") {
logActionAudit(ActionsEnum.deleteOrg), logActionAudit(ActionsEnum.deleteOrg),
org.deleteOrg org.deleteOrg
); );
}
authenticated.put( authenticated.put(
"/org/:orgId/site", "/org/:orgId/site",
@@ -145,6 +143,13 @@ authenticated.get(
client.listClients client.listClients
); );
authenticated.get(
"/org/:orgId/user-devices",
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.listClients),
client.listUserDevices
);
authenticated.get( authenticated.get(
"/client/:clientId", "/client/:clientId",
verifyClientAccess, verifyClientAccess,
@@ -1116,6 +1121,8 @@ authenticated.get(
blueprints.getBlueprint blueprints.getBlueprint
); );
authenticated.get("/ws/round-trip-message/:messageId", checkRoundTripMessage);
// Auth routes // Auth routes
export const authRouter = Router(); export const authRouter = Router();
unauthenticated.use("/auth", authRouter); unauthenticated.use("/auth", authRouter);

View File

@@ -70,6 +70,15 @@ export async function createIdpOrgPolicy(
const { idpId, orgId } = parsedParams.data; const { idpId, orgId } = parsedParams.data;
const { roleMapping, orgMapping } = parsedBody.data; const { roleMapping, orgMapping } = parsedBody.data;
if (process.env.IDENTITY_PROVIDER_MODE === "org") {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Global IdP creation is not allowed in the current identity provider mode. Set app.identity_provider_mode to 'global' in the private configuration to enable this feature."
)
);
}
const [existing] = await db const [existing] = await db
.select() .select()
.from(idp) .from(idp)

View File

@@ -80,6 +80,17 @@ export async function createOidcIdp(
tags tags
} = parsedBody.data; } = parsedBody.data;
if (
process.env.IDENTITY_PROVIDER_MODE === "org"
) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Global IdP creation is not allowed in the current identity provider mode. Set app.identity_provider_mode to 'global' in the private configuration to enable this feature."
)
);
}
const key = config.getRawConfig().server.secret!; const key = config.getRawConfig().server.secret!;
const encryptedSecret = encrypt(clientSecret, key); const encryptedSecret = encrypt(clientSecret, key);

View File

@@ -69,6 +69,15 @@ export async function updateIdpOrgPolicy(
const { idpId, orgId } = parsedParams.data; const { idpId, orgId } = parsedParams.data;
const { roleMapping, orgMapping } = parsedBody.data; const { roleMapping, orgMapping } = parsedBody.data;
if (process.env.IDENTITY_PROVIDER_MODE === "org") {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Global IdP creation is not allowed in the current identity provider mode. Set app.identity_provider_mode to 'global' in the private configuration to enable this feature."
)
);
}
// Check if IDP and policy exist // Check if IDP and policy exist
const [existing] = await db const [existing] = await db
.select() .select()

View File

@@ -99,6 +99,15 @@ export async function updateOidcIdp(
tags tags
} = parsedBody.data; } = parsedBody.data;
if (process.env.IDENTITY_PROVIDER_MODE === "org") {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Global IdP creation is not allowed in the current identity provider mode. Set app.identity_provider_mode to 'global' in the private configuration to enable this feature."
)
);
}
// Check if IDP exists and is of type OIDC // Check if IDP exists and is of type OIDC
const [existingIdp] = await db const [existingIdp] = await db
.select() .select()

View File

@@ -36,6 +36,10 @@ import { build } from "@server/build";
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
import { isSubscribed } from "#dynamic/lib/isSubscribed"; import { isSubscribed } from "#dynamic/lib/isSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix"; import { tierMatrix } from "@server/lib/billing/tierMatrix";
import {
assignUserToOrg,
removeUserFromOrg
} from "@server/lib/userOrg";
const ensureTrailingSlash = (url: string): string => { const ensureTrailingSlash = (url: string): string => {
return url; return url;
@@ -436,6 +440,7 @@ export async function validateOidcCallback(
} }
} }
// These are the orgs that the user should be provisioned into based on the IdP mappings and the token claims
logger.debug("User org info", { userOrgInfo }); logger.debug("User org info", { userOrgInfo });
let existingUserId = existingUser?.userId; let existingUserId = existingUser?.userId;
@@ -455,14 +460,31 @@ export async function validateOidcCallback(
if (!existingUserOrgs.length) { if (!existingUserOrgs.length) {
// delete all auto-provisioned user orgs // delete all auto-provisioned user orgs
await db const autoProvisionedUserOrgs = await db
.delete(userOrgs) .select()
.from(userOrgs)
.where( .where(
and( and(
eq(userOrgs.userId, existingUser.userId), eq(userOrgs.userId, existingUser.userId),
eq(userOrgs.autoProvisioned, true) eq(userOrgs.autoProvisioned, true)
) )
); );
const orgIdsToRemove = autoProvisionedUserOrgs.map(
(uo) => uo.orgId
);
if (orgIdsToRemove.length > 0) {
const orgsToRemove = await db
.select()
.from(orgs)
.where(inArray(orgs.orgId, orgIdsToRemove));
for (const org of orgsToRemove) {
await removeUserFromOrg(
org,
existingUser.userId,
db
);
}
}
await calculateUserClientsForOrgs(existingUser.userId); await calculateUserClientsForOrgs(existingUser.userId);
@@ -538,15 +560,14 @@ export async function validateOidcCallback(
); );
if (orgsToDelete.length > 0) { if (orgsToDelete.length > 0) {
await trx.delete(userOrgs).where( const orgIdsToRemove = orgsToDelete.map((org) => org.orgId);
and( const fullOrgsToRemove = await trx
eq(userOrgs.userId, userId!), .select()
inArray( .from(orgs)
userOrgs.orgId, .where(inArray(orgs.orgId, orgIdsToRemove));
orgsToDelete.map((org) => org.orgId) for (const org of fullOrgsToRemove) {
) await removeUserFromOrg(org, userId!, trx);
) }
);
} }
// Update roles for existing auto-provisioned orgs where the role has changed // Update roles for existing auto-provisioned orgs where the role has changed
@@ -587,16 +608,25 @@ export async function validateOidcCallback(
); );
if (orgsToAdd.length > 0) { if (orgsToAdd.length > 0) {
await trx.insert(userOrgs).values( for (const org of orgsToAdd) {
orgsToAdd.map((org) => ({ const [fullOrg] = await trx
userId: userId!, .select()
.from(orgs)
.where(eq(orgs.orgId, org.orgId));
if (fullOrg) {
await assignUserToOrg(
fullOrg,
{
orgId: org.orgId, orgId: org.orgId,
userId: userId!,
roleId: org.roleId, roleId: org.roleId,
autoProvisioned: true, autoProvisioned: true,
dateCreated: new Date().toISOString() },
})) trx
); );
} }
}
}
// Loop through all the orgs and get the total number of users from the userOrgs table // Loop through all the orgs and get the total number of users from the userOrgs table
// Use all current user orgs (both auto-provisioned and manually added) for counting // Use all current user orgs (both auto-provisioned and manually added) for counting

View File

@@ -866,6 +866,13 @@ authenticated.get(
client.listClients client.listClients
); );
authenticated.get(
"/org/:orgId/user-devices",
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.listClients),
client.listUserDevices
);
authenticated.get( authenticated.get(
"/client/:clientId", "/client/:clientId",
verifyApiKeyClientAccess, verifyApiKeyClientAccess,

View File

@@ -1,7 +1,7 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db } from "@server/db"; import { db } from "@server/db";
import { eq } from "drizzle-orm"; import { and, count, eq } from "drizzle-orm";
import { import {
domains, domains,
Org, Org,
@@ -24,13 +24,24 @@ import { OpenAPITags, registry } from "@server/openApi";
import { isValidCIDR } from "@server/lib/validators"; import { isValidCIDR } from "@server/lib/validators";
import { createCustomer } from "#dynamic/lib/billing"; import { createCustomer } from "#dynamic/lib/billing";
import { usageService } from "@server/lib/billing/usageService"; import { usageService } from "@server/lib/billing/usageService";
import { FeatureId } from "@server/lib/billing"; import { FeatureId, limitsService, freeLimitSet } from "@server/lib/billing";
import { build } from "@server/build"; import { build } from "@server/build";
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
import { doCidrsOverlap } from "@server/lib/ip"; import { doCidrsOverlap } from "@server/lib/ip";
import { generateCA } from "@server/private/lib/sshCA";
import { encrypt } from "@server/lib/crypto";
const validOrgIdRegex = /^[a-z0-9_]+(-[a-z0-9_]+)*$/;
const createOrgSchema = z.strictObject({ const createOrgSchema = z.strictObject({
orgId: z.string(), orgId: z
.string()
.min(1, "Organization ID is required")
.max(32, "Organization ID must be at most 32 characters")
.refine((val) => validOrgIdRegex.test(val), {
message:
"Organization ID must contain only lowercase letters, numbers, underscores, and single hyphens (no leading, trailing, or consecutive hyphens)"
}),
name: z.string().min(1).max(255), name: z.string().min(1).max(255),
subnet: z subnet: z
// .union([z.cidrv4(), z.cidrv6()]) // .union([z.cidrv4(), z.cidrv6()])
@@ -108,6 +119,7 @@ export async function createOrg(
// ) // )
// ); // );
// } // }
//
// make sure the orgId is unique // make sure the orgId is unique
const orgExists = await db const orgExists = await db
@@ -134,8 +146,71 @@ export async function createOrg(
); );
} }
let isFirstOrg: boolean | null = null;
let billingOrgIdForNewOrg: string | null = null;
if (build === "saas" && req.user) {
const ownedOrgs = await db
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.userId, req.user.userId),
eq(userOrgs.isOwner, true)
)
);
if (ownedOrgs.length === 0) {
isFirstOrg = true;
} else {
isFirstOrg = false;
const [billingOrg] = await db
.select({ orgId: orgs.orgId })
.from(orgs)
.innerJoin(userOrgs, eq(orgs.orgId, userOrgs.orgId))
.where(
and(
eq(userOrgs.userId, req.user.userId),
eq(userOrgs.isOwner, true),
eq(orgs.isBillingOrg, true)
)
)
.limit(1);
if (billingOrg) {
billingOrgIdForNewOrg = billingOrg.orgId;
}
}
}
if (build == "saas" && billingOrgIdForNewOrg) {
const usage = await usageService.getUsage(billingOrgIdForNewOrg, FeatureId.ORGINIZATIONS);
if (!usage) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"No usage data found for this organization"
)
);
}
const rejectOrgs = await usageService.checkLimitSet(
billingOrgIdForNewOrg,
FeatureId.ORGINIZATIONS,
{
...usage,
instantaneousValue: (usage.instantaneousValue || 0) + 1
} // We need to add one to know if we are violating the limit
);
if (rejectOrgs) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Organization limit exceeded. Please upgrade your plan."
)
);
}
}
let error = ""; let error = "";
let org: Org | null = null; let org: Org | null = null;
let numOrgs: number | null = null;
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
const allDomains = await trx const allDomains = await trx
@@ -143,6 +218,21 @@ export async function createOrg(
.from(domains) .from(domains)
.where(eq(domains.configManaged, true)); .where(eq(domains.configManaged, true));
// Generate SSH CA keys for the org
// const ca = generateCA(`${orgId}-ca`);
// const encryptionKey = config.getRawConfig().server.secret!;
// const encryptedCaPrivateKey = encrypt(ca.privateKeyPem, encryptionKey);
const saasBillingFields =
build === "saas" && req.user && isFirstOrg !== null
? isFirstOrg
? { isBillingOrg: true as const, billingOrgId: orgId } // if this is the first org, it becomes the billing org for itself
: {
isBillingOrg: false as const,
billingOrgId: billingOrgIdForNewOrg
}
: {};
const newOrg = await trx const newOrg = await trx
.insert(orgs) .insert(orgs)
.values({ .values({
@@ -150,7 +240,10 @@ export async function createOrg(
name, name,
subnet, subnet,
utilitySubnet, utilitySubnet,
createdAt: new Date().toISOString() createdAt: new Date().toISOString(),
// sshCaPrivateKey: encryptedCaPrivateKey,
// sshCaPublicKey: ca.publicKeyOpenSSH,
...saasBillingFields
}) })
.returning(); .returning();
@@ -252,6 +345,17 @@ export async function createOrg(
); );
await calculateUserClientsForOrgs(ownerUserId, trx); await calculateUserClientsForOrgs(ownerUserId, trx);
if (billingOrgIdForNewOrg) {
const [numOrgsResult] = await trx
.select({ count: count() })
.from(orgs)
.where(eq(orgs.billingOrgId, billingOrgIdForNewOrg)); // all the billable orgs including the primary org that is the billing org itself
numOrgs = numOrgsResult.count;
} else {
numOrgs = 1; // we only have one org if there is no billing org found out
}
}); });
if (!org) { if (!org) {
@@ -267,8 +371,8 @@ export async function createOrg(
return next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, error)); return next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, error));
} }
if (build == "saas") { if (build === "saas" && isFirstOrg === true) {
// make sure we have the stripe customer await limitsService.applyLimitSetToOrg(orgId, freeLimitSet);
const customerId = await createCustomer(orgId, req.user?.email); const customerId = await createCustomer(orgId, req.user?.email);
if (customerId) { if (customerId) {
await usageService.updateCount( await usageService.updateCount(
@@ -280,6 +384,14 @@ export async function createOrg(
} }
} }
if (numOrgs) {
usageService.updateCount(
billingOrgIdForNewOrg || orgId,
FeatureId.ORGINIZATIONS,
numOrgs
);
}
return response(res, { return response(res, {
data: org, data: org,
success: true, success: true,

View File

@@ -7,6 +7,8 @@ import logger from "@server/logger";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import { deleteOrgById, sendTerminationMessages } from "@server/lib/deleteOrg"; import { deleteOrgById, sendTerminationMessages } from "@server/lib/deleteOrg";
import { db, userOrgs, orgs } from "@server/db";
import { eq, and } from "drizzle-orm";
const deleteOrgSchema = z.strictObject({ const deleteOrgSchema = z.strictObject({
orgId: z.string() orgId: z.string()
@@ -41,6 +43,48 @@ export async function deleteOrg(
); );
} }
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
const [data] = await db
.select()
.from(userOrgs)
.innerJoin(orgs, eq(userOrgs.orgId, orgs.orgId))
.where(
and(
eq(userOrgs.orgId, orgId),
eq(userOrgs.userId, req.user!.userId)
)
);
const org = data?.orgs;
const userOrg = data?.userOrgs;
if (!org || !userOrg) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Organization with ID ${orgId} not found`
)
);
}
if (!userOrg.isOwner) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Only organization owners can delete the organization"
)
);
}
if (org.isBillingOrg) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Cannot delete a primary organization"
)
);
}
const result = await deleteOrgById(orgId); const result = await deleteOrgById(orgId);
sendTerminationMessages(result); sendTerminationMessages(result);
return response(res, { return response(res, {

View File

@@ -40,7 +40,11 @@ const listOrgsSchema = z.object({
// responses: {} // responses: {}
// }); // });
type ResponseOrg = Org & { isOwner?: boolean; isAdmin?: boolean }; type ResponseOrg = Org & {
isOwner?: boolean;
isAdmin?: boolean;
isPrimaryOrg?: boolean;
};
export type ListUserOrgsResponse = { export type ListUserOrgsResponse = {
orgs: ResponseOrg[]; orgs: ResponseOrg[];
@@ -132,6 +136,9 @@ export async function listUserOrgs(
if (val.roles && val.roles.isAdmin) { if (val.roles && val.roles.isAdmin) {
res.isAdmin = val.roles.isAdmin; res.isAdmin = val.roles.isAdmin;
} }
if (val.userOrgs?.isOwner && val.orgs?.isBillingOrg) {
res.isPrimaryOrg = val.orgs.isBillingOrg;
}
return res; return res;
}); });

View File

@@ -8,7 +8,10 @@ import {
userOrgs, userOrgs,
resourcePassword, resourcePassword,
resourcePincode, resourcePincode,
resourceWhitelist resourceWhitelist,
siteResources,
userSiteResources,
roleSiteResources
} from "@server/db"; } from "@server/db";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
@@ -57,9 +60,21 @@ export async function getUserResources(
.from(roleResources) .from(roleResources)
.where(eq(roleResources.roleId, userRoleId)); .where(eq(roleResources.roleId, userRoleId));
const [directResources, roleResourceResults] = await Promise.all([ const directSiteResourcesQuery = db
.select({ siteResourceId: userSiteResources.siteResourceId })
.from(userSiteResources)
.where(eq(userSiteResources.userId, userId));
const roleSiteResourcesQuery = db
.select({ siteResourceId: roleSiteResources.siteResourceId })
.from(roleSiteResources)
.where(eq(roleSiteResources.roleId, userRoleId));
const [directResources, roleResourceResults, directSiteResourceResults, roleSiteResourceResults] = await Promise.all([
directResourcesQuery, directResourcesQuery,
roleResourcesQuery roleResourcesQuery,
directSiteResourcesQuery,
roleSiteResourcesQuery
]); ]);
// Combine all accessible resource IDs // Combine all accessible resource IDs
@@ -68,18 +83,25 @@ export async function getUserResources(
...roleResourceResults.map((r) => r.resourceId) ...roleResourceResults.map((r) => r.resourceId)
]; ];
if (accessibleResourceIds.length === 0) { // Combine all accessible site resource IDs
return response(res, { const accessibleSiteResourceIds = [
data: { resources: [] }, ...directSiteResourceResults.map((r) => r.siteResourceId),
success: true, ...roleSiteResourceResults.map((r) => r.siteResourceId)
error: false, ];
message: "No resources found",
status: HttpCode.OK
});
}
// Get resource details for accessible resources // Get resource details for accessible resources
const resourcesData = await db let resourcesData: Array<{
resourceId: number;
name: string;
fullDomain: string | null;
ssl: boolean;
enabled: boolean;
sso: boolean;
protocol: string;
emailWhitelistEnabled: boolean;
}> = [];
if (accessibleResourceIds.length > 0) {
resourcesData = await db
.select({ .select({
resourceId: resources.resourceId, resourceId: resources.resourceId,
name: resources.name, name: resources.name,
@@ -98,6 +120,40 @@ export async function getUserResources(
eq(resources.enabled, true) eq(resources.enabled, true)
) )
); );
}
// Get site resource details for accessible site resources
let siteResourcesData: Array<{
siteResourceId: number;
name: string;
destination: string;
mode: string;
protocol: string | null;
enabled: boolean;
alias: string | null;
aliasAddress: string | null;
}> = [];
if (accessibleSiteResourceIds.length > 0) {
siteResourcesData = await db
.select({
siteResourceId: siteResources.siteResourceId,
name: siteResources.name,
destination: siteResources.destination,
mode: siteResources.mode,
protocol: siteResources.protocol,
enabled: siteResources.enabled,
alias: siteResources.alias,
aliasAddress: siteResources.aliasAddress
})
.from(siteResources)
.where(
and(
inArray(siteResources.siteResourceId, accessibleSiteResourceIds),
eq(siteResources.orgId, orgId),
eq(siteResources.enabled, true)
)
);
}
// Check for password, pincode, and whitelist protection for each resource // Check for password, pincode, and whitelist protection for each resource
const resourcesWithAuth = await Promise.all( const resourcesWithAuth = await Promise.all(
@@ -161,8 +217,26 @@ export async function getUserResources(
}) })
); );
// Format site resources
const siteResourcesFormatted = siteResourcesData.map((siteResource) => {
return {
siteResourceId: siteResource.siteResourceId,
name: siteResource.name,
destination: siteResource.destination,
mode: siteResource.mode,
protocol: siteResource.protocol,
enabled: siteResource.enabled,
alias: siteResource.alias,
aliasAddress: siteResource.aliasAddress,
type: 'site' as const
};
});
return response(res, { return response(res, {
data: { resources: resourcesWithAuth }, data: {
resources: resourcesWithAuth,
siteResources: siteResourcesFormatted
},
success: true, success: true,
error: false, error: false,
message: "User resources retrieved successfully", message: "User resources retrieved successfully",
@@ -190,5 +264,16 @@ export type GetUserResourcesResponse = {
protected: boolean; protected: boolean;
protocol: string; protocol: string;
}>; }>;
siteResources: Array<{
siteResourceId: number;
name: string;
destination: string;
mode: string;
protocol: string | null;
enabled: boolean;
alias: string | null;
aliasAddress: string | null;
type: 'site';
}>;
}; };
}; };

View File

@@ -1,74 +1,99 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { import {
db, db,
resourceHeaderAuth, resourceHeaderAuth,
resourceHeaderAuthExtendedCompatibility resourceHeaderAuthExtendedCompatibility,
} from "@server/db";
import {
resources,
userResources,
roleResources,
resourcePassword, resourcePassword,
resourcePincode, resourcePincode,
resources,
roleResources,
targetHealthCheck,
targets, targets,
targetHealthCheck userResources
} from "@server/db"; } from "@server/db";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import { sql, eq, or, inArray, and, count } from "drizzle-orm";
import logger from "@server/logger"; import logger from "@server/logger";
import { fromZodError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import HttpCode from "@server/types/HttpCode";
import type { PaginatedResponse } from "@server/types/Pagination";
import {
and,
asc,
count,
eq,
inArray,
isNull,
like,
not,
or,
sql,
type SQL
} from "drizzle-orm";
import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors";
import { z } from "zod";
import { fromZodError } from "zod-validation-error";
const listResourcesParamsSchema = z.strictObject({ const listResourcesParamsSchema = z.strictObject({
orgId: z.string() orgId: z.string()
}); });
const listResourcesSchema = z.object({ const listResourcesSchema = z.object({
limit: z pageSize: z.coerce
.string() .number<string>() // for prettier formatting
.int()
.positive()
.optional() .optional()
.default("1000") .catch(20)
.transform(Number) .default(20)
.pipe(z.int().nonnegative()), .openapi({
type: "integer",
offset: z default: 20,
.string() description: "Number of items per page"
}),
page: z.coerce
.number<string>() // for prettier formatting
.int()
.min(0)
.optional() .optional()
.default("0") .catch(1)
.transform(Number) .default(1)
.pipe(z.int().nonnegative()) .openapi({
type: "integer",
default: 1,
description: "Page number to retrieve"
}),
query: z.string().optional(),
enabled: z
.enum(["true", "false"])
.transform((v) => v === "true")
.optional()
.catch(undefined)
.openapi({
type: "boolean",
description: "Filter resources based on enabled status"
}),
authState: z
.enum(["protected", "not_protected", "none"])
.optional()
.catch(undefined)
.openapi({
type: "string",
enum: ["protected", "not_protected", "none"],
description:
"Filter resources based on authentication state. `protected` means the resource has at least one auth mechanism (password, pincode, header auth, SSO, or email whitelist). `not_protected` means the resource has no auth mechanisms. `none` means the resource is not protected by HTTP (i.e. it has no auth mechanisms and http is false)."
}),
healthStatus: z
.enum(["no_targets", "healthy", "degraded", "offline", "unknown"])
.optional()
.catch(undefined)
.openapi({
type: "string",
enum: ["no_targets", "healthy", "degraded", "offline", "unknown"],
description:
"Filter resources based on health status of their targets. `healthy` means all targets are healthy. `degraded` means at least one target is unhealthy, but not all are unhealthy. `offline` means all targets are unhealthy. `unknown` means all targets have unknown health status. `no_targets` means the resource has no targets."
})
}); });
// (resource fields + a single joined target)
type JoinedRow = {
resourceId: number;
niceId: string;
name: string;
ssl: boolean;
fullDomain: string | null;
passwordId: number | null;
sso: boolean;
pincodeId: number | null;
whitelist: boolean;
http: boolean;
protocol: string;
proxyPort: number | null;
enabled: boolean;
domainId: string | null;
headerAuthId: number | null;
targetId: number | null;
targetIp: string | null;
targetPort: number | null;
targetEnabled: boolean | null;
hcHealth: string | null;
hcEnabled: boolean | null;
};
// grouped by resource with targets[]) // grouped by resource with targets[])
export type ResourceWithTargets = { export type ResourceWithTargets = {
resourceId: number; resourceId: number;
@@ -91,11 +116,32 @@ export type ResourceWithTargets = {
ip: string; ip: string;
port: number; port: number;
enabled: boolean; enabled: boolean;
healthStatus?: "healthy" | "unhealthy" | "unknown"; healthStatus: "healthy" | "unhealthy" | "unknown" | null;
}>; }>;
}; };
function queryResources(accessibleResourceIds: number[], orgId: string) { // Aggregate filters
const total_targets = count(targets.targetId);
const healthy_targets = sql<number>`SUM(
CASE
WHEN ${targetHealthCheck.hcHealth} = 'healthy' THEN 1
ELSE 0
END
) `;
const unknown_targets = sql<number>`SUM(
CASE
WHEN ${targetHealthCheck.hcHealth} = 'unknown' THEN 1
ELSE 0
END
) `;
const unhealthy_targets = sql<number>`SUM(
CASE
WHEN ${targetHealthCheck.hcHealth} = 'unhealthy' THEN 1
ELSE 0
END
) `;
function queryResourcesBase() {
return db return db
.select({ .select({
resourceId: resources.resourceId, resourceId: resources.resourceId,
@@ -114,14 +160,7 @@ function queryResources(accessibleResourceIds: number[], orgId: string) {
niceId: resources.niceId, niceId: resources.niceId,
headerAuthId: resourceHeaderAuth.headerAuthId, headerAuthId: resourceHeaderAuth.headerAuthId,
headerAuthExtendedCompatibilityId: headerAuthExtendedCompatibilityId:
resourceHeaderAuthExtendedCompatibility.headerAuthExtendedCompatibilityId, resourceHeaderAuthExtendedCompatibility.headerAuthExtendedCompatibilityId
targetId: targets.targetId,
targetIp: targets.ip,
targetPort: targets.port,
targetEnabled: targets.enabled,
hcHealth: targetHealthCheck.hcHealth,
hcEnabled: targetHealthCheck.hcEnabled
}) })
.from(resources) .from(resources)
.leftJoin( .leftJoin(
@@ -148,18 +187,18 @@ function queryResources(accessibleResourceIds: number[], orgId: string) {
targetHealthCheck, targetHealthCheck,
eq(targetHealthCheck.targetId, targets.targetId) eq(targetHealthCheck.targetId, targets.targetId)
) )
.where( .groupBy(
and( resources.resourceId,
inArray(resources.resourceId, accessibleResourceIds), resourcePassword.passwordId,
eq(resources.orgId, orgId) resourcePincode.pincodeId,
) resourceHeaderAuth.headerAuthId,
resourceHeaderAuthExtendedCompatibility.headerAuthExtendedCompatibilityId
); );
} }
export type ListResourcesResponse = { export type ListResourcesResponse = PaginatedResponse<{
resources: ResourceWithTargets[]; resources: ResourceWithTargets[];
pagination: { total: number; limit: number; offset: number }; }>;
};
registry.registerPath({ registry.registerPath({
method: "get", method: "get",
@@ -190,7 +229,8 @@ export async function listResources(
) )
); );
} }
const { limit, offset } = parsedQuery.data; const { page, pageSize, authState, enabled, query, healthStatus } =
parsedQuery.data;
const parsedParams = listResourcesParamsSchema.safeParse(req.params); const parsedParams = listResourcesParamsSchema.safeParse(req.params);
if (!parsedParams.success) { if (!parsedParams.success) {
@@ -252,14 +292,133 @@ export async function listResources(
(resource) => resource.resourceId (resource) => resource.resourceId
); );
const countQuery: any = db const conditions = [
.select({ count: count() }) and(
.from(resources) inArray(resources.resourceId, accessibleResourceIds),
.where(inArray(resources.resourceId, accessibleResourceIds)); eq(resources.orgId, orgId)
)
];
const baseQuery = queryResources(accessibleResourceIds, orgId); if (query) {
conditions.push(
or(
like(
sql`LOWER(${resources.name})`,
"%" + query.toLowerCase() + "%"
),
like(
sql`LOWER(${resources.niceId})`,
"%" + query.toLowerCase() + "%"
),
like(
sql`LOWER(${resources.fullDomain})`,
"%" + query.toLowerCase() + "%"
)
)
);
}
if (typeof enabled !== "undefined") {
conditions.push(eq(resources.enabled, enabled));
}
const rows: JoinedRow[] = await baseQuery.limit(limit).offset(offset); if (typeof authState !== "undefined") {
switch (authState) {
case "none":
conditions.push(eq(resources.http, false));
break;
case "protected":
conditions.push(
or(
eq(resources.sso, true),
eq(resources.emailWhitelistEnabled, true),
not(isNull(resourceHeaderAuth.headerAuthId)),
not(isNull(resourcePincode.pincodeId)),
not(isNull(resourcePassword.passwordId))
)
);
break;
case "not_protected":
conditions.push(
not(eq(resources.sso, true)),
not(eq(resources.emailWhitelistEnabled, true)),
isNull(resourceHeaderAuth.headerAuthId),
isNull(resourcePincode.pincodeId),
isNull(resourcePassword.passwordId)
);
break;
}
}
let aggregateFilters: SQL<any> | undefined = sql`1 = 1`;
if (typeof healthStatus !== "undefined") {
switch (healthStatus) {
case "healthy":
aggregateFilters = and(
sql`${total_targets} > 0`,
sql`${healthy_targets} = ${total_targets}`
);
break;
case "degraded":
aggregateFilters = and(
sql`${total_targets} > 0`,
sql`${unhealthy_targets} > 0`
);
break;
case "no_targets":
aggregateFilters = sql`${total_targets} = 0`;
break;
case "offline":
aggregateFilters = and(
sql`${total_targets} > 0`,
sql`${healthy_targets} = 0`,
sql`${unhealthy_targets} = ${total_targets}`
);
break;
case "unknown":
aggregateFilters = and(
sql`${total_targets} > 0`,
sql`${unknown_targets} = ${total_targets}`
);
break;
}
}
const baseQuery = queryResourcesBase()
.where(and(...conditions))
.having(aggregateFilters);
// we need to add `as` so that drizzle filters the result as a subquery
const countQuery = db.$count(baseQuery.as("filtered_resources"));
const [rows, totalCount] = await Promise.all([
baseQuery
.limit(pageSize)
.offset(pageSize * (page - 1))
.orderBy(asc(resources.resourceId)),
countQuery
]);
const resourceIdList = rows.map((row) => row.resourceId);
const allResourceTargets =
resourceIdList.length === 0
? []
: await db
.select({
targetId: targets.targetId,
resourceId: targets.resourceId,
ip: targets.ip,
port: targets.port,
enabled: targets.enabled,
healthStatus: targetHealthCheck.hcHealth,
hcEnabled: targetHealthCheck.hcEnabled
})
.from(targets)
.where(inArray(targets.resourceId, resourceIdList))
.leftJoin(
targetHealthCheck,
eq(targetHealthCheck.targetId, targets.targetId)
);
// avoids TS issues with reduce/never[] // avoids TS issues with reduce/never[]
const map = new Map<number, ResourceWithTargets>(); const map = new Map<number, ResourceWithTargets>();
@@ -288,44 +447,20 @@ export async function listResources(
map.set(row.resourceId, entry); map.set(row.resourceId, entry);
} }
if ( entry.targets = allResourceTargets.filter(
row.targetId != null && (t) => t.resourceId === entry.resourceId
row.targetIp && );
row.targetPort != null &&
row.targetEnabled != null
) {
let healthStatus: "healthy" | "unhealthy" | "unknown" =
"unknown";
if (row.hcEnabled && row.hcHealth) {
healthStatus = row.hcHealth as
| "healthy"
| "unhealthy"
| "unknown";
}
entry.targets.push({
targetId: row.targetId,
ip: row.targetIp,
port: row.targetPort,
enabled: row.targetEnabled,
healthStatus: healthStatus
});
}
} }
const resourcesList: ResourceWithTargets[] = Array.from(map.values()); const resourcesList: ResourceWithTargets[] = Array.from(map.values());
const totalCountResult = await countQuery;
const totalCount = totalCountResult[0]?.count ?? 0;
return response<ListResourcesResponse>(res, { return response<ListResourcesResponse>(res, {
data: { data: {
resources: resourcesList, resources: resourcesList,
pagination: { pagination: {
total: totalCount, total: totalCount,
limit, pageSize,
offset page
} }
}, },
success: true, success: true,

View File

@@ -33,7 +33,7 @@ const updateResourceParamsSchema = z.strictObject({
const updateHttpResourceBodySchema = z const updateHttpResourceBodySchema = z
.strictObject({ .strictObject({
name: z.string().min(1).max(255).optional(), name: z.string().min(1).max(255).optional(),
niceId: z.string().min(1).max(255).optional(), niceId: z.string().min(1).max(255).regex(/^[a-zA-Z0-9-]+$/, "niceId can only contain letters, numbers, and dashes").optional(),
subdomain: subdomainSchema.nullable().optional(), subdomain: subdomainSchema.nullable().optional(),
ssl: z.boolean().optional(), ssl: z.boolean().optional(),
sso: z.boolean().optional(), sso: z.boolean().optional(),

View File

@@ -6,7 +6,7 @@ import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import logger from "@server/logger"; import logger from "@server/logger";
import { eq, and } from "drizzle-orm"; import { eq, and, count } from "drizzle-orm";
import { getUniqueSiteName } from "../../db/names"; import { getUniqueSiteName } from "../../db/names";
import { addPeer } from "../gerbil/peers"; import { addPeer } from "../gerbil/peers";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
@@ -288,7 +288,6 @@ export async function createSite(
const niceId = await getUniqueSiteName(orgId); const niceId = await getUniqueSiteName(orgId);
let newSite: Site | undefined; let newSite: Site | undefined;
let numSites: Site[] | undefined;
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
if (type == "newt") { if (type == "newt") {
[newSite] = await trx [newSite] = await trx
@@ -443,20 +442,9 @@ export async function createSite(
}); });
} }
numSites = await trx await usageService.add(orgId, FeatureId.SITES, 1, trx);
.select()
.from(sites)
.where(eq(sites.orgId, orgId));
}); });
if (numSites) {
await usageService.updateCount(
orgId,
FeatureId.SITES,
numSites.length
);
}
if (!newSite) { if (!newSite) {
return next( return next(
createHttpError( createHttpError(

View File

@@ -64,7 +64,6 @@ export async function deleteSite(
} }
let deletedNewtId: string | null = null; let deletedNewtId: string | null = null;
let numSites: Site[] | undefined;
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
if (site.type == "wireguard") { if (site.type == "wireguard") {
@@ -103,19 +102,9 @@ export async function deleteSite(
await trx.delete(sites).where(eq(sites.siteId, siteId)); await trx.delete(sites).where(eq(sites.siteId, siteId));
numSites = await trx await usageService.add(site.orgId, FeatureId.SITES, -1, trx);
.select()
.from(sites)
.where(eq(sites.orgId, site.orgId));
}); });
if (numSites) {
await usageService.updateCount(
site.orgId,
FeatureId.SITES,
numSites.length
);
}
// Send termination message outside of transaction to prevent blocking // Send termination message outside of transaction to prevent blocking
if (deletedNewtId) { if (deletedNewtId) {
const payload = { const payload = {

View File

@@ -1,17 +1,25 @@
import { db, exitNodes, newts } from "@server/db"; import {
import { orgs, roleSites, sites, userSites } from "@server/db"; db,
import { remoteExitNodes } from "@server/db"; exitNodes,
import logger from "@server/logger"; newts,
import HttpCode from "@server/types/HttpCode"; orgs,
remoteExitNodes,
roleSites,
sites,
userSites
} from "@server/db";
import cache from "@server/lib/cache";
import response from "@server/lib/response"; import response from "@server/lib/response";
import { and, count, eq, inArray, or, sql } from "drizzle-orm"; import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
import HttpCode from "@server/types/HttpCode";
import type { PaginatedResponse } from "@server/types/Pagination";
import { and, asc, desc, eq, inArray, like, or, sql } from "drizzle-orm";
import { NextFunction, Request, Response } from "express"; import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import semver from "semver";
import { z } from "zod"; import { z } from "zod";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import semver from "semver";
import cache from "@server/lib/cache";
async function getLatestNewtVersion(): Promise<string | null> { async function getLatestNewtVersion(): Promise<string | null> {
try { try {
@@ -74,21 +82,63 @@ const listSitesParamsSchema = z.strictObject({
}); });
const listSitesSchema = z.object({ const listSitesSchema = z.object({
limit: z pageSize: z.coerce
.string() .number<string>() // for prettier formatting
.int()
.positive()
.optional() .optional()
.default("1000") .catch(20)
.transform(Number) .default(20)
.pipe(z.int().positive()), .openapi({
offset: z type: "integer",
.string() default: 20,
description: "Number of items per page"
}),
page: z.coerce
.number<string>() // for prettier formatting
.int()
.min(0)
.optional() .optional()
.default("0") .catch(1)
.transform(Number) .default(1)
.pipe(z.int().nonnegative()) .openapi({
type: "integer",
default: 1,
description: "Page number to retrieve"
}),
query: z.string().optional(),
sort_by: z
.enum(["megabytesIn", "megabytesOut"])
.optional()
.catch(undefined)
.openapi({
type: "string",
enum: ["megabytesIn", "megabytesOut"],
description: "Field to sort by"
}),
order: z
.enum(["asc", "desc"])
.optional()
.default("asc")
.catch("asc")
.openapi({
type: "string",
enum: ["asc", "desc"],
default: "asc",
description: "Sort order"
}),
online: z
.enum(["true", "false"])
.transform((v) => v === "true")
.optional()
.catch(undefined)
.openapi({
type: "boolean",
description: "Filter by online status"
})
}); });
function querySites(orgId: string, accessibleSiteIds: number[]) { function querySitesBase() {
return db return db
.select({ .select({
siteId: sites.siteId, siteId: sites.siteId,
@@ -115,23 +165,16 @@ function querySites(orgId: string, accessibleSiteIds: number[]) {
.leftJoin( .leftJoin(
remoteExitNodes, remoteExitNodes,
eq(remoteExitNodes.exitNodeId, sites.exitNodeId) eq(remoteExitNodes.exitNodeId, sites.exitNodeId)
)
.where(
and(
inArray(sites.siteId, accessibleSiteIds),
eq(sites.orgId, orgId)
)
); );
} }
type SiteWithUpdateAvailable = Awaited<ReturnType<typeof querySites>>[0] & { type SiteWithUpdateAvailable = Awaited<ReturnType<typeof querySitesBase>>[0] & {
newtUpdateAvailable?: boolean; newtUpdateAvailable?: boolean;
}; };
export type ListSitesResponse = { export type ListSitesResponse = PaginatedResponse<{
sites: SiteWithUpdateAvailable[]; sites: SiteWithUpdateAvailable[];
pagination: { total: number; limit: number; offset: number }; }>;
};
registry.registerPath({ registry.registerPath({
method: "get", method: "get",
@@ -160,7 +203,6 @@ export async function listSites(
) )
); );
} }
const { limit, offset } = parsedQuery.data;
const parsedParams = listSitesParamsSchema.safeParse(req.params); const parsedParams = listSitesParamsSchema.safeParse(req.params);
if (!parsedParams.success) { if (!parsedParams.success) {
@@ -203,34 +245,67 @@ export async function listSites(
.where(eq(sites.orgId, orgId)); .where(eq(sites.orgId, orgId));
} }
const accessibleSiteIds = accessibleSites.map((site) => site.siteId); const { pageSize, page, query, sort_by, order, online } =
const baseQuery = querySites(orgId, accessibleSiteIds); parsedQuery.data;
const countQuery = db const accessibleSiteIds = accessibleSites.map((site) => site.siteId);
.select({ count: count() })
.from(sites) const conditions = [
.where(
and( and(
inArray(sites.siteId, accessibleSiteIds), inArray(sites.siteId, accessibleSiteIds),
eq(sites.orgId, orgId) eq(sites.orgId, orgId)
) )
];
if (query) {
conditions.push(
or(
like(
sql`LOWER(${sites.name})`,
"%" + query.toLowerCase() + "%"
),
like(
sql`LOWER(${sites.niceId})`,
"%" + query.toLowerCase() + "%"
)
)
);
}
if (typeof online !== "undefined") {
conditions.push(eq(sites.online, online));
}
const baseQuery = querySitesBase().where(and(...conditions));
// we need to add `as` so that drizzle filters the result as a subquery
const countQuery = db.$count(
querySitesBase().where(and(...conditions))
); );
const sitesList = await baseQuery.limit(limit).offset(offset); const siteListQuery = baseQuery
const totalCountResult = await countQuery; .limit(pageSize)
const totalCount = totalCountResult[0].count; .offset(pageSize * (page - 1))
.orderBy(
sort_by
? order === "asc"
? asc(sites[sort_by])
: desc(sites[sort_by])
: asc(sites.siteId)
);
const [totalCount, rows] = await Promise.all([
countQuery,
siteListQuery
]);
// Get latest version asynchronously without blocking the response // Get latest version asynchronously without blocking the response
const latestNewtVersionPromise = getLatestNewtVersion(); const latestNewtVersionPromise = getLatestNewtVersion();
const sitesWithUpdates: SiteWithUpdateAvailable[] = sitesList.map( const sitesWithUpdates: SiteWithUpdateAvailable[] = rows.map((site) => {
(site) => {
const siteWithUpdate: SiteWithUpdateAvailable = { ...site }; const siteWithUpdate: SiteWithUpdateAvailable = { ...site };
// Initially set to false, will be updated if version check succeeds // Initially set to false, will be updated if version check succeeds
siteWithUpdate.newtUpdateAvailable = false; siteWithUpdate.newtUpdateAvailable = false;
return siteWithUpdate; return siteWithUpdate;
} });
);
// Try to get the latest version, but don't block if it fails // Try to get the latest version, but don't block if it fails
try { try {
@@ -267,8 +342,8 @@ export async function listSites(
sites: sitesWithUpdates, sites: sitesWithUpdates,
pagination: { pagination: {
total: totalCount, total: totalCount,
limit, pageSize,
offset page
} }
}, },
success: true, success: true,

View File

@@ -284,7 +284,7 @@ export async function createSiteResource(
niceId, niceId,
orgId, orgId,
name, name,
mode, mode: mode as "host" | "cidr",
// protocol: mode === "port" ? protocol : null, // protocol: mode === "port" ? protocol : null,
// proxyPort: mode === "port" ? proxyPort : null, // proxyPort: mode === "port" ? proxyPort : null,
// destinationPort: mode === "port" ? destinationPort : null, // destinationPort: mode === "port" ? destinationPort : null,

View File

@@ -1,41 +1,90 @@
import { Request, Response, NextFunction } from "express"; import { db, SiteResource, siteResources, sites } from "@server/db";
import { z } from "zod";
import { db } from "@server/db";
import { siteResources, sites, SiteResource } from "@server/db";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import { eq, and } from "drizzle-orm";
import { fromError } from "zod-validation-error";
import logger from "@server/logger"; import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import HttpCode from "@server/types/HttpCode";
import type { PaginatedResponse } from "@server/types/Pagination";
import { and, asc, eq, like, or, sql } from "drizzle-orm";
import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors";
import { z } from "zod";
import { fromError } from "zod-validation-error";
const listAllSiteResourcesByOrgParamsSchema = z.strictObject({ const listAllSiteResourcesByOrgParamsSchema = z.strictObject({
orgId: z.string() orgId: z.string()
}); });
const listAllSiteResourcesByOrgQuerySchema = z.object({ const listAllSiteResourcesByOrgQuerySchema = z.object({
limit: z pageSize: z.coerce
.string() .number<string>() // for prettier formatting
.int()
.positive()
.optional() .optional()
.default("1000") .catch(20)
.transform(Number) .default(20)
.pipe(z.int().positive()), .openapi({
offset: z type: "integer",
.string() default: 20,
description: "Number of items per page"
}),
page: z.coerce
.number<string>() // for prettier formatting
.int()
.min(0)
.optional() .optional()
.default("0") .catch(1)
.transform(Number) .default(1)
.pipe(z.int().nonnegative()) .openapi({
type: "integer",
default: 1,
description: "Page number to retrieve"
}),
query: z.string().optional(),
mode: z
.enum(["host", "cidr"])
.optional()
.catch(undefined)
.openapi({
type: "string",
enum: ["host", "cidr"],
description: "Filter site resources by mode"
})
}); });
export type ListAllSiteResourcesByOrgResponse = { export type ListAllSiteResourcesByOrgResponse = PaginatedResponse<{
siteResources: (SiteResource & { siteResources: (SiteResource & {
siteName: string; siteName: string;
siteNiceId: string; siteNiceId: string;
siteAddress: string | null; siteAddress: string | null;
})[]; })[];
}; }>;
function querySiteResourcesBase() {
return db
.select({
siteResourceId: siteResources.siteResourceId,
siteId: siteResources.siteId,
orgId: siteResources.orgId,
niceId: siteResources.niceId,
name: siteResources.name,
mode: siteResources.mode,
protocol: siteResources.protocol,
proxyPort: siteResources.proxyPort,
destinationPort: siteResources.destinationPort,
destination: siteResources.destination,
enabled: siteResources.enabled,
alias: siteResources.alias,
aliasAddress: siteResources.aliasAddress,
tcpPortRangeString: siteResources.tcpPortRangeString,
udpPortRangeString: siteResources.udpPortRangeString,
disableIcmp: siteResources.disableIcmp,
siteName: sites.name,
siteNiceId: sites.niceId,
siteAddress: sites.address
})
.from(siteResources)
.innerJoin(sites, eq(siteResources.siteId, sites.siteId));
}
registry.registerPath({ registry.registerPath({
method: "get", method: "get",
@@ -80,39 +129,67 @@ export async function listAllSiteResourcesByOrg(
} }
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
const { limit, offset } = parsedQuery.data; const { page, pageSize, query, mode } = parsedQuery.data;
// Get all site resources for the org with site names const conditions = [and(eq(siteResources.orgId, orgId))];
const siteResourcesList = await db if (query) {
.select({ conditions.push(
siteResourceId: siteResources.siteResourceId, or(
siteId: siteResources.siteId, like(
orgId: siteResources.orgId, sql`LOWER(${siteResources.name})`,
niceId: siteResources.niceId, "%" + query.toLowerCase() + "%"
name: siteResources.name, ),
mode: siteResources.mode, like(
protocol: siteResources.protocol, sql`LOWER(${siteResources.niceId})`,
proxyPort: siteResources.proxyPort, "%" + query.toLowerCase() + "%"
destinationPort: siteResources.destinationPort, ),
destination: siteResources.destination, like(
enabled: siteResources.enabled, sql`LOWER(${siteResources.destination})`,
alias: siteResources.alias, "%" + query.toLowerCase() + "%"
aliasAddress: siteResources.aliasAddress, ),
tcpPortRangeString: siteResources.tcpPortRangeString, like(
udpPortRangeString: siteResources.udpPortRangeString, sql`LOWER(${siteResources.alias})`,
disableIcmp: siteResources.disableIcmp, "%" + query.toLowerCase() + "%"
siteName: sites.name, ),
siteNiceId: sites.niceId, like(
siteAddress: sites.address sql`LOWER(${siteResources.aliasAddress})`,
}) "%" + query.toLowerCase() + "%"
.from(siteResources) ),
.innerJoin(sites, eq(siteResources.siteId, sites.siteId)) like(
.where(eq(siteResources.orgId, orgId)) sql`LOWER(${sites.name})`,
.limit(limit) "%" + query.toLowerCase() + "%"
.offset(offset); )
)
);
}
return response(res, { if (mode) {
data: { siteResources: siteResourcesList }, conditions.push(eq(siteResources.mode, mode));
}
const baseQuery = querySiteResourcesBase().where(and(...conditions));
const countQuery = db.$count(
querySiteResourcesBase().where(and(...conditions))
);
const [siteResourcesList, totalCount] = await Promise.all([
baseQuery
.limit(pageSize)
.offset(pageSize * (page - 1))
.orderBy(asc(siteResources.siteResourceId)),
countQuery
]);
return response<ListAllSiteResourcesByOrgResponse>(res, {
data: {
siteResources: siteResourcesList,
pagination: {
total: totalCount,
pageSize,
page
}
},
success: true, success: true,
error: false, error: false,
message: "Site resources retrieved successfully", message: "Site resources retrieved successfully",

View File

@@ -41,6 +41,7 @@ const updateSiteResourceSchema = z
.strictObject({ .strictObject({
name: z.string().min(1).max(255).optional(), name: z.string().min(1).max(255).optional(),
siteId: z.int(), siteId: z.int(),
// niceId: z.string().min(1).max(255).regex(/^[a-zA-Z0-9-]+$/, "niceId can only contain letters, numbers, and dashes").optional(),
// mode: z.enum(["host", "cidr", "port"]).optional(), // mode: z.enum(["host", "cidr", "port"]).optional(),
mode: z.enum(["host", "cidr"]).optional(), mode: z.enum(["host", "cidr"]).optional(),
// protocol: z.enum(["tcp", "udp"]).nullish(), // protocol: z.enum(["tcp", "udp"]).nullish(),

View File

@@ -105,7 +105,10 @@ export const handleHealthcheckStatusMessage: MessageHandler = async (
await db await db
.update(targetHealthCheck) .update(targetHealthCheck)
.set({ .set({
hcHealth: healthStatus.status hcHealth: healthStatus.status as
| "unknown"
| "healthy"
| "unhealthy"
}) })
.where(eq(targetHealthCheck.targetId, targetIdNum)) .where(eq(targetHealthCheck.targetId, targetIdNum))
.execute(); .execute();

View File

@@ -1,8 +1,8 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db, UserOrg } from "@server/db"; import { db, orgs, UserOrg } from "@server/db";
import { roles, userInvites, userOrgs, users } from "@server/db"; import { roles, userInvites, userOrgs, users } from "@server/db";
import { eq } from "drizzle-orm"; import { eq, and, inArray, ne } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
@@ -14,6 +14,7 @@ import { usageService } from "@server/lib/billing/usageService";
import { FeatureId } from "@server/lib/billing"; import { FeatureId } from "@server/lib/billing";
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
import { build } from "@server/build"; import { build } from "@server/build";
import { assignUserToOrg } from "@server/lib/userOrg";
const acceptInviteBodySchema = z.strictObject({ const acceptInviteBodySchema = z.strictObject({
token: z.string(), token: z.string(),
@@ -125,8 +126,22 @@ export async function acceptInvite(
} }
} }
const [org] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, existingInvite.orgId))
.limit(1);
if (!org) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Organization does not exist. Please contact an admin."
)
);
}
let roleId: number; let roleId: number;
let totalUsers: UserOrg[] | undefined;
// get the role to make sure it exists // get the role to make sure it exists
const existingRole = await db const existingRole = await db
.select() .select()
@@ -146,12 +161,15 @@ export async function acceptInvite(
} }
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
// add the user to the org await assignUserToOrg(
await trx.insert(userOrgs).values({ org,
{
userId: existingUser[0].userId, userId: existingUser[0].userId,
orgId: existingInvite.orgId, orgId: existingInvite.orgId,
roleId: existingInvite.roleId roleId: existingInvite.roleId
}); },
trx
);
// delete the invite // delete the invite
await trx await trx
@@ -160,25 +178,11 @@ export async function acceptInvite(
await calculateUserClientsForOrgs(existingUser[0].userId, trx); await calculateUserClientsForOrgs(existingUser[0].userId, trx);
// Get the total number of users in the org now
totalUsers = await trx
.select()
.from(userOrgs)
.where(eq(userOrgs.orgId, existingInvite.orgId));
logger.debug( logger.debug(
`User ${existingUser[0].userId} accepted invite to org ${existingInvite.orgId}. Total users in org: ${totalUsers.length}` `User ${existingUser[0].userId} accepted invite to org ${existingInvite.orgId}`
); );
}); });
if (totalUsers) {
await usageService.updateCount(
existingInvite.orgId,
FeatureId.USERS,
totalUsers.length
);
}
return response<AcceptInviteResponse>(res, { return response<AcceptInviteResponse>(res, {
data: { accepted: true, orgId: existingInvite.orgId }, data: { accepted: true, orgId: existingInvite.orgId },
success: true, success: true,

View File

@@ -6,8 +6,8 @@ import createHttpError from "http-errors";
import logger from "@server/logger"; import logger from "@server/logger";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import { db, UserOrg } from "@server/db"; import { db, orgs, UserOrg } from "@server/db";
import { and, eq } from "drizzle-orm"; import { and, eq, inArray, ne } from "drizzle-orm";
import { idp, idpOidcConfig, roles, userOrgs, users } from "@server/db"; import { idp, idpOidcConfig, roles, userOrgs, users } from "@server/db";
import { generateId } from "@server/auth/sessions/app"; import { generateId } from "@server/auth/sessions/app";
import { usageService } from "@server/lib/billing/usageService"; import { usageService } from "@server/lib/billing/usageService";
@@ -16,6 +16,7 @@ import { build } from "@server/build";
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
import { isSubscribed } from "#dynamic/lib/isSubscribed"; import { isSubscribed } from "#dynamic/lib/isSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix"; import { tierMatrix } from "@server/lib/billing/tierMatrix";
import { assignUserToOrg } from "@server/lib/userOrg";
const paramsSchema = z.strictObject({ const paramsSchema = z.strictObject({
orgId: z.string().nonempty() orgId: z.string().nonempty()
@@ -151,6 +152,21 @@ export async function createOrgUser(
); );
} }
const [org] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (!org) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"Organization not found"
)
);
}
const [idpRes] = await db const [idpRes] = await db
.select() .select()
.from(idp) .from(idp)
@@ -172,8 +188,6 @@ export async function createOrgUser(
); );
} }
let orgUsers: UserOrg[] | undefined;
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
const [existingUser] = await trx const [existingUser] = await trx
.select() .select()
@@ -207,15 +221,12 @@ export async function createOrgUser(
); );
} }
await trx await assignUserToOrg(org, {
.insert(userOrgs)
.values({
orgId, orgId,
userId: existingUser.userId, userId: existingUser.userId,
roleId: role.roleId, roleId: role.roleId,
autoProvisioned: false autoProvisioned: false
}) }, trx);
.returning();
} else { } else {
userId = generateId(15); userId = generateId(15);
@@ -233,33 +244,16 @@ export async function createOrgUser(
}) })
.returning(); .returning();
await trx await assignUserToOrg(org, {
.insert(userOrgs)
.values({
orgId, orgId,
userId: newUser.userId, userId: newUser.userId,
roleId: role.roleId, roleId: role.roleId,
autoProvisioned: false autoProvisioned: false
}) }, trx);
.returning();
} }
// List all of the users in the org
orgUsers = await trx
.select()
.from(userOrgs)
.where(eq(userOrgs.orgId, orgId));
await calculateUserClientsForOrgs(userId, trx); await calculateUserClientsForOrgs(userId, trx);
}); });
if (orgUsers) {
await usageService.updateCount(
orgId,
FeatureId.USERS,
orgUsers.length
);
}
} else { } else {
return next( return next(
createHttpError(HttpCode.BAD_REQUEST, "User type is required") createHttpError(HttpCode.BAD_REQUEST, "User type is required")

View File

@@ -1,8 +1,16 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db, resources, sites, UserOrg } from "@server/db"; import {
db,
orgs,
resources,
siteResources,
sites,
UserOrg,
userSiteResources
} from "@server/db";
import { userOrgs, userResources, users, userSites } from "@server/db"; import { userOrgs, userResources, users, userSites } from "@server/db";
import { and, count, eq, exists } from "drizzle-orm"; import { and, count, eq, exists, inArray } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
@@ -14,6 +22,7 @@ import { FeatureId } from "@server/lib/billing";
import { build } from "@server/build"; import { build } from "@server/build";
import { UserType } from "@server/types/UserTypes"; import { UserType } from "@server/types/UserTypes";
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
import { removeUserFromOrg } from "@server/lib/userOrg";
const removeUserSchema = z.strictObject({ const removeUserSchema = z.strictObject({
userId: z.string(), userId: z.string(),
@@ -50,16 +59,16 @@ export async function removeUserOrg(
const { userId, orgId } = parsedParams.data; const { userId, orgId } = parsedParams.data;
// get the user first // get the user first
const user = await db const [user] = await db
.select() .select()
.from(userOrgs) .from(userOrgs)
.where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, orgId))); .where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, orgId)));
if (!user || user.length === 0) { if (!user) {
return next(createHttpError(HttpCode.NOT_FOUND, "User not found")); return next(createHttpError(HttpCode.NOT_FOUND, "User not found"));
} }
if (user[0].isOwner) { if (user.isOwner) {
return next( return next(
createHttpError( createHttpError(
HttpCode.BAD_REQUEST, HttpCode.BAD_REQUEST,
@@ -68,56 +77,20 @@ export async function removeUserOrg(
); );
} }
let userCount: UserOrg[] | undefined; const [org] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (!org) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Organization not found")
);
}
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
await trx await removeUserFromOrg(org, userId, trx);
.delete(userOrgs)
.where(
and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, orgId))
);
await db.delete(userResources).where(
and(
eq(userResources.userId, userId),
exists(
db
.select()
.from(resources)
.where(
and(
eq(
resources.resourceId,
userResources.resourceId
),
eq(resources.orgId, orgId)
)
)
)
)
);
await db.delete(userSites).where(
and(
eq(userSites.userId, userId),
exists(
db
.select()
.from(sites)
.where(
and(
eq(sites.siteId, userSites.siteId),
eq(sites.orgId, orgId)
)
)
)
)
);
userCount = await trx
.select()
.from(userOrgs)
.where(eq(userOrgs.orgId, orgId));
// if (build === "saas") { // if (build === "saas") {
// const [rootUser] = await trx // const [rootUser] = await trx
@@ -139,14 +112,6 @@ export async function removeUserOrg(
await calculateUserClientsForOrgs(userId, trx); await calculateUserClientsForOrgs(userId, trx);
}); });
if (userCount) {
await usageService.updateCount(
orgId,
FeatureId.USERS,
userCount.length
);
}
return response(res, { return response(res, {
data: null, data: null,
success: true, success: true,

View File

@@ -0,0 +1,85 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, roundTripMessageTracker } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { eq } from "drizzle-orm";
import { OpenAPITags, registry } from "@server/openApi";
const checkRoundTripMessageParamsSchema = z
.object({
messageId: z
.string()
.transform(Number)
.pipe(z.number().int().positive())
})
.strict();
// registry.registerPath({
// method: "get",
// path: "/ws/round-trip-message/{messageId}",
// description:
// "Check if a round trip message has been completed by checking the roundTripMessageTracker table",
// tags: [OpenAPITags.WebSocket],
// request: {
// params: checkRoundTripMessageParamsSchema
// },
// responses: {}
// });
export async function checkRoundTripMessage(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = checkRoundTripMessageParamsSchema.safeParse(
req.params
);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { messageId } = parsedParams.data;
// Get the round trip message from the tracker
const [message] = await db
.select()
.from(roundTripMessageTracker)
.where(eq(roundTripMessageTracker.messageId, messageId))
.limit(1);
if (!message) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Message not found")
);
}
return response(res, {
data: {
messageId: message.messageId,
complete: message.complete,
sentAt: message.sentAt,
receivedAt: message.receivedAt,
error: message.error,
},
success: true,
error: false,
message: "Round trip message status retrieved successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -0,0 +1,49 @@
import { db, roundTripMessageTracker } from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import { eq } from "drizzle-orm";
import logger from "@server/logger";
interface RoundTripCompleteMessage {
messageId: number;
complete: boolean;
error?: string;
}
export const handleRoundTripMessage: MessageHandler = async (
context
) => {
const { message, client: c } = context;
logger.info("Handling round trip message");
const data = message.data as RoundTripCompleteMessage;
try {
const { messageId, complete, error } = data;
if (!messageId) {
logger.error("Round trip message missing messageId");
return;
}
// Update the roundTripMessageTracker with completion status
await db
.update(roundTripMessageTracker)
.set({
complete: complete,
receivedAt: Math.floor(Date.now() / 1000),
error: error || null
})
.where(eq(roundTripMessageTracker.messageId, messageId));
logger.info(`Round trip message ${messageId} marked as complete: ${complete}`);
if (error) {
logger.warn(`Round trip message ${messageId} completed with error: ${error}`);
}
} catch (error) {
logger.error("Error processing round trip message:", error);
}
return;
};

View File

@@ -1,2 +1,3 @@
export * from "./ws"; export * from "./ws";
export * from "./types"; export * from "./types";
export * from "./checkRoundTripMessage";

View File

@@ -18,6 +18,7 @@ import {
handleOlmDisconnecingMessage handleOlmDisconnecingMessage
} from "../olm"; } from "../olm";
import { handleHealthcheckStatusMessage } from "../target"; import { handleHealthcheckStatusMessage } from "../target";
import { handleRoundTripMessage } from "./handleRoundTripMessage";
import { MessageHandler } from "./types"; import { MessageHandler } from "./types";
export const messageHandlers: Record<string, MessageHandler> = { export const messageHandlers: Record<string, MessageHandler> = {
@@ -35,7 +36,8 @@ export const messageHandlers: Record<string, MessageHandler> = {
"newt/socket/containers": handleDockerContainersMessage, "newt/socket/containers": handleDockerContainersMessage,
"newt/ping/request": handleNewtPingRequestMessage, "newt/ping/request": handleNewtPingRequestMessage,
"newt/blueprint/apply": handleApplyBlueprintMessage, "newt/blueprint/apply": handleApplyBlueprintMessage,
"newt/healthcheck/status": handleHealthcheckStatusMessage "newt/healthcheck/status": handleHealthcheckStatusMessage,
"ws/round-trip/complete": handleRoundTripMessage
}; };
startOlmOfflineChecker(); // this is to handle the offline check for olms startOlmOfflineChecker(); // this is to handle the offline check for olms

View File

@@ -0,0 +1,5 @@
export type Pagination = { total: number; pageSize: number; page: number };
export type PaginatedResponse<T> = T & {
pagination: Pagination;
};

View File

@@ -6,6 +6,7 @@ import { redirect } from "next/navigation";
import { getTranslations } from "next-intl/server"; import { getTranslations } from "next-intl/server";
import { getCachedOrgUser } from "@app/lib/api/getCachedOrgUser"; import { getCachedOrgUser } from "@app/lib/api/getCachedOrgUser";
import { getCachedOrg } from "@app/lib/api/getCachedOrg"; import { getCachedOrg } from "@app/lib/api/getCachedOrg";
import { build } from "@server/build";
type BillingSettingsProps = { type BillingSettingsProps = {
children: React.ReactNode; children: React.ReactNode;
@@ -17,6 +18,9 @@ export default async function BillingSettingsPage({
params params
}: BillingSettingsProps) { }: BillingSettingsProps) {
const { orgId } = await params; const { orgId } = await params;
if (build !== "saas") {
redirect(`/${orgId}/settings`);
}
const user = await verifySession(); const user = await verifySession();
@@ -40,6 +44,10 @@ export default async function BillingSettingsPage({
redirect(`/${orgId}`); redirect(`/${orgId}`);
} }
if (!(org?.org?.isBillingOrg && orgUser?.isOwner)) {
redirect(`/${orgId}`);
}
const t = await getTranslations(); const t = await getTranslations();
return ( return (

View File

@@ -110,37 +110,42 @@ const planOptions: PlanOption[] = [
// Tier limits mapping derived from limit sets // Tier limits mapping derived from limit sets
const tierLimits: Record< const tierLimits: Record<
Tier | "basic", Tier | "basic",
{ users: number; sites: number; domains: number; remoteNodes: number } { users: number; sites: number; domains: number; remoteNodes: number; organizations: number }
> = { > = {
basic: { basic: {
users: freeLimitSet[FeatureId.USERS]?.value ?? 0, users: freeLimitSet[FeatureId.USERS]?.value ?? 0,
sites: freeLimitSet[FeatureId.SITES]?.value ?? 0, sites: freeLimitSet[FeatureId.SITES]?.value ?? 0,
domains: freeLimitSet[FeatureId.DOMAINS]?.value ?? 0, domains: freeLimitSet[FeatureId.DOMAINS]?.value ?? 0,
remoteNodes: freeLimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0 remoteNodes: freeLimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0,
organizations: freeLimitSet[FeatureId.ORGINIZATIONS]?.value ?? 0
}, },
tier1: { tier1: {
users: tier1LimitSet[FeatureId.USERS]?.value ?? 0, users: tier1LimitSet[FeatureId.USERS]?.value ?? 0,
sites: tier1LimitSet[FeatureId.SITES]?.value ?? 0, sites: tier1LimitSet[FeatureId.SITES]?.value ?? 0,
domains: tier1LimitSet[FeatureId.DOMAINS]?.value ?? 0, domains: tier1LimitSet[FeatureId.DOMAINS]?.value ?? 0,
remoteNodes: tier1LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0 remoteNodes: tier1LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0,
organizations: tier1LimitSet[FeatureId.ORGINIZATIONS]?.value ?? 0
}, },
tier2: { tier2: {
users: tier2LimitSet[FeatureId.USERS]?.value ?? 0, users: tier2LimitSet[FeatureId.USERS]?.value ?? 0,
sites: tier2LimitSet[FeatureId.SITES]?.value ?? 0, sites: tier2LimitSet[FeatureId.SITES]?.value ?? 0,
domains: tier2LimitSet[FeatureId.DOMAINS]?.value ?? 0, domains: tier2LimitSet[FeatureId.DOMAINS]?.value ?? 0,
remoteNodes: tier2LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0 remoteNodes: tier2LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0,
organizations: tier2LimitSet[FeatureId.ORGINIZATIONS]?.value ?? 0
}, },
tier3: { tier3: {
users: tier3LimitSet[FeatureId.USERS]?.value ?? 0, users: tier3LimitSet[FeatureId.USERS]?.value ?? 0,
sites: tier3LimitSet[FeatureId.SITES]?.value ?? 0, sites: tier3LimitSet[FeatureId.SITES]?.value ?? 0,
domains: tier3LimitSet[FeatureId.DOMAINS]?.value ?? 0, domains: tier3LimitSet[FeatureId.DOMAINS]?.value ?? 0,
remoteNodes: tier3LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0 remoteNodes: tier3LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0,
organizations: tier3LimitSet[FeatureId.ORGINIZATIONS]?.value ?? 0
}, },
enterprise: { enterprise: {
users: 0, // Custom for enterprise users: 0, // Custom for enterprise
sites: 0, // Custom for enterprise sites: 0, // Custom for enterprise
domains: 0, // Custom for enterprise domains: 0, // Custom for enterprise
remoteNodes: 0 // Custom for enterprise remoteNodes: 0, // Custom for enterprise
organizations: 0 // Custom for enterprise
} }
}; };
@@ -179,6 +184,7 @@ export default function BillingPage() {
const SITES = "sites"; const SITES = "sites";
const DOMAINS = "domains"; const DOMAINS = "domains";
const REMOTE_EXIT_NODES = "remoteExitNodes"; const REMOTE_EXIT_NODES = "remoteExitNodes";
const ORGINIZATIONS = "organizations";
// Confirmation dialog state // Confirmation dialog state
const [showConfirmDialog, setShowConfirmDialog] = useState(false); const [showConfirmDialog, setShowConfirmDialog] = useState(false);
@@ -619,6 +625,16 @@ export default function BillingPage() {
}); });
} }
// Check organizations
const organizationsUsage = getUsageValue(ORGINIZATIONS);
if (limits.organizations > 0 && organizationsUsage > limits.organizations) {
violations.push({
feature: "Organizations",
currentUsage: organizationsUsage,
newLimit: limits.organizations
});
}
return violations; return violations;
}; };
@@ -752,7 +768,7 @@ export default function BillingPage() {
<div className="text-sm text-muted-foreground mb-3"> <div className="text-sm text-muted-foreground mb-3">
{t("billingMaximumLimits") || "Maximum Limits"} {t("billingMaximumLimits") || "Maximum Limits"}
</div> </div>
<InfoSections cols={4}> <InfoSections cols={5}>
<InfoSection> <InfoSection>
<InfoSectionTitle className="flex items-center gap-1 text-xs"> <InfoSectionTitle className="flex items-center gap-1 text-xs">
{t("billingUsers") || "Users"} {t("billingUsers") || "Users"}
@@ -855,6 +871,41 @@ export default function BillingPage() {
)} )}
</InfoSectionContent> </InfoSectionContent>
</InfoSection> </InfoSection>
<InfoSection>
<InfoSectionTitle className="flex items-center gap-1 text-xs">
{t("billingOrganizations") ||
"Organizations"}
</InfoSectionTitle>
<InfoSectionContent className="text-sm">
{isOverLimit(ORGINIZATIONS) ? (
<Tooltip>
<TooltipTrigger className="flex items-center gap-1">
<AlertTriangle className="h-3 w-3 text-orange-400" />
<span className={cn(
"text-orange-600 dark:text-orange-400 font-medium"
)}>
{getLimitValue(ORGINIZATIONS) ??
t("billingUnlimited") ??
"∞"}{" "}
{getLimitValue(ORGINIZATIONS) !==
null && "orgs"}
</span>
</TooltipTrigger>
<TooltipContent>
<p>{t("billingUsageExceedsLimit", { current: getUsageValue(ORGINIZATIONS), limit: getLimitValue(ORGINIZATIONS) ?? 0 }) || `Current usage (${getUsageValue(ORGINIZATIONS)}) exceeds limit (${getLimitValue(ORGINIZATIONS)})`}</p>
</TooltipContent>
</Tooltip>
) : (
<>
{getLimitValue(ORGINIZATIONS) ??
t("billingUnlimited") ??
"∞"}{" "}
{getLimitValue(ORGINIZATIONS) !==
null && "orgs"}
</>
)}
</InfoSectionContent>
</InfoSection>
<InfoSection> <InfoSection>
<InfoSectionTitle className="flex items-center gap-1 text-xs"> <InfoSectionTitle className="flex items-center gap-1 text-xs">
{t("billingRemoteNodes") || {t("billingRemoteNodes") ||
@@ -872,7 +923,7 @@ export default function BillingPage() {
t("billingUnlimited") ?? t("billingUnlimited") ??
"∞"}{" "} "∞"}{" "}
{getLimitValue(REMOTE_EXIT_NODES) !== {getLimitValue(REMOTE_EXIT_NODES) !==
null && "remote nodes"} null && "nodes"}
</span> </span>
</TooltipTrigger> </TooltipTrigger>
<TooltipContent> <TooltipContent>
@@ -885,7 +936,7 @@ export default function BillingPage() {
t("billingUnlimited") ?? t("billingUnlimited") ??
"∞"}{" "} "∞"}{" "}
{getLimitValue(REMOTE_EXIT_NODES) !== {getLimitValue(REMOTE_EXIT_NODES) !==
null && "remote nodes"} null && "nodes"}
</> </>
)} )}
</InfoSectionContent> </InfoSectionContent>
@@ -1016,6 +1067,17 @@ export default function BillingPage() {
"Domains"} "Domains"}
</span> </span>
</div> </div>
<div className="flex items-center gap-2">
<Check className="h-4 w-4 text-green-600" />
<span>
{
tierLimits[pendingTier.tier]
.organizations
}{" "}
{t("billingOrganizations") ||
"Organizations"}
</span>
</div>
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
<Check className="h-4 w-4 text-green-600" /> <Check className="h-4 w-4 text-green-600" />
<span> <span>

View File

@@ -4,6 +4,8 @@ import { redirect } from "next/navigation";
import { cache } from "react"; import { cache } from "react";
import { getTranslations } from "next-intl/server"; import { getTranslations } from "next-intl/server";
import { build } from "@server/build"; import { build } from "@server/build";
import { getCachedOrgUser } from "@app/lib/api/getCachedOrgUser";
import { getCachedOrg } from "@app/lib/api/getCachedOrg";
type LicensesSettingsProps = { type LicensesSettingsProps = {
children: React.ReactNode; children: React.ReactNode;
@@ -27,6 +29,26 @@ export default async function LicensesSetingsLayoutProps({
redirect(`/`); redirect(`/`);
} }
let orgUser = null;
try {
const res = await getCachedOrgUser(orgId, user.userId);
orgUser = res.data.data;
} catch {
redirect(`/${orgId}`);
}
let org = null;
try {
const res = await getCachedOrg(orgId);
org = res.data.data;
} catch {
redirect(`/${orgId}`);
}
if (!org?.org?.isBillingOrg || !orgUser?.isOwner) {
redirect(`/${orgId}`);
}
const t = await getTranslations(); const t = await getTranslations();
return ( return (

View File

@@ -7,10 +7,11 @@ import { authCookieHeader } from "@app/lib/api/cookies";
import { ListClientsResponse } from "@server/routers/client"; import { ListClientsResponse } from "@server/routers/client";
import { AxiosResponse } from "axios"; import { AxiosResponse } from "axios";
import { getTranslations } from "next-intl/server"; import { getTranslations } from "next-intl/server";
import type { Pagination } from "@server/types/Pagination";
type ClientsPageProps = { type ClientsPageProps = {
params: Promise<{ orgId: string }>; params: Promise<{ orgId: string }>;
searchParams: Promise<{ view?: string }>; searchParams: Promise<Record<string, string>>;
}; };
export const dynamic = "force-dynamic"; export const dynamic = "force-dynamic";
@@ -19,17 +20,25 @@ export default async function ClientsPage(props: ClientsPageProps) {
const t = await getTranslations(); const t = await getTranslations();
const params = await props.params; const params = await props.params;
const searchParams = new URLSearchParams(await props.searchParams);
let machineClients: ListClientsResponse["clients"] = []; let machineClients: ListClientsResponse["clients"] = [];
let pagination: Pagination = {
page: 1,
total: 0,
pageSize: 20
};
try { try {
const machineRes = await internal.get< const machineRes = await internal.get<
AxiosResponse<ListClientsResponse> AxiosResponse<ListClientsResponse>
>( >(
`/org/${params.orgId}/clients?filter=machine`, `/org/${params.orgId}/clients?${searchParams.toString()}`,
await authCookieHeader() await authCookieHeader()
); );
machineClients = machineRes.data.data.clients; const responseData = machineRes.data.data;
machineClients = responseData.clients;
pagination = responseData.pagination;
} catch (e) {} } catch (e) {}
function formatSize(mb: number): string { function formatSize(mb: number): string {
@@ -80,6 +89,11 @@ export default async function ClientsPage(props: ClientsPageProps) {
<MachineClientsTable <MachineClientsTable
machineClients={machineClientRows} machineClients={machineClientRows}
orgId={params.orgId} orgId={params.orgId}
rowCount={pagination.total}
pagination={{
pageIndex: pagination.page - 1,
pageSize: pagination.pageSize
}}
/> />
</> </>
); );

View File

@@ -602,7 +602,8 @@ export default function GeneralPage() {
) )
? formatPostureValue( ? formatPostureValue(
client.posture client.posture
.biometricsEnabled .biometricsEnabled ===
true
) )
: "-"} : "-"}
</InfoSectionContent> </InfoSectionContent>
@@ -622,7 +623,8 @@ export default function GeneralPage() {
) )
? formatPostureValue( ? formatPostureValue(
client.posture client.posture
.diskEncrypted .diskEncrypted ===
true
) )
: "-"} : "-"}
</InfoSectionContent> </InfoSectionContent>
@@ -642,7 +644,8 @@ export default function GeneralPage() {
) )
? formatPostureValue( ? formatPostureValue(
client.posture client.posture
.firewallEnabled .firewallEnabled ===
true
) )
: "-"} : "-"}
</InfoSectionContent> </InfoSectionContent>
@@ -663,7 +666,8 @@ export default function GeneralPage() {
) )
? formatPostureValue( ? formatPostureValue(
client.posture client.posture
.autoUpdatesEnabled .autoUpdatesEnabled ===
true
) )
: "-"} : "-"}
</InfoSectionContent> </InfoSectionContent>
@@ -683,7 +687,8 @@ export default function GeneralPage() {
) )
? formatPostureValue( ? formatPostureValue(
client.posture client.posture
.tpmAvailable .tpmAvailable ===
true
) )
: "-"} : "-"}
</InfoSectionContent> </InfoSectionContent>
@@ -707,7 +712,8 @@ export default function GeneralPage() {
) )
? formatPostureValue( ? formatPostureValue(
client.posture client.posture
.windowsAntivirusEnabled .windowsAntivirusEnabled ===
true
) )
: "-"} : "-"}
</InfoSectionContent> </InfoSectionContent>
@@ -727,7 +733,8 @@ export default function GeneralPage() {
) )
? formatPostureValue( ? formatPostureValue(
client.posture client.posture
.macosSipEnabled .macosSipEnabled ===
true
) )
: "-"} : "-"}
</InfoSectionContent> </InfoSectionContent>
@@ -751,7 +758,8 @@ export default function GeneralPage() {
) )
? formatPostureValue( ? formatPostureValue(
client.posture client.posture
.macosGatekeeperEnabled .macosGatekeeperEnabled ===
true
) )
: "-"} : "-"}
</InfoSectionContent> </InfoSectionContent>
@@ -775,7 +783,8 @@ export default function GeneralPage() {
) )
? formatPostureValue( ? formatPostureValue(
client.posture client.posture
.macosFirewallStealthMode .macosFirewallStealthMode ===
true
) )
: "-"} : "-"}
</InfoSectionContent> </InfoSectionContent>
@@ -796,7 +805,8 @@ export default function GeneralPage() {
) )
? formatPostureValue( ? formatPostureValue(
client.posture client.posture
.linuxAppArmorEnabled .linuxAppArmorEnabled ===
true
) )
: "-"} : "-"}
</InfoSectionContent> </InfoSectionContent>
@@ -817,7 +827,8 @@ export default function GeneralPage() {
) )
? formatPostureValue( ? formatPostureValue(
client.posture client.posture
.linuxSELinuxEnabled .linuxSELinuxEnabled ===
true
) )
: "-"} : "-"}
</InfoSectionContent> </InfoSectionContent>

View File

@@ -1,14 +1,16 @@
import { internal } from "@app/lib/api";
import { authCookieHeader } from "@app/lib/api/cookies";
import { AxiosResponse } from "axios";
import SettingsSectionTitle from "@app/components/SettingsSectionTitle"; import SettingsSectionTitle from "@app/components/SettingsSectionTitle";
import { ListClientsResponse } from "@server/routers/client";
import { getTranslations } from "next-intl/server";
import type { ClientRow } from "@app/components/UserDevicesTable"; import type { ClientRow } from "@app/components/UserDevicesTable";
import UserDevicesTable from "@app/components/UserDevicesTable"; import UserDevicesTable from "@app/components/UserDevicesTable";
import { internal } from "@app/lib/api";
import { authCookieHeader } from "@app/lib/api/cookies";
import { type ListUserDevicesResponse } from "@server/routers/client";
import type { Pagination } from "@server/types/Pagination";
import { AxiosResponse } from "axios";
import { getTranslations } from "next-intl/server";
type ClientsPageProps = { type ClientsPageProps = {
params: Promise<{ orgId: string }>; params: Promise<{ orgId: string }>;
searchParams: Promise<Record<string, string>>;
}; };
export const dynamic = "force-dynamic"; export const dynamic = "force-dynamic";
@@ -17,15 +19,26 @@ export default async function ClientsPage(props: ClientsPageProps) {
const t = await getTranslations(); const t = await getTranslations();
const params = await props.params; const params = await props.params;
const searchParams = new URLSearchParams(await props.searchParams);
let userClients: ListClientsResponse["clients"] = []; let userClients: ListUserDevicesResponse["devices"] = [];
let pagination: Pagination = {
page: 1,
total: 0,
pageSize: 20
};
try { try {
const userRes = await internal.get<AxiosResponse<ListClientsResponse>>( const userRes = await internal.get<
`/org/${params.orgId}/clients?filter=user`, AxiosResponse<ListUserDevicesResponse>
>(
`/org/${params.orgId}/user-devices?${searchParams.toString()}`,
await authCookieHeader() await authCookieHeader()
); );
userClients = userRes.data.data.clients; const responseData = userRes.data.data;
userClients = responseData.devices;
pagination = responseData.pagination;
} catch (e) {} } catch (e) {}
function formatSize(mb: number): string { function formatSize(mb: number): string {
@@ -39,31 +52,29 @@ export default async function ClientsPage(props: ClientsPageProps) {
} }
const mapClientToRow = ( const mapClientToRow = (
client: ListClientsResponse["clients"][0] client: ListUserDevicesResponse["devices"][number]
): ClientRow => { ): ClientRow => {
// Build fingerprint object if any fingerprint data exists // Build fingerprint object if any fingerprint data exists
const hasFingerprintData = const hasFingerprintData =
(client as any).fingerprintPlatform || client.fingerprintPlatform ||
(client as any).fingerprintOsVersion || client.fingerprintOsVersion ||
(client as any).fingerprintKernelVersion || client.fingerprintKernelVersion ||
(client as any).fingerprintArch || client.fingerprintArch ||
(client as any).fingerprintSerialNumber || client.fingerprintSerialNumber ||
(client as any).fingerprintUsername || client.fingerprintUsername ||
(client as any).fingerprintHostname || client.fingerprintHostname ||
(client as any).deviceModel; client.deviceModel;
const fingerprint = hasFingerprintData const fingerprint = hasFingerprintData
? { ? {
platform: (client as any).fingerprintPlatform || null, platform: client.fingerprintPlatform,
osVersion: (client as any).fingerprintOsVersion || null, osVersion: client.fingerprintOsVersion,
kernelVersion: kernelVersion: client.fingerprintKernelVersion,
(client as any).fingerprintKernelVersion || null, arch: client.fingerprintArch,
arch: (client as any).fingerprintArch || null, deviceModel: client.deviceModel,
deviceModel: (client as any).deviceModel || null, serialNumber: client.fingerprintSerialNumber,
serialNumber: username: client.fingerprintUsername,
(client as any).fingerprintSerialNumber || null, hostname: client.fingerprintHostname
username: (client as any).fingerprintUsername || null,
hostname: (client as any).fingerprintHostname || null
} }
: null; : null;
@@ -71,19 +82,19 @@ export default async function ClientsPage(props: ClientsPageProps) {
name: client.name, name: client.name,
id: client.clientId, id: client.clientId,
subnet: client.subnet.split("/")[0], subnet: client.subnet.split("/")[0],
mbIn: formatSize(client.megabytesIn || 0), mbIn: formatSize(client.megabytesIn ?? 0),
mbOut: formatSize(client.megabytesOut || 0), mbOut: formatSize(client.megabytesOut ?? 0),
orgId: params.orgId, orgId: params.orgId,
online: client.online, online: client.online,
olmVersion: client.olmVersion || undefined, olmVersion: client.olmVersion || undefined,
olmUpdateAvailable: client.olmUpdateAvailable || false, olmUpdateAvailable: Boolean(client.olmUpdateAvailable),
userId: client.userId, userId: client.userId,
username: client.username, username: client.username,
userEmail: client.userEmail, userEmail: client.userEmail,
niceId: client.niceId, niceId: client.niceId,
agent: client.agent, agent: client.agent,
archived: client.archived || false, archived: Boolean(client.archived),
blocked: client.blocked || false, blocked: Boolean(client.blocked),
approvalState: client.approvalState, approvalState: client.approvalState,
fingerprint fingerprint
}; };
@@ -101,6 +112,11 @@ export default async function ClientsPage(props: ClientsPageProps) {
<UserDevicesTable <UserDevicesTable
userClients={userClientRows} userClients={userClientRows}
orgId={params.orgId} orgId={params.orgId}
rowCount={pagination.total}
pagination={{
pageIndex: pagination.page - 1,
pageSize: pagination.pageSize
}}
/> />
</> </>
); );

View File

@@ -3,11 +3,7 @@ import ConfirmDeleteDialog from "@app/components/ConfirmDeleteDialog";
import { Button } from "@app/components/ui/button"; import { Button } from "@app/components/ui/button";
import { useOrgContext } from "@app/hooks/useOrgContext"; import { useOrgContext } from "@app/hooks/useOrgContext";
import { toast } from "@app/hooks/useToast"; import { toast } from "@app/hooks/useToast";
import { import { useState, useTransition, useActionState } from "react";
useState,
useTransition,
useActionState
} from "react";
import { import {
Form, Form,
FormControl, FormControl,
@@ -54,7 +50,7 @@ export default function GeneralPage() {
return ( return (
<SettingsContainer> <SettingsContainer>
<GeneralSectionForm org={org.org} /> <GeneralSectionForm org={org.org} />
{build !== "saas" && <DeleteForm org={org.org} />} {!org.org.isBillingOrg && <DeleteForm org={org.org} />}
</SettingsContainer> </SettingsContainer>
); );
} }

View File

@@ -77,12 +77,16 @@ export default async function SettingsLayout(props: SettingsLayoutProps) {
} }
} catch (e) {} } catch (e) {}
const primaryOrg = orgs.find((o) => o.orgId === params.orgId)?.isPrimaryOrg;
return ( return (
<UserProvider user={user}> <UserProvider user={user}>
<Layout <Layout
orgId={params.orgId} orgId={params.orgId}
orgs={orgs} orgs={orgs}
navItems={orgNavSections(env)} navItems={orgNavSections(env, {
isPrimaryOrg: primaryOrg
})}
> >
{children} {children}
</Layout> </Layout>

View File

@@ -14,7 +14,7 @@ import { redirect } from "next/navigation";
export interface ClientResourcesPageProps { export interface ClientResourcesPageProps {
params: Promise<{ orgId: string }>; params: Promise<{ orgId: string }>;
searchParams: Promise<{ view?: string }>; searchParams: Promise<Record<string, string>>;
} }
export default async function ClientResourcesPage( export default async function ClientResourcesPage(
@@ -22,22 +22,24 @@ export default async function ClientResourcesPage(
) { ) {
const params = await props.params; const params = await props.params;
const t = await getTranslations(); const t = await getTranslations();
const searchParams = new URLSearchParams(await props.searchParams);
let resources: ListResourcesResponse["resources"] = [];
try {
const res = await internal.get<AxiosResponse<ListResourcesResponse>>(
`/org/${params.orgId}/resources`,
await authCookieHeader()
);
resources = res.data.data.resources;
} catch (e) {}
let siteResources: ListAllSiteResourcesByOrgResponse["siteResources"] = []; let siteResources: ListAllSiteResourcesByOrgResponse["siteResources"] = [];
let pagination: ListResourcesResponse["pagination"] = {
total: 0,
page: 1,
pageSize: 20
};
try { try {
const res = await internal.get< const res = await internal.get<
AxiosResponse<ListAllSiteResourcesByOrgResponse> AxiosResponse<ListAllSiteResourcesByOrgResponse>
>(`/org/${params.orgId}/site-resources`, await authCookieHeader()); >(
siteResources = res.data.data.siteResources; `/org/${params.orgId}/site-resources?${searchParams.toString()}`,
await authCookieHeader()
);
const responseData = res.data.data;
siteResources = responseData.siteResources;
pagination = responseData.pagination;
} catch (e) {} } catch (e) {}
let org = null; let org = null;
@@ -89,9 +91,10 @@ export default async function ClientResourcesPage(
<ClientResourcesTable <ClientResourcesTable
internalResources={internalResourceRows} internalResources={internalResourceRows}
orgId={params.orgId} orgId={params.orgId}
defaultSort={{ rowCount={pagination.total}
id: "name", pagination={{
desc: false pageIndex: pagination.page - 1,
pageSize: pagination.pageSize
}} }}
/> />
</OrgProvider> </OrgProvider>

View File

@@ -16,7 +16,7 @@ import { cache } from "react";
export interface ProxyResourcesPageProps { export interface ProxyResourcesPageProps {
params: Promise<{ orgId: string }>; params: Promise<{ orgId: string }>;
searchParams: Promise<{ view?: string }>; searchParams: Promise<Record<string, string>>;
} }
export default async function ProxyResourcesPage( export default async function ProxyResourcesPage(
@@ -24,14 +24,22 @@ export default async function ProxyResourcesPage(
) { ) {
const params = await props.params; const params = await props.params;
const t = await getTranslations(); const t = await getTranslations();
const searchParams = new URLSearchParams(await props.searchParams);
let resources: ListResourcesResponse["resources"] = []; let resources: ListResourcesResponse["resources"] = [];
let pagination: ListResourcesResponse["pagination"] = {
total: 0,
page: 1,
pageSize: 20
};
try { try {
const res = await internal.get<AxiosResponse<ListResourcesResponse>>( const res = await internal.get<AxiosResponse<ListResourcesResponse>>(
`/org/${params.orgId}/resources`, `/org/${params.orgId}/resources?${searchParams.toString()}`,
await authCookieHeader() await authCookieHeader()
); );
resources = res.data.data.resources; const responseData = res.data.data;
resources = responseData.resources;
pagination = responseData.pagination;
} catch (e) {} } catch (e) {}
let siteResources: ListAllSiteResourcesByOrgResponse["siteResources"] = []; let siteResources: ListAllSiteResourcesByOrgResponse["siteResources"] = [];
@@ -104,9 +112,10 @@ export default async function ProxyResourcesPage(
<ProxyResourcesTable <ProxyResourcesTable
resources={resourceRows} resources={resourceRows}
orgId={params.orgId} orgId={params.orgId}
defaultSort={{ rowCount={pagination.total}
id: "name", pagination={{
desc: false pageIndex: pagination.page - 1,
pageSize: pagination.pageSize
}} }}
/> />
</OrgProvider> </OrgProvider>

View File

@@ -63,7 +63,6 @@ import { QRCodeCanvas } from "qrcode.react";
import { useTranslations } from "next-intl"; import { useTranslations } from "next-intl";
import { build } from "@server/build"; import { build } from "@server/build";
import { NewtSiteInstallCommands } from "@app/components/newt-install-commands"; import { NewtSiteInstallCommands } from "@app/components/newt-install-commands";
import { id } from "date-fns/locale";
type SiteType = "newt" | "wireguard" | "local"; type SiteType = "newt" | "wireguard" | "local";

View File

@@ -9,19 +9,30 @@ import { getTranslations } from "next-intl/server";
type SitesPageProps = { type SitesPageProps = {
params: Promise<{ orgId: string }>; params: Promise<{ orgId: string }>;
searchParams: Promise<Record<string, string>>;
}; };
export const dynamic = "force-dynamic"; export const dynamic = "force-dynamic";
export default async function SitesPage(props: SitesPageProps) { export default async function SitesPage(props: SitesPageProps) {
const params = await props.params; const params = await props.params;
const searchParams = new URLSearchParams(await props.searchParams);
let sites: ListSitesResponse["sites"] = []; let sites: ListSitesResponse["sites"] = [];
let pagination: ListSitesResponse["pagination"] = {
total: 0,
page: 1,
pageSize: 20
};
try { try {
const res = await internal.get<AxiosResponse<ListSitesResponse>>( const res = await internal.get<AxiosResponse<ListSitesResponse>>(
`/org/${params.orgId}/sites`, `/org/${params.orgId}/sites?${searchParams.toString()}`,
await authCookieHeader() await authCookieHeader()
); );
sites = res.data.data.sites; const responseData = res.data.data;
sites = responseData.sites;
pagination = responseData.pagination;
} catch (e) {} } catch (e) {}
const t = await getTranslations(); const t = await getTranslations();
@@ -60,8 +71,6 @@ export default async function SitesPage(props: SitesPageProps) {
return ( return (
<> <>
{/* <SitesSplashCard /> */}
<SettingsSectionTitle <SettingsSectionTitle
title={t("siteManageSites")} title={t("siteManageSites")}
description={t("siteDescription")} description={t("siteDescription")}
@@ -69,7 +78,15 @@ export default async function SitesPage(props: SitesPageProps) {
<SitesBanner /> <SitesBanner />
<SitesTable sites={siteRows} orgId={params.orgId} /> <SitesTable
sites={siteRows}
orgId={params.orgId}
rowCount={pagination.total}
pagination={{
pageIndex: pagination.page - 1,
pageSize: pagination.pageSize
}}
/>
</> </>
); );
} }

View File

@@ -15,6 +15,7 @@ export default async function Page(props: {
redirect: string | undefined; redirect: string | undefined;
email: string | undefined; email: string | undefined;
fromSmartLogin: string | undefined; fromSmartLogin: string | undefined;
skipVerificationEmail: string | undefined;
}>; }>;
}) { }) {
const searchParams = await props.searchParams; const searchParams = await props.searchParams;
@@ -75,6 +76,10 @@ export default async function Page(props: {
inviteId={inviteId} inviteId={inviteId}
emailParam={searchParams.email} emailParam={searchParams.email}
fromSmartLogin={searchParams.fromSmartLogin === "true"} fromSmartLogin={searchParams.fromSmartLogin === "true"}
skipVerificationEmail={
searchParams.skipVerificationEmail === "true" ||
searchParams.skipVerificationEmail === "1"
}
/> />
<p className="text-center text-muted-foreground mt-4"> <p className="text-center text-muted-foreground mt-4">

View File

@@ -31,6 +31,10 @@ export type SidebarNavSection = {
items: SidebarNavItem[]; items: SidebarNavItem[];
}; };
export type OrgNavSectionsOptions = {
isPrimaryOrg?: boolean;
};
// Merged from 'user-management-and-resources' branch // Merged from 'user-management-and-resources' branch
export const orgLangingNavItems: SidebarNavItem[] = [ export const orgLangingNavItems: SidebarNavItem[] = [
{ {
@@ -40,7 +44,10 @@ export const orgLangingNavItems: SidebarNavItem[] = [
} }
]; ];
export const orgNavSections = (env?: Env): SidebarNavSection[] => [ export const orgNavSections = (
env?: Env,
options?: OrgNavSectionsOptions
): SidebarNavSection[] => [
{ {
heading: "sidebarGeneral", heading: "sidebarGeneral",
items: [ items: [
@@ -214,28 +221,28 @@ export const orgNavSections = (env?: Env): SidebarNavSection[] => [
title: "sidebarSettings", title: "sidebarSettings",
href: "/{orgId}/settings/general", href: "/{orgId}/settings/general",
icon: <Settings className="size-4 flex-none" /> icon: <Settings className="size-4 flex-none" />
}
]
}, },
...(build == "saas" && options?.isPrimaryOrg
...(build == "saas"
? [ ? [
{
heading: "sidebarBillingAndLicenses",
items: [
{ {
title: "sidebarBilling", title: "sidebarBilling",
href: "/{orgId}/settings/billing", href: "/{orgId}/settings/billing",
icon: <CreditCard className="size-4 flex-none" /> icon: <CreditCard className="size-4 flex-none" />
} },
]
: []),
...(build == "saas"
? [
{ {
title: "sidebarEnterpriseLicenses", title: "sidebarEnterpriseLicenses",
href: "/{orgId}/settings/license", href: "/{orgId}/settings/license",
icon: <TicketCheck className="size-4 flex-none" /> icon: <TicketCheck className="size-4 flex-none" />
} }
] ]
: [])
]
} }
]
: [])
]; ];
export const adminNavSections = (env?: Env): SidebarNavSection[] => [ export const adminNavSections = (env?: Env): SidebarNavSection[] => [

View File

@@ -73,7 +73,7 @@ export default async function Page(props: {
if (!orgs.length) { if (!orgs.length) {
if (!env.flags.disableUserCreateOrg || user.serverAdmin) { if (!env.flags.disableUserCreateOrg || user.serverAdmin) {
redirect("/setup"); redirect("/setup?firstOrg");
} }
} }
@@ -86,6 +86,14 @@ export default async function Page(props: {
targetOrgId = lastOrgCookie; targetOrgId = lastOrgCookie;
} else { } else {
let ownedOrg = orgs.find((org) => org.isOwner); let ownedOrg = orgs.find((org) => org.isOwner);
let primaryOrg = orgs.find((org) => org.isPrimaryOrg);
if (!ownedOrg) {
if (primaryOrg) {
ownedOrg = primaryOrg;
} else {
ownedOrg = orgs[0];
}
}
if (!ownedOrg) { if (!ownedOrg) {
ownedOrg = orgs[0]; ownedOrg = orgs[0];
} }

View File

@@ -4,19 +4,14 @@ import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input"; import { Input } from "@/components/ui/input";
import { toast } from "@app/hooks/useToast"; import { toast } from "@app/hooks/useToast";
import { useCallback, useEffect, useState } from "react"; import { useCallback, useEffect, useState } from "react";
import {
Card,
CardContent,
CardDescription,
CardHeader,
CardTitle
} from "@app/components/ui/card";
import { formatAxiosError } from "@app/lib/api"; import { formatAxiosError } from "@app/lib/api";
import { createApiClient } from "@app/lib/api"; import { createApiClient } from "@app/lib/api";
import { useEnvContext } from "@app/hooks/useEnvContext"; import { useEnvContext } from "@app/hooks/useEnvContext";
import { useUserContext } from "@app/hooks/useUserContext";
import { build } from "@server/build";
import { Separator } from "@/components/ui/separator"; import { Separator } from "@/components/ui/separator";
import { z } from "zod"; import { z } from "zod";
import { useRouter } from "next/navigation"; import { useRouter, useSearchParams } from "next/navigation";
import { useForm } from "react-hook-form"; import { useForm } from "react-hook-form";
import { zodResolver } from "@hookform/resolvers/zod"; import { zodResolver } from "@hookform/resolvers/zod";
import { import {
@@ -35,7 +30,7 @@ import {
CollapsibleContent, CollapsibleContent,
CollapsibleTrigger CollapsibleTrigger
} from "@app/components/ui/collapsible"; } from "@app/components/ui/collapsible";
import { ChevronsUpDown } from "lucide-react"; import { ArrowRight, ChevronsUpDown } from "lucide-react";
import { cn } from "@app/lib/cn"; import { cn } from "@app/lib/cn";
type Step = "org" | "site" | "resources"; type Step = "org" | "site" | "resources";
@@ -45,6 +40,7 @@ export default function StepperForm() {
const [orgIdTaken, setOrgIdTaken] = useState(false); const [orgIdTaken, setOrgIdTaken] = useState(false);
const t = useTranslations(); const t = useTranslations();
const { env } = useEnvContext(); const { env } = useEnvContext();
const { user } = useUserContext();
const [loading, setLoading] = useState(false); const [loading, setLoading] = useState(false);
const [isChecked, setIsChecked] = useState(false); const [isChecked, setIsChecked] = useState(false);
@@ -54,7 +50,10 @@ export default function StepperForm() {
const orgSchema = z.object({ const orgSchema = z.object({
orgName: z.string().min(1, { message: t("orgNameRequired") }), orgName: z.string().min(1, { message: t("orgNameRequired") }),
orgId: z.string().min(1, { message: t("orgIdRequired") }), orgId: z
.string()
.min(1, { message: t("orgIdRequired") })
.max(32, { message: t("orgIdMaxLength") }),
subnet: z.string().min(1, { message: t("subnetRequired") }), subnet: z.string().min(1, { message: t("subnetRequired") }),
utilitySubnet: z.string().min(1, { message: t("subnetRequired") }) utilitySubnet: z.string().min(1, { message: t("subnetRequired") })
}); });
@@ -71,12 +70,27 @@ export default function StepperForm() {
const api = createApiClient(useEnvContext()); const api = createApiClient(useEnvContext());
const router = useRouter(); const router = useRouter();
const searchParams = useSearchParams();
const isFirstOrg = searchParams.get("firstOrg") != null;
// Fetch default subnet on component mount // Fetch default subnet on component mount
useEffect(() => { useEffect(() => {
fetchDefaultSubnet(); fetchDefaultSubnet();
}, []); }, []);
// Prefill org name and id when build is saas and firstOrg query param is set
useEffect(() => {
if (build !== "saas" || !user || !isFirstOrg) return;
const orgName = user.email
? `${user.email}'s Organization`
: "My Organization";
const orgId = `org_${user.userId}`;
orgForm.setValue("orgName", orgName);
orgForm.setValue("orgId", orgId);
debouncedCheckOrgIdAvailability(orgId);
}, []);
const fetchDefaultSubnet = async () => { const fetchDefaultSubnet = async () => {
try { try {
const res = await api.get(`/pick-org-defaults`); const res = await api.get(`/pick-org-defaults`);
@@ -129,6 +143,16 @@ export default function StepperForm() {
.replace(/^-+|-+$/g, ""); .replace(/^-+|-+$/g, "");
}; };
const sanitizeOrgId = (value: string) => {
return value
.toLowerCase()
.replace(/\s+/g, "-")
.replace(/[^a-z0-9_-]/g, "")
.replace(/-+/g, "-")
.replace(/^-+|-+$/g, "")
.slice(0, 32);
};
async function orgSubmit(values: z.infer<typeof orgSchema>) { async function orgSubmit(values: z.infer<typeof orgSchema>) {
if (orgIdTaken) { if (orgIdTaken) {
return; return;
@@ -161,14 +185,15 @@ export default function StepperForm() {
} }
return ( return (
<>
<Card>
<CardHeader>
<CardTitle>{t("setupNewOrg")}</CardTitle>
<CardDescription>{t("setupCreate")}</CardDescription>
</CardHeader>
<CardContent>
<section className="space-y-6"> <section className="space-y-6">
<div>
<h1 className="text-2xl font-semibold tracking-tight">
{t("setupNewOrg")}
</h1>
<p className="text-muted-foreground text-sm mt-1">
{t("setupCreate")}
</p>
</div>
<div className="flex justify-between mb-2"> <div className="flex justify-between mb-2">
<div className="flex flex-col items-center"> <div className="flex flex-col items-center">
<div <div
@@ -245,9 +270,7 @@ export default function StepperForm() {
name="orgName" name="orgName"
render={({ field }) => ( render={({ field }) => (
<FormItem> <FormItem>
<FormLabel> <FormLabel>{t("setupOrgName")}</FormLabel>
{t("setupOrgName")}
</FormLabel>
<FormControl> <FormControl>
<Input <Input
type="text" type="text"
@@ -260,9 +283,7 @@ export default function StepperForm() {
"-" "-"
); );
const orgId = const orgId =
generateId( generateId(sanitizedValue);
sanitizedValue
);
orgForm.setValue( orgForm.setValue(
"orgId", "orgId",
orgId orgId
@@ -293,20 +314,28 @@ export default function StepperForm() {
name="orgId" name="orgId"
render={({ field }) => ( render={({ field }) => (
<FormItem> <FormItem>
<FormLabel> <FormLabel>{t("orgId")}</FormLabel>
{t("orgId")}
</FormLabel>
<FormControl> <FormControl>
<Input <Input
type="text" type="text"
{...field} {...field}
onChange={(e) => {
const value = sanitizeOrgId(
e.target.value
);
field.onChange(value);
setOrgIdTaken(false);
if (value) {
debouncedCheckOrgIdAvailability(
value
);
}
}}
/> />
</FormControl> </FormControl>
<FormMessage /> <FormMessage />
<FormDescription> <FormDescription>
{t( {t("setupIdentifierMessage")}
"setupIdentifierMessage"
)}
</FormDescription> </FormDescription>
</FormItem> </FormItem>
)} )}
@@ -344,21 +373,14 @@ export default function StepperForm() {
render={({ field }) => ( render={({ field }) => (
<FormItem> <FormItem>
<FormLabel> <FormLabel>
{t( {t("setupSubnetAdvanced")}
"setupSubnetAdvanced"
)}
</FormLabel> </FormLabel>
<FormControl> <FormControl>
<Input <Input type="text" {...field} />
type="text"
{...field}
/>
</FormControl> </FormControl>
<FormMessage /> <FormMessage />
<FormDescription> <FormDescription>
{t( {t("setupSubnetDescription")}
"setupSubnetDescription"
)}
</FormDescription> </FormDescription>
</FormItem> </FormItem>
)} )}
@@ -370,15 +392,10 @@ export default function StepperForm() {
render={({ field }) => ( render={({ field }) => (
<FormItem> <FormItem>
<FormLabel> <FormLabel>
{t( {t("setupUtilitySubnet")}
"setupUtilitySubnet"
)}
</FormLabel> </FormLabel>
<FormControl> <FormControl>
<Input <Input type="text" {...field} />
type="text"
{...field}
/>
</FormControl> </FormControl>
<FormMessage /> <FormMessage />
<FormDescription> <FormDescription>
@@ -409,15 +426,13 @@ export default function StepperForm() {
disabled={loading || orgIdTaken} disabled={loading || orgIdTaken}
> >
{t("setupCreateOrg")} {t("setupCreateOrg")}
<ArrowRight className="ml-2 h-4 w-4" />
</Button> </Button>
</div> </div>
</form> </form>
</Form> </Form>
)} )}
</section> </section>
</CardContent>
</Card>
</>
); );
} }

View File

@@ -2,16 +2,16 @@
import { useEnvContext } from "@app/hooks/useEnvContext"; import { useEnvContext } from "@app/hooks/useEnvContext";
import { toast } from "@app/hooks/useToast"; import { toast } from "@app/hooks/useToast";
import { createApiClient, formatAxiosError } from "@app/lib/api"; import { createApiClient, formatAxiosError } from "@app/lib/api";
import { getUserDisplayName } from "@app/lib/getUserDisplayName";
import { cn } from "@app/lib/cn"; import { cn } from "@app/lib/cn";
import { formatFingerprintInfo } from "@app/lib/formatDeviceFingerprint"; import { formatFingerprintInfo } from "@app/lib/formatDeviceFingerprint";
import { getUserDisplayName } from "@app/lib/getUserDisplayName";
import { import {
approvalFiltersSchema, approvalFiltersSchema,
approvalQueries, approvalQueries,
type ApprovalItem type ApprovalItem
} from "@app/lib/queries"; } from "@app/lib/queries";
import { useQuery } from "@tanstack/react-query"; import { useInfiniteQuery } from "@tanstack/react-query";
import { ArrowRight, Ban, Check, LaptopMinimal, RefreshCw } from "lucide-react"; import { Ban, Check, Loader, RefreshCw } from "lucide-react";
import { useTranslations } from "next-intl"; import { useTranslations } from "next-intl";
import Link from "next/link"; import Link from "next/link";
import { usePathname, useRouter, useSearchParams } from "next/navigation"; import { usePathname, useRouter, useSearchParams } from "next/navigation";
@@ -54,12 +54,20 @@ export function ApprovalFeed({
const { isPaidUser } = usePaidStatus(); const { isPaidUser } = usePaidStatus();
const { data, isFetching, refetch } = useQuery({ const {
data,
isFetching,
isLoading,
refetch,
hasNextPage,
fetchNextPage,
isFetchingNextPage
} = useInfiniteQuery({
...approvalQueries.listApprovals(orgId, filters), ...approvalQueries.listApprovals(orgId, filters),
enabled: isPaidUser(tierMatrix.deviceApprovals) enabled: isPaidUser(tierMatrix.deviceApprovals)
}); });
const approvals = data?.approvals ?? []; const approvals = data?.pages.flatMap((data) => data.approvals) ?? [];
// Show empty state if no approvals are enabled for any role // Show empty state if no approvals are enabled for any role
if (!hasApprovalsEnabled) { if (!hasApprovalsEnabled) {
@@ -115,13 +123,13 @@ export function ApprovalFeed({
onClick={() => { onClick={() => {
refetch(); refetch();
}} }}
disabled={isFetching} disabled={isFetching || isLoading}
className="lg:static gap-2" className="lg:static gap-2"
> >
<RefreshCw <RefreshCw
className={cn( className={cn(
"size-4", "size-4",
isFetching && "animate-spin" (isFetching || isLoading) && "animate-spin"
)} )}
/> />
{t("refresh")} {t("refresh")}
@@ -145,13 +153,30 @@ export function ApprovalFeed({
))} ))}
{approvals.length === 0 && ( {approvals.length === 0 && (
<li className="flex justify-center items-center p-4 text-muted-foreground"> <li className="flex justify-center items-center p-4 text-muted-foreground gap-2">
{t("approvalListEmpty")} {isLoading
? t("loadingApprovals")
: t("approvalListEmpty")}
{isLoading && (
<Loader className="size-4 flex-none animate-spin" />
)}
</li> </li>
)} )}
</ul> </ul>
</CardHeader> </CardHeader>
</Card> </Card>
{hasNextPage && (
<Button
variant="secondary"
className="self-center"
size="lg"
loading={isFetchingNextPage}
onClick={() => fetchNextPage()}
>
{t("approvalLoadMore")}
</Button>
)}
</div> </div>
); );
} }

View File

@@ -1,9 +1,5 @@
"use client"; "use client";
import { zodResolver } from "@hookform/resolvers/zod";
import { startTransition, useActionState, useState } from "react";
import { useForm } from "react-hook-form";
import z from "zod";
import { import {
Form, Form,
FormControl, FormControl,
@@ -13,6 +9,11 @@ import {
FormLabel, FormLabel,
FormMessage FormMessage
} from "@app/components/ui/form"; } from "@app/components/ui/form";
import { zodResolver } from "@hookform/resolvers/zod";
import { useTranslations } from "next-intl";
import { useActionState } from "react";
import { useForm } from "react-hook-form";
import z from "zod";
import { import {
SettingsSection, SettingsSection,
SettingsSectionBody, SettingsSectionBody,
@@ -21,19 +22,19 @@ import {
SettingsSectionHeader, SettingsSectionHeader,
SettingsSectionTitle SettingsSectionTitle
} from "./Settings"; } from "./Settings";
import { useTranslations } from "next-intl";
import type { GetLoginPageBrandingResponse } from "@server/routers/loginPage/types";
import { Input } from "./ui/input";
import { ExternalLink, InfoIcon, XIcon } from "lucide-react";
import { Button } from "./ui/button";
import { createApiClient, formatAxiosError } from "@app/lib/api";
import { useEnvContext } from "@app/hooks/useEnvContext"; import { useEnvContext } from "@app/hooks/useEnvContext";
import { useRouter } from "next/navigation";
import { toast } from "@app/hooks/useToast";
import { usePaidStatus } from "@app/hooks/usePaidStatus"; import { usePaidStatus } from "@app/hooks/usePaidStatus";
import { toast } from "@app/hooks/useToast";
import { createApiClient, formatAxiosError } from "@app/lib/api";
import { build } from "@server/build"; import { build } from "@server/build";
import type { GetLoginPageBrandingResponse } from "@server/routers/loginPage/types";
import { XIcon } from "lucide-react";
import { useRouter } from "next/navigation";
import { PaidFeaturesAlert } from "./PaidFeaturesAlert"; import { PaidFeaturesAlert } from "./PaidFeaturesAlert";
import { Button } from "./ui/button";
import { Input } from "./ui/input";
import { validateLocalPath } from "@app/lib/validateLocalPath";
import { Alert, AlertDescription, AlertTitle } from "./ui/alert"; import { Alert, AlertDescription, AlertTitle } from "./ui/alert";
import { tierMatrix } from "@server/lib/billing/tierMatrix"; import { tierMatrix } from "@server/lib/billing/tierMatrix";
@@ -45,13 +46,36 @@ export type AuthPageCustomizationProps = {
const AuthPageFormSchema = z.object({ const AuthPageFormSchema = z.object({
logoUrl: z.union([ logoUrl: z.union([
z.literal(""), z.literal(""),
z.url("Must be a valid URL").superRefine(async (url, ctx) => { z.string().superRefine(async (urlOrPath, ctx) => {
const parseResult = z.url().safeParse(urlOrPath);
if (!parseResult.success) {
if (build !== "enterprise") {
ctx.addIssue({
code: "custom",
message: "Must be a valid URL"
});
return;
} else {
try { try {
const response = await fetch(url, { validateLocalPath(urlOrPath);
} catch (error) {
ctx.addIssue({
code: "custom",
message:
"Must be either a valid image URL or a valid pathname starting with `/` and not containing query parameters, `..` or `*`"
});
} finally {
return;
}
}
}
try {
const response = await fetch(urlOrPath, {
method: "HEAD" method: "HEAD"
}).catch(() => { }).catch(() => {
// If HEAD fails (CORS or method not allowed), try GET // If HEAD fails (CORS or method not allowed), try GET
return fetch(url, { method: "GET" }); return fetch(urlOrPath, { method: "GET" });
}); });
if (response.status !== 200) { if (response.status !== 200) {
@@ -271,12 +295,25 @@ export default function AuthPageBrandingForm({
render={({ field }) => ( render={({ field }) => (
<FormItem className="md:col-span-3"> <FormItem className="md:col-span-3">
<FormLabel> <FormLabel>
{t("brandingLogoURL")} {build === "enterprise"
? t(
"brandingLogoURLOrPath"
)
: t("brandingLogoURL")}
</FormLabel> </FormLabel>
<FormControl> <FormControl>
<Input {...field} /> <Input {...field} />
</FormControl> </FormControl>
<FormMessage /> <FormMessage />
<FormDescription>
{build === "enterprise"
? t(
"brandingLogoPathDescription"
)
: t(
"brandingLogoURLDescription"
)}
</FormDescription>
</FormItem> </FormItem>
)} )}
/> />

View File

@@ -25,6 +25,11 @@ import CreateInternalResourceDialog from "@app/components/CreateInternalResource
import EditInternalResourceDialog from "@app/components/EditInternalResourceDialog"; import EditInternalResourceDialog from "@app/components/EditInternalResourceDialog";
import { orgQueries } from "@app/lib/queries"; import { orgQueries } from "@app/lib/queries";
import { useQuery } from "@tanstack/react-query"; import { useQuery } from "@tanstack/react-query";
import type { PaginationState } from "@tanstack/react-table";
import { ControlledDataTable } from "./ui/controlled-data-table";
import { useNavigationContext } from "@app/hooks/useNavigationContext";
import { useDebouncedCallback } from "use-debounce";
import { ColumnFilterButton } from "./ColumnFilterButton";
export type InternalResourceRow = { export type InternalResourceRow = {
id: number; id: number;
@@ -51,18 +56,22 @@ export type InternalResourceRow = {
type ClientResourcesTableProps = { type ClientResourcesTableProps = {
internalResources: InternalResourceRow[]; internalResources: InternalResourceRow[];
orgId: string; orgId: string;
defaultSort?: { pagination: PaginationState;
id: string; rowCount: number;
desc: boolean;
};
}; };
export default function ClientResourcesTable({ export default function ClientResourcesTable({
internalResources, internalResources,
orgId, orgId,
defaultSort pagination,
rowCount
}: ClientResourcesTableProps) { }: ClientResourcesTableProps) {
const router = useRouter(); const router = useRouter();
const {
navigate: filter,
isNavigating: isFiltering,
searchParams
} = useNavigationContext();
const t = useTranslations(); const t = useTranslations();
const { env } = useEnvContext(); const { env } = useEnvContext();
@@ -122,19 +131,7 @@ export default function ClientResourcesTable({
accessorKey: "name", accessorKey: "name",
enableHiding: false, enableHiding: false,
friendlyName: t("name"), friendlyName: t("name"),
header: ({ column }) => { header: () => <span className="p-3">{t("name")}</span>
return (
<Button
variant="ghost"
onClick={() =>
column.toggleSorting(column.getIsSorted() === "asc")
}
>
{t("name")}
<ArrowUpDown className="ml-2 h-4 w-4" />
</Button>
);
}
}, },
{ {
id: "niceId", id: "niceId",
@@ -180,9 +177,24 @@ export default function ClientResourcesTable({
accessorKey: "mode", accessorKey: "mode",
friendlyName: t("editInternalResourceDialogMode"), friendlyName: t("editInternalResourceDialogMode"),
header: () => ( header: () => (
<span className="p-3"> <ColumnFilterButton
{t("editInternalResourceDialogMode")} options={[
</span> {
value: "host",
label: t("editInternalResourceDialogModeHost")
},
{
value: "cidr",
label: t("editInternalResourceDialogModeCidr")
}
]}
selectedValue={searchParams.get("mode") ?? undefined}
onValueChange={(value) => handleFilterChange("mode", value)}
searchPlaceholder={t("searchPlaceholder")}
emptyMessage={t("emptySearchOptions")}
label={t("editInternalResourceDialogMode")}
className="p-3"
/>
), ),
cell: ({ row }) => { cell: ({ row }) => {
const resourceRow = row.original; const resourceRow = row.original;
@@ -300,6 +312,37 @@ export default function ClientResourcesTable({
} }
]; ];
function handleFilterChange(
column: string,
value: string | undefined | null
) {
searchParams.delete(column);
searchParams.delete("page");
if (value) {
searchParams.set(column, value);
}
filter({
searchParams
});
}
const handlePaginationChange = (newPage: PaginationState) => {
searchParams.set("page", (newPage.pageIndex + 1).toString());
searchParams.set("pageSize", newPage.pageSize.toString());
filter({
searchParams
});
};
const handleSearchChange = useDebouncedCallback((query: string) => {
searchParams.set("query", query);
searchParams.delete("page");
filter({
searchParams
});
}, 300);
return ( return (
<> <>
{selectedInternalResource && ( {selectedInternalResource && (
@@ -327,19 +370,20 @@ export default function ClientResourcesTable({
/> />
)} )}
<DataTable <ControlledDataTable
columns={internalColumns} columns={internalColumns}
data={internalResources} rows={internalResources}
persistPageSize="internal-resources" tableId="internal-resources"
searchPlaceholder={t("resourcesSearch")} searchPlaceholder={t("resourcesSearch")}
searchColumn="name"
onAdd={() => setIsCreateDialogOpen(true)} onAdd={() => setIsCreateDialogOpen(true)}
addButtonText={t("resourceAdd")} addButtonText={t("resourceAdd")}
onSearch={handleSearchChange}
onRefresh={refreshData} onRefresh={refreshData}
isRefreshing={isRefreshing} onPaginationChange={handlePaginationChange}
defaultSort={defaultSort} pagination={pagination}
enableColumnVisibility={true} rowCount={rowCount}
persistColumnVisibility="internal-resources" isRefreshing={isRefreshing || isFiltering}
enableColumnVisibility
columnVisibility={{ columnVisibility={{
niceId: false, niceId: false,
aliasAddress: false aliasAddress: false

View File

@@ -15,6 +15,7 @@ import {
} from "@app/components/ui/command"; } from "@app/components/ui/command";
import { CheckIcon, ChevronDownIcon, Filter } from "lucide-react"; import { CheckIcon, ChevronDownIcon, Filter } from "lucide-react";
import { cn } from "@app/lib/cn"; import { cn } from "@app/lib/cn";
import { Badge } from "./ui/badge";
interface FilterOption { interface FilterOption {
value: string; value: string;
@@ -61,16 +62,19 @@ export function ColumnFilter({
> >
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
<Filter className="h-4 w-4" /> <Filter className="h-4 w-4" />
<span className="truncate">
{selectedOption && (
<Badge className="truncate" variant="secondary">
{selectedOption {selectedOption
? selectedOption.label ? selectedOption.label
: placeholder} : placeholder}
</span> </Badge>
)}
</div> </div>
<ChevronDownIcon className="h-4 w-4 shrink-0 opacity-50" /> <ChevronDownIcon className="h-4 w-4 shrink-0 opacity-50" />
</Button> </Button>
</PopoverTrigger> </PopoverTrigger>
<PopoverContent className="p-0 w-[200px]" align="start"> <PopoverContent className="p-0 w-50" align="start">
<Command> <Command>
<CommandInput placeholder={searchPlaceholder} /> <CommandInput placeholder={searchPlaceholder} />
<CommandList> <CommandList>

View File

@@ -0,0 +1,126 @@
import { useState } from "react";
import { Button } from "@app/components/ui/button";
import {
Popover,
PopoverContent,
PopoverTrigger
} from "@app/components/ui/popover";
import {
Command,
CommandEmpty,
CommandGroup,
CommandInput,
CommandItem,
CommandList
} from "@app/components/ui/command";
import { CheckIcon, ChevronDownIcon, Funnel } from "lucide-react";
import { cn } from "@app/lib/cn";
import { Badge } from "./ui/badge";
interface FilterOption {
value: string;
label: string;
}
interface ColumnFilterButtonProps {
options: FilterOption[];
selectedValue?: string;
onValueChange: (value: string | undefined) => void;
placeholder?: string;
searchPlaceholder?: string;
emptyMessage?: string;
className?: string;
label: string;
}
export function ColumnFilterButton({
options,
selectedValue,
onValueChange,
placeholder,
searchPlaceholder = "Search...",
emptyMessage = "No options found",
className,
label
}: ColumnFilterButtonProps) {
const [open, setOpen] = useState(false);
const selectedOption = options.find(
(option) => option.value === selectedValue
);
return (
<Popover open={open} onOpenChange={setOpen}>
<PopoverTrigger asChild>
<Button
variant="ghost"
role="combobox"
aria-expanded={open}
className={cn(
"justify-between text-sm h-8 px-2",
!selectedValue && "text-muted-foreground",
className
)}
>
<div className="flex items-center gap-2">
{label}
<Funnel className="size-4 flex-none" />
{selectedOption && (
<Badge className="truncate" variant="secondary">
{selectedOption.label}
</Badge>
)}
</div>
</Button>
</PopoverTrigger>
<PopoverContent className="p-0 w-50" align="start">
<Command>
<CommandInput placeholder={searchPlaceholder} />
<CommandList>
<CommandEmpty>{emptyMessage}</CommandEmpty>
<CommandGroup>
{/* Clear filter option */}
{selectedValue && (
<CommandItem
onSelect={() => {
onValueChange(undefined);
setOpen(false);
}}
className="text-muted-foreground"
>
Clear filter
</CommandItem>
)}
{options.map((option) => (
<CommandItem
key={option.value}
value={option.label}
onSelect={() => {
onValueChange(
selectedValue === option.value
? undefined
: option.value
);
setOpen(false);
}}
>
<CheckIcon
className={cn(
"mr-2 h-4 w-4",
selectedValue === option.value
? "opacity-100"
: "opacity-0"
)}
/>
{option.label}
</CommandItem>
))}
</CommandGroup>
</CommandList>
</Command>
</PopoverContent>
</Popover>
);
}

View File

@@ -255,10 +255,7 @@ export default function CreateInternalResourceDialog({
const { data: usersResponse = [] } = useQuery(orgQueries.users({ orgId })); const { data: usersResponse = [] } = useQuery(orgQueries.users({ orgId }));
const { data: clientsResponse = [] } = useQuery( const { data: clientsResponse = [] } = useQuery(
orgQueries.clients({ orgQueries.clients({
orgId, orgId
filters: {
filter: "machine"
}
}) })
); );

View File

@@ -277,10 +277,7 @@ export default function EditInternalResourceDialog({
orgQueries.roles({ orgId }), orgQueries.roles({ orgId }),
orgQueries.users({ orgId }), orgQueries.users({ orgId }),
orgQueries.clients({ orgQueries.clients({
orgId, orgId
filters: {
filter: "machine"
}
}), }),
resourceQueries.siteResourceUsers({ siteResourceId: resource.id }), resourceQueries.siteResourceUsers({ siteResourceId: resource.id }),
resourceQueries.siteResourceRoles({ siteResourceId: resource.id }), resourceQueries.siteResourceRoles({ siteResourceId: resource.id }),

View File

@@ -189,10 +189,12 @@ export function LayoutSidebar({
<div className="w-full border-t border-border" /> <div className="w-full border-t border-border" />
<div className="p-4 pt-1 flex flex-col shrink-0"> <div className="p-4 pt-1 flex flex-col shrink-0">
{canShowProductUpdates && ( {canShowProductUpdates ? (
<div className="mb-3"> <div className="mb-3">
<ProductUpdates isCollapsed={isSidebarCollapsed} /> <ProductUpdates isCollapsed={isSidebarCollapsed} />
</div> </div>
) : (
<div className="mb-3"></div>
)} )}
{build === "enterprise" && ( {build === "enterprise" && (

View File

@@ -16,13 +16,23 @@ import {
ArrowRight, ArrowRight,
ArrowUpDown, ArrowUpDown,
MoreHorizontal, MoreHorizontal,
CircleSlash CircleSlash,
ArrowDown01Icon,
ArrowUp10Icon,
ChevronsUpDownIcon
} from "lucide-react"; } from "lucide-react";
import { useTranslations } from "next-intl"; import { useTranslations } from "next-intl";
import Link from "next/link"; import Link from "next/link";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import { useMemo, useState, useTransition } from "react"; import { useMemo, useState, useTransition } from "react";
import { Badge } from "./ui/badge"; import { Badge } from "./ui/badge";
import type { PaginationState } from "@tanstack/react-table";
import { ControlledDataTable } from "./ui/controlled-data-table";
import { useNavigationContext } from "@app/hooks/useNavigationContext";
import { useDebouncedCallback } from "use-debounce";
import z from "zod";
import { getNextSortOrder, getSortDirection } from "@app/lib/sortColumn";
import { ColumnFilterButton } from "./ColumnFilterButton";
export type ClientRow = { export type ClientRow = {
id: number; id: number;
@@ -48,14 +58,24 @@ export type ClientRow = {
type ClientTableProps = { type ClientTableProps = {
machineClients: ClientRow[]; machineClients: ClientRow[];
orgId: string; orgId: string;
pagination: PaginationState;
rowCount: number;
}; };
export default function MachineClientsTable({ export default function MachineClientsTable({
machineClients, machineClients,
orgId orgId,
pagination,
rowCount
}: ClientTableProps) { }: ClientTableProps) {
const router = useRouter(); const router = useRouter();
const {
navigate: filter,
isNavigating: isFiltering,
searchParams
} = useNavigationContext();
const t = useTranslations(); const t = useTranslations();
const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false); const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false);
@@ -65,6 +85,7 @@ export default function MachineClientsTable({
const api = createApiClient(useEnvContext()); const api = createApiClient(useEnvContext());
const [isRefreshing, startTransition] = useTransition(); const [isRefreshing, startTransition] = useTransition();
const [isNavigatingToAddPage, startNavigation] = useTransition();
const defaultMachineColumnVisibility = { const defaultMachineColumnVisibility = {
subnet: false, subnet: false,
@@ -182,22 +203,8 @@ export default function MachineClientsTable({
{ {
accessorKey: "name", accessorKey: "name",
enableHiding: false, enableHiding: false,
friendlyName: "Name", friendlyName: t("name"),
header: ({ column }) => { header: () => <span className="px-3">{t("name")}</span>,
return (
<Button
variant="ghost"
onClick={() =>
column.toggleSorting(
column.getIsSorted() === "asc"
)
}
>
Name
<ArrowUpDown className="ml-2 h-4 w-4" />
</Button>
);
},
cell: ({ row }) => { cell: ({ row }) => {
const r = row.original; const r = row.original;
return ( return (
@@ -224,38 +231,35 @@ export default function MachineClientsTable({
{ {
accessorKey: "niceId", accessorKey: "niceId",
friendlyName: "Identifier", friendlyName: "Identifier",
header: ({ column }) => { header: () => <span className="px-3">{t("identifier")}</span>
return (
<Button
variant="ghost"
onClick={() =>
column.toggleSorting(
column.getIsSorted() === "asc"
)
}
>
{t("identifier")}
<ArrowUpDown className="ml-2 h-4 w-4" />
</Button>
);
}
}, },
{ {
accessorKey: "online", accessorKey: "online",
friendlyName: "Connectivity", friendlyName: t("online"),
header: ({ column }) => { header: () => {
return ( return (
<Button <ColumnFilterButton
variant="ghost" options={[
onClick={() => {
column.toggleSorting( value: "true",
column.getIsSorted() === "asc" label: t("connected")
) },
{
value: "false",
label: t("disconnected")
} }
> ]}
Connectivity selectedValue={
<ArrowUpDown className="ml-2 h-4 w-4" /> searchParams.get("online") ?? undefined
</Button> }
onValueChange={(value) =>
handleFilterChange("online", value)
}
searchPlaceholder={t("searchPlaceholder")}
emptyMessage={t("emptySearchOptions")}
label={t("online")}
className="p-3"
/>
); );
}, },
cell: ({ row }) => { cell: ({ row }) => {
@@ -279,38 +283,52 @@ export default function MachineClientsTable({
}, },
{ {
accessorKey: "mbIn", accessorKey: "mbIn",
friendlyName: "Data In", friendlyName: t("dataIn"),
header: ({ column }) => { header: () => {
const dataInOrder = getSortDirection(
"megabytesIn",
searchParams
);
const Icon =
dataInOrder === "asc"
? ArrowDown01Icon
: dataInOrder === "desc"
? ArrowUp10Icon
: ChevronsUpDownIcon;
return ( return (
<Button <Button
variant="ghost" variant="ghost"
onClick={() => onClick={() => toggleSort("megabytesIn")}
column.toggleSorting(
column.getIsSorted() === "asc"
)
}
> >
Data In {t("dataIn")}
<ArrowUpDown className="ml-2 h-4 w-4" /> <Icon className="ml-2 h-4 w-4" />
</Button> </Button>
); );
} }
}, },
{ {
accessorKey: "mbOut", accessorKey: "mbOut",
friendlyName: "Data Out", friendlyName: t("dataOut"),
header: ({ column }) => { header: () => {
const dataOutOrder = getSortDirection(
"megabytesOut",
searchParams
);
const Icon =
dataOutOrder === "asc"
? ArrowDown01Icon
: dataOutOrder === "desc"
? ArrowUp10Icon
: ChevronsUpDownIcon;
return ( return (
<Button <Button
variant="ghost" variant="ghost"
onClick={() => onClick={() => toggleSort("megabytesOut")}
column.toggleSorting(
column.getIsSorted() === "asc"
)
}
> >
Data Out {t("dataOut")}
<ArrowUpDown className="ml-2 h-4 w-4" /> <Icon className="ml-2 h-4 w-4" />
</Button> </Button>
); );
} }
@@ -318,21 +336,7 @@ export default function MachineClientsTable({
{ {
accessorKey: "client", accessorKey: "client",
friendlyName: t("agent"), friendlyName: t("agent"),
header: ({ column }) => { header: () => <span className="px-3">{t("agent")}</span>,
return (
<Button
variant="ghost"
onClick={() =>
column.toggleSorting(
column.getIsSorted() === "asc"
)
}
>
{t("agent")}
<ArrowUpDown className="ml-2 h-4 w-4" />
</Button>
);
},
cell: ({ row }) => { cell: ({ row }) => {
const originalRow = row.original; const originalRow = row.original;
@@ -356,22 +360,8 @@ export default function MachineClientsTable({
}, },
{ {
accessorKey: "subnet", accessorKey: "subnet",
friendlyName: "Address", friendlyName: t("address"),
header: ({ column }) => { header: () => <span className="px-3">{t("address")}</span>
return (
<Button
variant="ghost"
onClick={() =>
column.toggleSorting(
column.getIsSorted() === "asc"
)
}
>
Address
<ArrowUpDown className="ml-2 h-4 w-4" />
</Button>
);
}
} }
]; ];
@@ -455,7 +445,56 @@ export default function MachineClientsTable({
} }
return baseColumns; return baseColumns;
}, [hasRowsWithoutUserId, t]); }, [hasRowsWithoutUserId, t, getSortDirection, toggleSort]);
const booleanSearchFilterSchema = z
.enum(["true", "false"])
.optional()
.catch(undefined);
function handleFilterChange(
column: string,
value: string | null | undefined | string[]
) {
searchParams.delete(column);
searchParams.delete("page");
if (typeof value === "string") {
searchParams.set(column, value);
} else if (value) {
for (const val of value) {
searchParams.append(column, val);
}
}
filter({
searchParams
});
}
function toggleSort(column: string) {
const newSearch = getNextSortOrder(column, searchParams);
filter({
searchParams: newSearch
});
}
const handlePaginationChange = (newPage: PaginationState) => {
searchParams.set("page", (newPage.pageIndex + 1).toString());
searchParams.set("pageSize", newPage.pageSize.toString());
filter({
searchParams
});
};
const handleSearchChange = useDebouncedCallback((query: string) => {
searchParams.set("query", query);
searchParams.delete("page");
filter({
searchParams
});
}, 300);
return ( return (
<> <>
@@ -478,20 +517,25 @@ export default function MachineClientsTable({
title="Delete Client" title="Delete Client"
/> />
)} )}
<DataTable <ControlledDataTable
columns={columns} columns={columns}
data={machineClients || []} rows={machineClients}
persistPageSize="machine-clients" tableId="machine-clients"
searchPlaceholder={t("resourcesSearch")} searchPlaceholder={t("resourcesSearch")}
searchColumn="name"
onAdd={() => onAdd={() =>
startNavigation(() =>
router.push(`/${orgId}/settings/clients/machine/create`) router.push(`/${orgId}/settings/clients/machine/create`)
)
} }
pagination={pagination}
rowCount={rowCount}
addButtonText={t("createClient")} addButtonText={t("createClient")}
onRefresh={refreshData} onRefresh={refreshData}
isRefreshing={isRefreshing} isRefreshing={isRefreshing || isFiltering}
enableColumnVisibility={true} onSearch={handleSearchChange}
persistColumnVisibility="machine-clients" onPaginationChange={handlePaginationChange}
isNavigatingToAddPage={isNavigatingToAddPage}
enableColumnVisibility
columnVisibility={defaultMachineColumnVisibility} columnVisibility={defaultMachineColumnVisibility}
stickyLeftColumn="name" stickyLeftColumn="name"
stickyRightColumn="actions" stickyRightColumn="actions"
@@ -518,30 +562,10 @@ export default function MachineClientsTable({
value: "blocked" value: "blocked"
} }
], ],
filterFn: ( onValueChange(selectedValues: string[]) {
row: ClientRow, handleFilterChange("status", selectedValues);
selectedValues: (string | number | boolean)[]
) => {
if (selectedValues.length === 0) return true;
const rowArchived = row.archived || false;
const rowBlocked = row.blocked || false;
const isActive = !rowArchived && !rowBlocked;
if (selectedValues.includes("active") && isActive)
return true;
if (
selectedValues.includes("archived") &&
rowArchived
)
return true;
if (
selectedValues.includes("blocked") &&
rowBlocked
)
return true;
return false;
}, },
defaultValues: ["active"] // Default to showing active clients values: searchParams.getAll("status")
} }
]} ]}
/> />

View File

@@ -58,6 +58,18 @@ type Resource = {
siteName?: string | null; siteName?: string | null;
}; };
type SiteResource = {
siteResourceId: number;
name: string;
destination: string;
mode: string;
protocol: string | null;
enabled: boolean;
alias: string | null;
aliasAddress: string | null;
type: 'site';
};
type MemberResourcesPortalProps = { type MemberResourcesPortalProps = {
orgId: string; orgId: string;
}; };
@@ -334,7 +346,9 @@ export default function MemberResourcesPortal({
const { toast } = useToast(); const { toast } = useToast();
const [resources, setResources] = useState<Resource[]>([]); const [resources, setResources] = useState<Resource[]>([]);
const [siteResources, setSiteResources] = useState<SiteResource[]>([]);
const [filteredResources, setFilteredResources] = useState<Resource[]>([]); const [filteredResources, setFilteredResources] = useState<Resource[]>([]);
const [filteredSiteResources, setFilteredSiteResources] = useState<SiteResource[]>([]);
const [loading, setLoading] = useState(true); const [loading, setLoading] = useState(true);
const [error, setError] = useState<string | null>(null); const [error, setError] = useState<string | null>(null);
const [searchQuery, setSearchQuery] = useState(""); const [searchQuery, setSearchQuery] = useState("");
@@ -360,7 +374,9 @@ export default function MemberResourcesPortal({
if (response.data.success) { if (response.data.success) {
setResources(response.data.data.resources); setResources(response.data.data.resources);
setSiteResources(response.data.data.siteResources || []);
setFilteredResources(response.data.data.resources); setFilteredResources(response.data.data.resources);
setFilteredSiteResources(response.data.data.siteResources || []);
} else { } else {
setError("Failed to load resources"); setError("Failed to load resources");
} }
@@ -417,17 +433,61 @@ export default function MemberResourcesPortal({
setFilteredResources(filtered); setFilteredResources(filtered);
// Filter and sort site resources
const filteredSites = siteResources.filter(
(resource) =>
resource.name
.toLowerCase()
.includes(searchQuery.toLowerCase()) ||
resource.destination
.toLowerCase()
.includes(searchQuery.toLowerCase())
);
// Sort site resources
filteredSites.sort((a, b) => {
switch (sortBy) {
case "name-asc":
return a.name.localeCompare(b.name);
case "name-desc":
return b.name.localeCompare(a.name);
case "domain-asc":
case "domain-desc":
// Sort by destination for site resources
const destCompare = sortBy === "domain-asc"
? a.destination.localeCompare(b.destination)
: b.destination.localeCompare(a.destination);
return destCompare;
case "status-enabled":
return b.enabled ? 1 : -1;
case "status-disabled":
return a.enabled ? 1 : -1;
default:
return a.name.localeCompare(b.name);
}
});
setFilteredSiteResources(filteredSites);
// Reset to first page when search/sort changes // Reset to first page when search/sort changes
setCurrentPage(1); setCurrentPage(1);
}, [resources, searchQuery, sortBy]); }, [resources, siteResources, searchQuery, sortBy]);
// Calculate pagination // Calculate pagination
const totalPages = Math.ceil(filteredResources.length / itemsPerPage); const totalItems = filteredResources.length + filteredSiteResources.length;
const totalPages = Math.ceil(totalItems / itemsPerPage);
const startIndex = (currentPage - 1) * itemsPerPage; const startIndex = (currentPage - 1) * itemsPerPage;
const paginatedResources = filteredResources.slice( const paginatedResources = filteredResources.slice(
startIndex, startIndex,
startIndex + itemsPerPage startIndex + itemsPerPage
); );
const remainingSlots = itemsPerPage - paginatedResources.length;
const paginatedSiteResources = remainingSlots > 0
? filteredSiteResources.slice(
Math.max(0, startIndex - filteredResources.length),
Math.max(0, startIndex - filteredResources.length) + remainingSlots
)
: [];
const handleOpenResource = (resource: Resource) => { const handleOpenResource = (resource: Resource) => {
// Open the resource in a new tab // Open the resource in a new tab
@@ -575,7 +635,7 @@ export default function MemberResourcesPortal({
</div> </div>
{/* Resources Content */} {/* Resources Content */}
{filteredResources.length === 0 ? ( {filteredResources.length === 0 && filteredSiteResources.length === 0 ? (
/* Enhanced Empty State */ /* Enhanced Empty State */
<Card> <Card>
<CardContent className="flex flex-col items-center justify-center py-20 text-center"> <CardContent className="flex flex-col items-center justify-center py-20 text-center">
@@ -623,8 +683,19 @@ export default function MemberResourcesPortal({
</Card> </Card>
) : ( ) : (
<> <>
{/* Resources Grid */} {/* Public Resources Section */}
<div className="grid gap-5 grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-4 auto-cols-fr"> {paginatedResources.length > 0 && (
<>
<div className="mb-4">
<h3 className="text-lg font-semibold text-foreground flex items-center gap-2">
<Globe className="h-5 w-5" />
Public Resources
</h3>
<p className="text-sm text-muted-foreground mt-1">
Web applications and services accessible via browser
</p>
</div>
<div className="grid gap-5 grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-4 auto-cols-fr mb-8">
{paginatedResources.map((resource) => ( {paginatedResources.map((resource) => (
<Card key={resource.resourceId}> <Card key={resource.resourceId}>
<div className="p-6"> <div className="p-6">
@@ -702,13 +773,167 @@ export default function MemberResourcesPortal({
</Card> </Card>
))} ))}
</div> </div>
</>
)}
{/* Private Resources (Site Resources) Section */}
{paginatedSiteResources.length > 0 && (
<>
<div className="mb-4">
<h3 className="text-lg font-semibold text-foreground flex items-center gap-2">
<Combine className="h-5 w-5" />
Private Resources
</h3>
<p className="text-sm text-muted-foreground mt-1">
Internal network resources accessible via client
</p>
</div>
<div className="grid gap-5 grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-4 auto-cols-fr mb-8">
{paginatedSiteResources.map((siteResource) => (
<Card key={siteResource.siteResourceId}>
<div className="p-6">
<div className="flex items-center justify-between gap-3">
<div className="flex items-center min-w-0 flex-1 gap-3 overflow-hidden">
<TooltipProvider>
<Tooltip>
<TooltipTrigger className="min-w-0 max-w-full">
<CardTitle className="text-lg font-bold text-foreground truncate group-hover:text-primary transition-colors">
{siteResource.name}
</CardTitle>
</TooltipTrigger>
<TooltipContent>
<p className="max-w-xs break-words">
{siteResource.name}
</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
<div className="flex-shrink-0">
<InfoPopup>
<div className="space-y-2 text-sm">
<div className="text-xs font-medium mb-1.5">Resource Details</div>
<div>
<span className="font-medium">Mode:</span>
<span className="ml-2 text-muted-foreground capitalize">
{siteResource.mode}
</span>
</div>
{siteResource.protocol && (
<div>
<span className="font-medium">Protocol:</span>
<span className="ml-2 text-muted-foreground uppercase">
{siteResource.protocol}
</span>
</div>
)}
{siteResource.alias && (
<div>
<span className="font-medium">Alias:</span>
<span className="ml-2 text-muted-foreground">
{siteResource.alias}
</span>
</div>
)}
{siteResource.aliasAddress && (
<div>
<span className="font-medium">Alias Address:</span>
<span className="ml-2 text-muted-foreground">
{siteResource.aliasAddress}
</span>
</div>
)}
<div>
<span className="font-medium">Status:</span>
<span className={`ml-2 ${siteResource.enabled ? 'text-green-600' : 'text-red-600'}`}>
{siteResource.enabled ? 'Enabled' : 'Disabled'}
</span>
</div>
</div>
</InfoPopup>
</div>
</div>
<div className="mt-3">
{siteResource.alias ? (
<>
{/* Alias as primary */}
<div className="flex items-center gap-2 mb-1">
<div className="text-base font-semibold text-foreground text-left truncate flex-1">
{siteResource.alias}
</div>
<Button
variant="ghost"
size="icon"
className="h-8 w-8 text-muted-foreground"
onClick={() => {
navigator.clipboard.writeText(
siteResource.alias!
);
toast({
title: "Copied to clipboard",
description:
"Resource alias has been copied to your clipboard.",
duration: 2000
});
}}
>
<Copy className="h-4 w-4" />
</Button>
</div>
{/* Destination as secondary */}
<div className="text-xs text-muted-foreground truncate">
{siteResource.destination}
</div>
</>
) : (
/* Destination as primary when no alias */
<div className="flex items-center gap-2">
<div className="text-sm text-muted-foreground font-medium text-left truncate flex-1">
{siteResource.destination}
</div>
<Button
variant="ghost"
size="icon"
className="h-8 w-8 text-muted-foreground"
onClick={() => {
navigator.clipboard.writeText(
siteResource.destination
);
toast({
title: "Copied to clipboard",
description:
"Resource destination has been copied to your clipboard.",
duration: 2000
});
}}
>
<Copy className="h-4 w-4" />
</Button>
</div>
)}
</div>
</div>
<div className="p-6 pt-0 mt-auto">
<div className="flex items-center justify-center py-2 px-4 bg-muted/50 rounded text-sm text-muted-foreground">
<Combine className="h-3.5 w-3.5 mr-2" />
Requires Client Connection
</div>
</div>
</Card>
))}
</div>
</>
)}
{/* Pagination Controls */} {/* Pagination Controls */}
<PaginationControls <PaginationControls
currentPage={currentPage} currentPage={currentPage}
totalPages={totalPages} totalPages={totalPages}
onPageChange={handlePageChange} onPageChange={handlePageChange}
totalItems={filteredResources.length} totalItems={totalItems}
itemsPerPage={itemsPerPage} itemsPerPage={itemsPerPage}
/> />
</> </>

View File

@@ -20,12 +20,13 @@ import {
TooltipProvider, TooltipProvider,
TooltipTrigger TooltipTrigger
} from "@app/components/ui/tooltip"; } from "@app/components/ui/tooltip";
import { Badge } from "@app/components/ui/badge";
import { useEnvContext } from "@app/hooks/useEnvContext"; import { useEnvContext } from "@app/hooks/useEnvContext";
import { cn } from "@app/lib/cn"; import { cn } from "@app/lib/cn";
import { ListUserOrgsResponse } from "@server/routers/org"; import { ListUserOrgsResponse } from "@server/routers/org";
import { Check, ChevronsUpDown, Plus, Building2, Users } from "lucide-react"; import { Check, ChevronsUpDown, Plus, Building2, Users } from "lucide-react";
import { useRouter } from "next/navigation"; import { usePathname, useRouter } from "next/navigation";
import { useState } from "react"; import { useMemo, useState } from "react";
import { useUserContext } from "@app/hooks/useUserContext"; import { useUserContext } from "@app/hooks/useUserContext";
import { useTranslations } from "next-intl"; import { useTranslations } from "next-intl";
@@ -43,11 +44,23 @@ export function OrgSelector({
const { user } = useUserContext(); const { user } = useUserContext();
const [open, setOpen] = useState(false); const [open, setOpen] = useState(false);
const router = useRouter(); const router = useRouter();
const pathname = usePathname();
const { env } = useEnvContext(); const { env } = useEnvContext();
const t = useTranslations(); const t = useTranslations();
const selectedOrg = orgs?.find((org) => org.orgId === orgId); const selectedOrg = orgs?.find((org) => org.orgId === orgId);
const sortedOrgs = useMemo(() => {
if (!orgs?.length) return orgs ?? [];
return [...orgs].sort((a, b) => {
const aPrimary = Boolean(a.isPrimaryOrg);
const bPrimary = Boolean(b.isPrimaryOrg);
if (aPrimary && !bPrimary) return -1;
if (!aPrimary && bPrimary) return 1;
return 0;
});
}, [orgs]);
const orgSelectorContent = ( const orgSelectorContent = (
<Popover open={open} onOpenChange={setOpen}> <Popover open={open} onOpenChange={setOpen}>
<PopoverTrigger asChild> <PopoverTrigger asChild>
@@ -83,7 +96,7 @@ export function OrgSelector({
<PopoverContent className="w-[320px] p-0" align="start"> <PopoverContent className="w-[320px] p-0" align="start">
<Command className="rounded-lg"> <Command className="rounded-lg">
<CommandInput <CommandInput
placeholder={t("searchProgress")} placeholder={t("searchPlaceholder")}
className="border-0 focus:ring-0" className="border-0 focus:ring-0"
/> />
<CommandEmpty className="py-6 text-center"> <CommandEmpty className="py-6 text-center">
@@ -124,25 +137,39 @@ export function OrgSelector({
)} )}
<CommandGroup heading={t("orgs")} className="py-2"> <CommandGroup heading={t("orgs")} className="py-2">
<CommandList> <CommandList>
{orgs?.map((org) => ( {sortedOrgs.map((org) => (
<CommandItem <CommandItem
key={org.orgId} key={org.orgId}
onSelect={() => { onSelect={() => {
setOpen(false); setOpen(false);
router.push(`/${org.orgId}/settings`); const newPath = pathname.replace(
/^\/[^/]+/,
`/${org.orgId}`
);
router.push(newPath);
}} }}
className="mx-2 rounded-md" className="mx-2 rounded-md"
> >
<div className="flex items-center justify-center w-8 h-8 rounded-lg bg-muted mr-3"> <div className="flex items-center justify-center w-8 h-8 rounded-lg bg-muted mr-3">
<Users className="h-4 w-4 text-muted-foreground" /> <Users className="h-4 w-4 text-muted-foreground" />
</div> </div>
<div className="flex flex-col flex-1"> <div className="flex flex-col flex-1 min-w-0">
<span className="font-medium"> <span className="font-medium truncate">
{org.name} {org.name}
</span> </span>
<span className="text-xs text-muted-foreground"> <div className="flex items-center gap-2 min-w-0">
{t("organization")} <span className="text-xs text-muted-foreground font-mono truncate">
{org.orgId}
</span> </span>
{org.isPrimaryOrg && (
<Badge
variant="outline"
className="shrink-0 text-[10px] px-1.5 py-0 font-medium ml-auto"
>
{t("primary")}
</Badge>
)}
</div>
</div> </div>
<Check <Check
className={cn( className={cn(

View File

@@ -2,9 +2,8 @@
import ConfirmDeleteDialog from "@app/components/ConfirmDeleteDialog"; import ConfirmDeleteDialog from "@app/components/ConfirmDeleteDialog";
import CopyToClipboard from "@app/components/CopyToClipboard"; import CopyToClipboard from "@app/components/CopyToClipboard";
import { DataTable } from "@app/components/ui/data-table";
import { ExtendedColumnDef } from "@app/components/ui/data-table";
import { Button } from "@app/components/ui/button"; import { Button } from "@app/components/ui/button";
import { ExtendedColumnDef } from "@app/components/ui/data-table";
import { import {
DropdownMenu, DropdownMenu,
DropdownMenuContent, DropdownMenuContent,
@@ -14,13 +13,14 @@ import {
import { InfoPopup } from "@app/components/ui/info-popup"; import { InfoPopup } from "@app/components/ui/info-popup";
import { Switch } from "@app/components/ui/switch"; import { Switch } from "@app/components/ui/switch";
import { useEnvContext } from "@app/hooks/useEnvContext"; import { useEnvContext } from "@app/hooks/useEnvContext";
import { useNavigationContext } from "@app/hooks/useNavigationContext";
import { toast } from "@app/hooks/useToast"; import { toast } from "@app/hooks/useToast";
import { createApiClient, formatAxiosError } from "@app/lib/api"; import { createApiClient, formatAxiosError } from "@app/lib/api";
import { UpdateResourceResponse } from "@server/routers/resource"; import { UpdateResourceResponse } from "@server/routers/resource";
import type { PaginationState } from "@tanstack/react-table";
import { AxiosResponse } from "axios"; import { AxiosResponse } from "axios";
import { import {
ArrowRight, ArrowRight,
ArrowUpDown,
CheckCircle2, CheckCircle2,
ChevronDown, ChevronDown,
Clock, Clock,
@@ -32,14 +32,24 @@ import {
import { useTranslations } from "next-intl"; import { useTranslations } from "next-intl";
import Link from "next/link"; import Link from "next/link";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import { useState, useTransition } from "react"; import {
useOptimistic,
useRef,
useState,
useTransition,
type ComponentRef
} from "react";
import { useDebouncedCallback } from "use-debounce";
import z from "zod";
import { ColumnFilterButton } from "./ColumnFilterButton";
import { ControlledDataTable } from "./ui/controlled-data-table";
export type TargetHealth = { export type TargetHealth = {
targetId: number; targetId: number;
ip: string; ip: string;
port: number; port: number;
enabled: boolean; enabled: boolean;
healthStatus?: "healthy" | "unhealthy" | "unknown"; healthStatus: "healthy" | "unhealthy" | "unknown" | null;
}; };
export type ResourceRow = { export type ResourceRow = {
@@ -117,18 +127,22 @@ function StatusIcon({
type ProxyResourcesTableProps = { type ProxyResourcesTableProps = {
resources: ResourceRow[]; resources: ResourceRow[];
orgId: string; orgId: string;
defaultSort?: { pagination: PaginationState;
id: string; rowCount: number;
desc: boolean;
};
}; };
export default function ProxyResourcesTable({ export default function ProxyResourcesTable({
resources, resources,
orgId, orgId,
defaultSort pagination,
rowCount
}: ProxyResourcesTableProps) { }: ProxyResourcesTableProps) {
const router = useRouter(); const router = useRouter();
const {
navigate: filter,
isNavigating: isFiltering,
searchParams
} = useNavigationContext();
const t = useTranslations(); const t = useTranslations();
const { env } = useEnvContext(); const { env } = useEnvContext();
@@ -140,6 +154,7 @@ export default function ProxyResourcesTable({
useState<ResourceRow | null>(); useState<ResourceRow | null>();
const [isRefreshing, startTransition] = useTransition(); const [isRefreshing, startTransition] = useTransition();
const [isNavigatingToAddPage, startNavigation] = useTransition();
const refreshData = () => { const refreshData = () => {
startTransition(() => { startTransition(() => {
@@ -174,14 +189,15 @@ export default function ProxyResourcesTable({
}; };
async function toggleResourceEnabled(val: boolean, resourceId: number) { async function toggleResourceEnabled(val: boolean, resourceId: number) {
await api try {
.post<AxiosResponse<UpdateResourceResponse>>( await api.post<AxiosResponse<UpdateResourceResponse>>(
`resource/${resourceId}`, `resource/${resourceId}`,
{ {
enabled: val enabled: val
} }
) );
.catch((e) => { router.refresh();
} catch (e) {
toast({ toast({
variant: "destructive", variant: "destructive",
title: t("resourcesErrorUpdate"), title: t("resourcesErrorUpdate"),
@@ -190,7 +206,7 @@ export default function ProxyResourcesTable({
t("resourcesErrorUpdateDescription") t("resourcesErrorUpdateDescription")
) )
}); });
}); }
} }
function TargetStatusCell({ targets }: { targets?: TargetHealth[] }) { function TargetStatusCell({ targets }: { targets?: TargetHealth[] }) {
@@ -236,7 +252,7 @@ export default function ProxyResourcesTable({
<ChevronDown className="h-3 w-3" /> <ChevronDown className="h-3 w-3" />
</Button> </Button>
</DropdownMenuTrigger> </DropdownMenuTrigger>
<DropdownMenuContent align="start" className="min-w-[280px]"> <DropdownMenuContent align="start" className="min-w-70">
{monitoredTargets.length > 0 && ( {monitoredTargets.length > 0 && (
<> <>
{monitoredTargets.map((target) => ( {monitoredTargets.map((target) => (
@@ -302,38 +318,14 @@ export default function ProxyResourcesTable({
accessorKey: "name", accessorKey: "name",
enableHiding: false, enableHiding: false,
friendlyName: t("name"), friendlyName: t("name"),
header: ({ column }) => { header: () => <span className="p-3">{t("name")}</span>
return (
<Button
variant="ghost"
onClick={() =>
column.toggleSorting(column.getIsSorted() === "asc")
}
>
{t("name")}
<ArrowUpDown className="ml-2 h-4 w-4" />
</Button>
);
}
}, },
{ {
id: "niceId", id: "niceId",
accessorKey: "nice", accessorKey: "nice",
friendlyName: t("identifier"), friendlyName: t("identifier"),
enableHiding: true, enableHiding: true,
header: ({ column }) => { header: () => <span className="p-3">{t("identifier")}</span>,
return (
<Button
variant="ghost"
onClick={() =>
column.toggleSorting(column.getIsSorted() === "asc")
}
>
{t("identifier")}
<ArrowUpDown className="ml-2 h-4 w-4" />
</Button>
);
},
cell: ({ row }) => { cell: ({ row }) => {
return <span>{row.original.nice || "-"}</span>; return <span>{row.original.nice || "-"}</span>;
} }
@@ -359,19 +351,33 @@ export default function ProxyResourcesTable({
id: "status", id: "status",
accessorKey: "status", accessorKey: "status",
friendlyName: t("status"), friendlyName: t("status"),
header: ({ column }) => { header: () => (
return ( <ColumnFilterButton
<Button options={[
variant="ghost" { value: "healthy", label: t("resourcesTableHealthy") },
onClick={() => {
column.toggleSorting(column.getIsSorted() === "asc") value: "degraded",
} label: t("resourcesTableDegraded")
>
{t("status")}
<ArrowUpDown className="ml-2 h-4 w-4" />
</Button>
);
}, },
{ value: "offline", label: t("resourcesTableOffline") },
{
value: "no_targets",
label: t("resourcesTableNoTargets")
},
{ value: "unknown", label: t("resourcesTableUnknown") }
]}
selectedValue={
searchParams.get("healthStatus") ?? undefined
}
onValueChange={(value) =>
handleFilterChange("healthStatus", value)
}
searchPlaceholder={t("searchPlaceholder")}
emptyMessage={t("emptySearchOptions")}
label={t("status")}
className="p-3"
/>
),
cell: ({ row }) => { cell: ({ row }) => {
const resourceRow = row.original; const resourceRow = row.original;
return <TargetStatusCell targets={resourceRow.targets} />; return <TargetStatusCell targets={resourceRow.targets} />;
@@ -419,19 +425,23 @@ export default function ProxyResourcesTable({
{ {
accessorKey: "authState", accessorKey: "authState",
friendlyName: t("authentication"), friendlyName: t("authentication"),
header: ({ column }) => { header: () => (
return ( <ColumnFilterButton
<Button options={[
variant="ghost" { value: "protected", label: t("protected") },
onClick={() => { value: "not_protected", label: t("notProtected") },
column.toggleSorting(column.getIsSorted() === "asc") { value: "none", label: t("none") }
]}
selectedValue={searchParams.get("authState") ?? undefined}
onValueChange={(value) =>
handleFilterChange("authState", value)
} }
> searchPlaceholder={t("searchPlaceholder")}
{t("authentication")} emptyMessage={t("emptySearchOptions")}
<ArrowUpDown className="ml-2 h-4 w-4" /> label={t("authentication")}
</Button> className="p-3"
); />
}, ),
cell: ({ row }) => { cell: ({ row }) => {
const resourceRow = row.original; const resourceRow = row.original;
return ( return (
@@ -456,20 +466,28 @@ export default function ProxyResourcesTable({
{ {
accessorKey: "enabled", accessorKey: "enabled",
friendlyName: t("enabled"), friendlyName: t("enabled"),
header: () => <span className="p-3">{t("enabled")}</span>, header: () => (
<ColumnFilterButton
options={[
{ value: "true", label: t("enabled") },
{ value: "false", label: t("disabled") }
]}
selectedValue={booleanSearchFilterSchema.parse(
searchParams.get("enabled")
)}
onValueChange={(value) =>
handleFilterChange("enabled", value)
}
searchPlaceholder={t("searchPlaceholder")}
emptyMessage={t("emptySearchOptions")}
label={t("enabled")}
className="p-3"
/>
),
cell: ({ row }) => ( cell: ({ row }) => (
<Switch <ResourceEnabledForm
defaultChecked={ resource={row.original}
row.original.http onToggleResourceEnabled={toggleResourceEnabled}
? !!row.original.domainId && row.original.enabled
: row.original.enabled
}
disabled={
row.original.http ? !row.original.domainId : false
}
onCheckedChange={(val) =>
toggleResourceEnabled(val, row.original.id)
}
/> />
) )
}, },
@@ -525,6 +543,42 @@ export default function ProxyResourcesTable({
} }
]; ];
const booleanSearchFilterSchema = z
.enum(["true", "false"])
.optional()
.catch(undefined);
function handleFilterChange(
column: string,
value: string | undefined | null
) {
searchParams.delete(column);
searchParams.delete("page");
if (value) {
searchParams.set(column, value);
}
filter({
searchParams
});
}
const handlePaginationChange = (newPage: PaginationState) => {
searchParams.set("page", (newPage.pageIndex + 1).toString());
searchParams.set("pageSize", newPage.pageSize.toString());
filter({
searchParams
});
};
const handleSearchChange = useDebouncedCallback((query: string) => {
searchParams.set("query", query);
searchParams.delete("page");
filter({
searchParams
});
}, 300);
return ( return (
<> <>
{selectedResource && ( {selectedResource && (
@@ -547,21 +601,25 @@ export default function ProxyResourcesTable({
/> />
)} )}
<DataTable <ControlledDataTable
columns={proxyColumns} columns={proxyColumns}
data={resources} rows={resources}
persistPageSize="proxy-resources" tableId="proxy-resources"
searchPlaceholder={t("resourcesSearch")} searchPlaceholder={t("resourcesSearch")}
searchColumn="name" pagination={pagination}
rowCount={rowCount}
onSearch={handleSearchChange}
onPaginationChange={handlePaginationChange}
onAdd={() => onAdd={() =>
startNavigation(() =>
router.push(`/${orgId}/settings/resources/proxy/create`) router.push(`/${orgId}/settings/resources/proxy/create`)
)
} }
addButtonText={t("resourceAdd")} addButtonText={t("resourceAdd")}
onRefresh={refreshData} onRefresh={refreshData}
isRefreshing={isRefreshing} isRefreshing={isRefreshing || isFiltering}
defaultSort={defaultSort} isNavigatingToAddPage={isNavigatingToAddPage}
enableColumnVisibility={true} enableColumnVisibility
persistColumnVisibility="proxy-resources"
columnVisibility={{ niceId: false }} columnVisibility={{ niceId: false }}
stickyLeftColumn="name" stickyLeftColumn="name"
stickyRightColumn="actions" stickyRightColumn="actions"
@@ -569,3 +627,43 @@ export default function ProxyResourcesTable({
</> </>
); );
} }
type ResourceEnabledFormProps = {
resource: ResourceRow;
onToggleResourceEnabled: (
val: boolean,
resourceId: number
) => Promise<void>;
};
function ResourceEnabledForm({
resource,
onToggleResourceEnabled
}: ResourceEnabledFormProps) {
const enabled = resource.http
? !!resource.domainId && resource.enabled
: resource.enabled;
const [optimisticEnabled, setOptimisticEnabled] = useOptimistic(enabled);
const formRef = useRef<ComponentRef<"form">>(null);
async function submitAction(formData: FormData) {
const newEnabled = !(formData.get("enabled") === "on");
setOptimisticEnabled(newEnabled);
await onToggleResourceEnabled(newEnabled, resource.id);
}
return (
<form action={submitAction} ref={formRef}>
<Switch
checked={optimisticEnabled}
disabled={
(resource.http && !resource.domainId) ||
optimisticEnabled !== enabled
}
name="enabled"
onCheckedChange={() => formRef.current?.requestSubmit()}
/>
</form>
);
}

View File

@@ -72,6 +72,7 @@ type SignupFormProps = {
inviteToken?: string; inviteToken?: string;
emailParam?: string; emailParam?: string;
fromSmartLogin?: boolean; fromSmartLogin?: boolean;
skipVerificationEmail?: boolean;
}; };
const formSchema = z const formSchema = z
@@ -103,7 +104,8 @@ export default function SignupForm({
inviteId, inviteId,
inviteToken, inviteToken,
emailParam, emailParam,
fromSmartLogin = false fromSmartLogin = false,
skipVerificationEmail = false
}: SignupFormProps) { }: SignupFormProps) {
const router = useRouter(); const router = useRouter();
const { env } = useEnvContext(); const { env } = useEnvContext();
@@ -147,7 +149,8 @@ export default function SignupForm({
inviteToken, inviteToken,
termsAcceptedTimestamp: termsAgreedAt, termsAcceptedTimestamp: termsAgreedAt,
marketingEmailConsent: marketingEmailConsent:
build === "saas" ? marketingEmailConsent : undefined build === "saas" ? marketingEmailConsent : undefined,
skipVerificationEmail: skipVerificationEmail || undefined
}) })
.catch((e) => { .catch((e) => {
console.error(e); console.error(e);

Some files were not shown because too many files have changed in this diff Show More