Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 50 additions & 6 deletions extern/threestudio/gradio_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
from pathlib import Path
import json

import gradio as gr
import numpy as np
Expand Down Expand Up @@ -202,7 +204,49 @@ def run(
# manually assign the output directory, name and tag so that we know the trial directory
name = os.path.basename(model_config[model_name]["path"]).split(".")[0]
tag = datetime.now().strftime("@%Y%m%d-%H%M%S")
trial_dir = os.path.join(save_root, EXP_ROOT_DIR, name, tag)

# normalize and validate save_root to avoid using unexpected directories
base_root = (Path.cwd() / EXP_ROOT_DIR).resolve()
try:
candidate_root = Path(save_root).expanduser().resolve()
# ensure the requested root stays within the base_root
_ = candidate_root.relative_to(base_root)
safe_save_root = candidate_root
except Exception:
# fall back to the default base_root on invalid or unsafe input
safe_save_root = base_root

# sanitize and constrain user-provided values before constructing the command
# Ensure prompt is safely escaped so it cannot break the expected key="value" syntax.
try:
# json.dumps returns a double-quoted string; strip the outer quotes so we can
# embed it inside key="...".
safe_prompt_inner = json.dumps(str(prompt))[1:-1]
except Exception:
safe_prompt_inner = ""

# Constrain guidance_scale to a reasonable numeric range.
try:
safe_guidance_scale = float(guidance_scale)
except (TypeError, ValueError):
safe_guidance_scale = 0.0
safe_guidance_scale = max(0.0, min(100.0, safe_guidance_scale))

# Constrain seed to a 32-bit signed integer range.
try:
safe_seed = int(seed)
except (TypeError, ValueError):
safe_seed = 0
safe_seed = max(0, min(2147483647, safe_seed))

# Constrain max_steps to a positive integer within an upper bound.
try:
safe_max_steps = int(max_steps)
except (TypeError, ValueError):
safe_max_steps = 1
safe_max_steps = max(1, min(20000, safe_max_steps))

trial_dir = os.path.join(str(safe_save_root), name, tag)
alive_path = os.path.join(trial_dir, "alive")

# spawn the training process
Expand All @@ -212,12 +256,12 @@ def run(
+ [
f'name="{name}"',
f'tag="{tag}"',
f"exp_root_dir={os.path.join(save_root, EXP_ROOT_DIR)}",
f"exp_root_dir={str(safe_save_root)}",
"use_timestamp=false",
f'system.prompt_processor.prompt="{prompt}"',
f"system.guidance.guidance_scale={guidance_scale}",
f"seed={seed}",
f"trainer.max_steps={max_steps}",
f'system.prompt_processor.prompt="{safe_prompt_inner}"',
f"system.guidance.guidance_scale={safe_guidance_scale}",
f"seed={safe_seed}",
f"trainer.max_steps={safe_max_steps}",
]
+ (
["checkpoint.every_n_train_steps=${trainer.max_steps}"] if save_ckpt else []
Expand Down