diff --git a/src/config.ts b/src/config.ts index 7b573c1c..1bd95567 100644 --- a/src/config.ts +++ b/src/config.ts @@ -114,6 +114,7 @@ type StorageConfigType = { requestUrlLengthLimit: number requestXForwardedHostRegExp?: string requestAllowXForwardedPrefix?: boolean + storagePublicUrl?: string logLevel?: string logflareEnabled?: boolean logflareApiKey?: string @@ -284,6 +285,7 @@ export function getConfig(options?: { reload?: boolean }): StorageConfigType { ), requestAllowXForwardedPrefix: getOptionalConfigFromEnv('REQUEST_ALLOW_X_FORWARDED_PATH') === 'true', + storagePublicUrl: getOptionalConfigFromEnv('STORAGE_PUBLIC_URL'), requestUrlLengthLimit: Number(getOptionalConfigFromEnv('REQUEST_URL_LENGTH_LIMIT', 'URL_LENGTH_LIMIT')) || 7_500, requestTraceHeader: getOptionalConfigFromEnv('REQUEST_TRACE_HEADER', 'REQUEST_ID_HEADER'), diff --git a/src/http/routes/tus/lifecycle.ts b/src/http/routes/tus/lifecycle.ts index 6ab40b71..610bfb1e 100644 --- a/src/http/routes/tus/lifecycle.ts +++ b/src/http/routes/tus/lifecycle.ts @@ -11,7 +11,8 @@ import type { ServerRequest as Request } from 'srvx' import { getConfig } from '../../../config' -const { storageS3Bucket, tusPath, requestAllowXForwardedPrefix } = getConfig() +const { storageS3Bucket, tusPath, requestAllowXForwardedPrefix, storagePublicUrl } = getConfig() +const parsedPublicUrl = storagePublicUrl ? new URL(storagePublicUrl) : undefined const reExtractFileID = /([^/]+)\/?$/ export const SIGNED_URL_SUFFIX = '/sign' @@ -112,6 +113,11 @@ export function generateUrl( throw ERRORS.InvalidParameter('url') } + if (parsedPublicUrl) { + proto = parsedPublicUrl.protocol.replace(':', '') + host = parsedPublicUrl.host + } + proto = process.env.NODE_ENV === 'production' ? 'https' : proto let basePath = path @@ -125,7 +131,7 @@ export function generateUrl( const isSigned = req.url?.endsWith(SIGNED_URL_SUFFIX) const fullPath = isSigned ? `${basePath}${SIGNED_URL_SUFFIX}` : basePath - if (req.headers['x-forwarded-host']) { + if (!parsedPublicUrl && req.headers['x-forwarded-host']) { const port = req.headers['x-forwarded-port'] if (typeof port === 'string' && port && !['443', '80'].includes(port)) {