From 3c261e9a1d8694c3d54708c9ce45dea4a9228662 Mon Sep 17 00:00:00 2001 From: begna112 Date: Thu, 12 Feb 2026 22:36:22 -0600 Subject: [PATCH] feat: CLI bug fixes, improved help text, and test infrastructure Bug fixes: - Fix mutable default arguments in http_put/post/del - Add missing exception handling in http_request - Fix timezone handling with UTC-aware timestamps - Fix show_instances row rebinding bug - Fix show_machine dict response handling - Fix parse_query field alias data loss - Remove hardcoded if True anti-patterns - Fix raw mode to return parsed JSON consistently - Fix namespace attribute typo - Replace bare except clauses with specific exceptions - Remove shell=True from subprocess calls - Fix self-test instance creation type handling and error messages Infrastructure improvements: - Add request timeout support with configurable defaults - Expand retry logic to cover 502/503/504 and connection errors - Add JSONDecodeError protection to high-risk .json() sites - Consolidate direct requests calls through http_* helpers - Add structured JSON error output for --raw mode - Add raw mode return handling to all command functions - Remove unreachable code after raise_for_status() - Add api_call/output_result/error_output DRY helpers - Add 2FA session expiry handling with automatic key fallback - Filter sensitive fields from show endpoints output CLI improvements: - Improve help text and descriptions for 130+ commands - Add named flags for transfer credit (--sender, --recipient) - Fix argument mismatches in usage strings - Standardize show maints to use --ids instead of -ids - Remove Python 2 compatibility code Test infrastructure: - Add pytest with regression, smoke, and unit test suites - Add CI workflow with lint, test, and smoke-test jobs - Add ruff configuration for code quality - Add dev dependencies to pyproject.toml --- .github/workflows/ci.yml | 94 + .gitignore | 13 + __init__.py | 5 +- conftest.py | 1 + pyproject.toml | 50 +- tests/__init__.py | 0 tests/conftest.py | 76 + tests/regression/__init__.py | 0 tests/regression/test_bare_except.py | 102 + tests/regression/test_direct_requests.py | 146 + tests/regression/test_hardcoded_raw.py | 61 + tests/regression/test_http_exception.py | 72 + tests/regression/test_json_decode.py | 128 + tests/regression/test_main_error_handling.py | 221 + tests/regression/test_main_raw_output.py | 199 + tests/regression/test_mutable_defaults.py | 54 + tests/regression/test_namespace_typo.py | 23 + tests/regression/test_parse_query.py | 59 + tests/regression/test_raw_completeness.py | 176 + tests/regression/test_raw_errors.py | 188 + tests/regression/test_retry.py | 233 + tests/regression/test_return_response.py | 58 + tests/regression/test_safe_dict_access.py | 237 + tests/regression/test_shell_injection.py | 116 + tests/regression/test_show_instances.py | 90 + tests/regression/test_show_machine.py | 58 + tests/regression/test_timeout.py | 129 + tests/regression/test_timezone_handling.py | 41 + tests/regression/test_unreachable_code.py | 227 + tests/regression/test_utc_display.py | 97 + tests/regression/test_utcfromtimestamp.py | 93 + tests/smoke/__init__.py | 1 + tests/smoke/test_cli_commands.py | 429 ++ tests/smoke/test_standalone.py | 275 + tests/unit/__init__.py | 1 + tests/unit/test_http_helpers.py | 267 + tests/unit/test_query_parser.py | 160 + tests/unit/test_timezone.py | 125 + vast.py | 6200 ++++++++++-------- 39 files changed, 7662 insertions(+), 2843 deletions(-) create mode 100644 .github/workflows/ci.yml create mode 100644 conftest.py create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/regression/__init__.py create mode 100644 tests/regression/test_bare_except.py create mode 100644 tests/regression/test_direct_requests.py create mode 100644 tests/regression/test_hardcoded_raw.py create mode 100644 tests/regression/test_http_exception.py create mode 100644 tests/regression/test_json_decode.py create mode 100644 tests/regression/test_main_error_handling.py create mode 100644 tests/regression/test_main_raw_output.py create mode 100644 tests/regression/test_mutable_defaults.py create mode 100644 tests/regression/test_namespace_typo.py create mode 100644 tests/regression/test_parse_query.py create mode 100644 tests/regression/test_raw_completeness.py create mode 100644 tests/regression/test_raw_errors.py create mode 100644 tests/regression/test_retry.py create mode 100644 tests/regression/test_return_response.py create mode 100644 tests/regression/test_safe_dict_access.py create mode 100644 tests/regression/test_shell_injection.py create mode 100644 tests/regression/test_show_instances.py create mode 100644 tests/regression/test_show_machine.py create mode 100644 tests/regression/test_timeout.py create mode 100644 tests/regression/test_timezone_handling.py create mode 100644 tests/regression/test_unreachable_code.py create mode 100644 tests/regression/test_utc_display.py create mode 100644 tests/regression/test_utcfromtimestamp.py create mode 100644 tests/smoke/__init__.py create mode 100644 tests/smoke/test_cli_commands.py create mode 100644 tests/smoke/test_standalone.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/test_http_helpers.py create mode 100644 tests/unit/test_query_parser.py create mode 100644 tests/unit/test_timezone.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..1dbbd9d4 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,94 @@ +name: CI + +on: + push: + branches: + - master + - main + pull_request: + branches: + - master + - main + +jobs: + lint: + name: Lint (Ruff) + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install Ruff + run: pip install ruff + + - name: Run Ruff check + run: ruff check . --output-format=github + + - name: Run Ruff format check + run: ruff format --check . + + test: + name: Test (${{ matrix.os }} / Python ${{ matrix.python-version }}) + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + + - name: Install dependencies + run: pip install -e ".[dev]" + + - name: Run pytest with coverage + run: pytest --cov=vast --cov-report=xml + + - name: Upload coverage to Codecov + if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11' + uses: codecov/codecov-action@v4 + with: + files: ./coverage.xml + fail_ci_if_error: false + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + + smoke-standalone: + name: Smoke Test Standalone (${{ matrix.os }}) + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install minimal dependencies (standalone mode) + run: pip install requests python-dateutil + + - name: Test vast.py --help + run: python vast.py --help + + - name: Test vast.py search offers --help + run: python vast.py search offers --help + + - name: Test vast.py show instances --help + run: python vast.py show instances --help diff --git a/.gitignore b/.gitignore index a3df4740..28390020 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,17 @@ passed_machines.txt failed_machines.txt Pass_testresults.log dist/ +build/ __pycache__/ +*.egg-info/ +*.egg +.eggs/ +*.pyc +*.pyo + +# Test artifacts +.coverage +.pytest_cache/ + +# MkDocs build output +site/ diff --git a/__init__.py b/__init__.py index 426f99d8..f2428881 100644 --- a/__init__.py +++ b/__init__.py @@ -1 +1,4 @@ -from .vastai_sdk import VastAI \ No newline at end of file +try: + from .vastai_sdk import VastAI +except ImportError: + pass diff --git a/conftest.py b/conftest.py new file mode 100644 index 00000000..ecd3474f --- /dev/null +++ b/conftest.py @@ -0,0 +1 @@ +collect_ignore = ["__init__.py", "vast.py", "vast_pdf.py", "vast_config.py"] diff --git a/pyproject.toml b/pyproject.toml index aeb35545..f22db580 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,11 +36,22 @@ dependencies = [ "cryptography (>=44.0.2,<45.0.0)", "rich", "fonttools>=4.60.2", - "qrcode" + "qrcode", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0", + "pytest-cov>=4.0", + "mypy>=1.14", + "types-requests>=2.32", + "types-python-dateutil>=2.9", + "ruff>=0.9", + "pre-commit>=4.0", ] [tool.poetry] -packages = [{ include = "utils" }, { include = "vast.py" }] +packages = [{ include = "utils" }, { include = "vast.py" }, { include = "vast_config.py" }] version = "0.0.0" [project.scripts] @@ -58,3 +69,38 @@ style = "semver" [tool.poetry.requires-plugins] poetry-dynamic-versioning = { version = ">=1.0.0,<2.0.0", extras = ["plugin"] } + +[tool.pytest.ini_options] +testpaths = ["tests"] + +[tool.coverage.run] +branch = true +source = ["."] +omit = ["*/tests/*", "*/.venv/*", "*/site-packages/*"] + +[tool.coverage.report] +precision = 2 +show_missing = true + +[tool.ruff] +line-length = 120 +target-version = "py310" +exclude = [".git", ".venv", "__pycache__", "build", "dist", ".eggs", "*.egg-info"] + +[tool.ruff.lint] +select = ["E", "W", "F", "I", "B", "C4", "UP"] +ignore = ["E501"] + +[tool.ruff.lint.per-file-ignores] +"vast.py" = ["T20", "A001", "A002", "E501", "B006", "F401", "F523", "F541", "F811", "F841", "E701", "E703", "E711", "E713", "E721", "E741"] +"vast_pdf.py" = ["T20", "E501"] +"tests/**/*.py" = ["F401", "S101"] + +[tool.ruff.lint.isort] +known-first-party = ["vast"] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..efd75be9 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,76 @@ +import json +import os +import pytest +import argparse +import requests +from unittest.mock import MagicMock, patch + + +@pytest.fixture +def mock_args(): + """Minimal argparse.Namespace for testing CLI functions.""" + return argparse.Namespace( + api_key="test-key", + url="https://console.vast.ai", + retry=3, + raw=False, + explain=False, + quiet=False, + curl=False, + full=False, + no_color=True, + debugging=False, + ) + + +@pytest.fixture +def mock_response(): + """Mock HTTP response with configurable status and JSON body.""" + response = MagicMock() + response.status_code = 200 + response.json.return_value = {"success": True} + response.text = '{"success": true}' + response.content = b'{"success": true}' + response.headers = {"Content-Type": "application/json"} + response.raise_for_status = MagicMock() + return response + + +@pytest.fixture +def mock_api_response(): + """Factory fixture for creating mock API responses with configurable status and data.""" + def _make_response(status_code=200, json_data=None, text=None, headers=None): + response = MagicMock() + response.status_code = status_code + response.json.return_value = json_data if json_data is not None else {} + response.text = text if text is not None else json.dumps(json_data or {}) + response.content = response.text.encode() + response.headers = headers if headers is not None else {"Content-Type": "application/json"} + if status_code >= 400: + response.raise_for_status.side_effect = requests.HTTPError(f"{status_code} Error") + else: + response.raise_for_status = MagicMock() + return response + return _make_response + + +@pytest.fixture +def mock_http_get(mock_api_response): + """Patch vast.http_get to return controlled responses.""" + with patch('vast.http_get') as mock: + mock.return_value = mock_api_response(200, {"success": True}) + yield mock + + +@pytest.fixture +def mock_http_post(mock_api_response): + """Patch vast.http_post to return controlled responses.""" + with patch('vast.http_post') as mock: + mock.return_value = mock_api_response(200, {"success": True}) + yield mock + + +@pytest.fixture +def vast_cli_path(): + """Return path to vast.py for subprocess tests.""" + return os.path.join(os.path.dirname(__file__), '..', 'vast.py') diff --git a/tests/regression/__init__.py b/tests/regression/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/regression/test_bare_except.py b/tests/regression/test_bare_except.py new file mode 100644 index 00000000..a2ac0ebb --- /dev/null +++ b/tests/regression/test_bare_except.py @@ -0,0 +1,102 @@ +"""No bare except: clauses in vast.py. + +The bug: Bare except: catches SystemExit and KeyboardInterrupt, which can +mask critical errors and make the program unresponsive to Ctrl+C. It also +swallows programming errors (NameError, TypeError) that should crash loudly. + +The fix: Replace all bare except: with specific exception types appropriate +to each try block's expected failure modes. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import re +import pytest + + +VAST_PY_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + + +class TestNoBareExcept: + """Lint-style tests verifying no bare except: remains in vast.py.""" + + def test_no_bare_except(self): + """No bare except: clauses should exist in vast.py. + + A bare except: catches everything including SystemExit and + KeyboardInterrupt, making Ctrl+C ineffective and masking bugs. + """ + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + # Match "except:" but not "except SomeName:" or "except (A, B):" + bare_excepts = re.findall(r'^\s*except\s*:', content, re.MULTILINE) + assert len(bare_excepts) == 0, ( + f"Found {len(bare_excepts)} bare except: clauses in vast.py. " + "Each except must catch specific exception types." + ) + + def test_except_clauses_have_types(self): + """Every except clause should specify at least one exception type.""" + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + # Find all except lines + except_lines = re.findall(r'^\s*except\b.*:', content, re.MULTILINE) + for line in except_lines: + stripped = line.strip() + # Must be "except SomeType:" or "except (A, B) as e:" etc. + # NOT just "except:" + assert stripped != "except:", ( + f"Found bare except: -- should catch specific types: {line!r}" + ) + + +class TestKeyboardInterruptPropagation: + """Verify KeyboardInterrupt is not swallowed by import handlers.""" + + def test_argcomplete_import_does_not_catch_keyboard_interrupt(self): + """The argcomplete import try/except should not catch KeyboardInterrupt. + + Before the fix, bare except: would catch KeyboardInterrupt during + import, making Ctrl+C during startup silently ignored. + """ + import importlib + import unittest.mock as mock + + # Simulate argcomplete import raising KeyboardInterrupt + original_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__ + + def mock_import(name, *args, **kwargs): + if name == 'argcomplete': + raise KeyboardInterrupt() + return original_import(name, *args, **kwargs) + + # The except ImportError: handler should NOT catch KeyboardInterrupt + # so it should propagate + with pytest.raises(KeyboardInterrupt): + with mock.patch('builtins.__import__', side_effect=mock_import): + # Re-execute the import block logic + try: + __import__('argcomplete') + except ImportError: + pass # This is what the fixed code does + + def test_argcomplete_import_catches_import_error(self): + """The argcomplete import try/except should catch ImportError.""" + import unittest.mock as mock + + original_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__ + + def mock_import(name, *args, **kwargs): + if name == 'argcomplete': + raise ImportError("No module named 'argcomplete'") + return original_import(name, *args, **kwargs) + + # ImportError should be caught (not propagated) + caught = False + with mock.patch('builtins.__import__', side_effect=mock_import): + try: + __import__('argcomplete') + except ImportError: + caught = True + assert caught, "ImportError should be caught by except ImportError:" diff --git a/tests/regression/test_direct_requests.py b/tests/regression/test_direct_requests.py new file mode 100644 index 00000000..97be8800 --- /dev/null +++ b/tests/regression/test_direct_requests.py @@ -0,0 +1,146 @@ +"""No direct requests.get/post calls outside http_* helpers. + +The bug: Several functions used requests.get() or requests.post() directly, +bypassing the centralized http_* helpers that provide timeout, retry, and +error handling. + +The fix: Convert all direct requests.get/post calls in CLI command functions +to use http_get/http_post. Only allowed exceptions are: + - get_project_data(): module-level PyPI check, no args available + - fetch_url_content(): utility function, no args available (dead code) + - _get_gpu_names(): module-level GPU cache, no args available + - http_request(): the low-level implementation that uses requests.Session + - import statements and commented-out code +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import ast +import re + + +# Allowed locations for direct requests.get/post calls +ALLOWED_FUNCTIONS = { + 'get_project_data', # Module-level PyPI check, no args + 'fetch_url_content', # Utility function, no args (dead code) + '_get_gpu_names', # Module-level GPU cache, no args + 'http_request', # Low-level implementation +} + + +def _get_vast_source(): + """Read the vast.py source file.""" + vast_path = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + with open(vast_path, encoding='utf-8') as f: + return f.read() + + +def test_no_direct_requests_get_in_command_functions(): + """No requests.get() calls exist in CLI command functions (outside allowed exceptions).""" + source = _get_vast_source() + tree = ast.parse(source) + + violations = [] + + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + func_name = node.name + if func_name in ALLOWED_FUNCTIONS: + continue + + # Walk the function body looking for requests.get( or requests.post( + for child in ast.walk(node): + if isinstance(child, ast.Call): + func = child.func + # Match: requests.get(...) or requests.post(...) + if (isinstance(func, ast.Attribute) and + isinstance(func.value, ast.Name) and + func.value.id == 'requests' and + func.attr in ('get', 'post')): + violations.append( + f" {func_name}() at line {child.lineno}: " + f"requests.{func.attr}()" + ) + + assert not violations, ( + "Found direct requests.get/post calls in command functions " + "(should use http_get/http_post):\n" + "\n".join(violations) + ) + + +def test_no_direct_requests_get_via_grep(): + """Grep-style check: no unprotected requests.get/post patterns in vast.py.""" + source = _get_vast_source() + lines = source.split('\n') + + violations = [] + pattern = re.compile(r'(?Bad Gateway" + mock_response.text = "Bad Gateway" + mock_response.json.side_effect = JSONDecodeError("msg", "doc", 0) + mock_response.raise_for_status.return_value = None + mock_http_get.return_value = mock_response + + args = _make_args() + result = api_call(args, "GET", "/test/") + + assert result is not None + assert "_raw_text" in result + assert result["_raw_text"] == "Bad Gateway" + + +@patch('vast.http_get') +def test_api_call_returns_json_on_valid_response(mock_http_get): + """api_call() returns parsed JSON normally when response is valid.""" + from vast import api_call + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b'{"success": true}' + mock_response.json.return_value = {"success": True} + mock_response.raise_for_status.return_value = None + mock_http_get.return_value = mock_response + + args = _make_args() + result = api_call(args, "GET", "/test/") + + assert result == {"success": True} + + +@patch('vast.http_get') +def test_api_call_returns_none_for_empty_response(mock_http_get): + """api_call() returns None when response has no content.""" + from vast import api_call + + mock_response = MagicMock() + mock_response.status_code = 204 + mock_response.content = b"" + mock_response.raise_for_status.return_value = None + mock_http_get.return_value = mock_response + + args = _make_args() + result = api_call(args, "GET", "/test/") + + assert result is None + + +@patch('vast.http_get') +def test_api_call_no_exception_raised_on_html_response(mock_http_get): + """api_call() does NOT raise an exception when API returns HTML.""" + from vast import api_call, JSONDecodeError + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b"502 Bad Gateway" + mock_response.text = "502 Bad Gateway" + mock_response.json.side_effect = JSONDecodeError("Expecting value", "", 0) + mock_response.raise_for_status.return_value = None + mock_http_get.return_value = mock_response + + args = _make_args() + + # This should NOT raise - the whole point of the fix + try: + result = api_call(args, "GET", "/instances/") + except JSONDecodeError: + pytest.fail("api_call() should not raise JSONDecodeError") + + assert "_raw_text" in result + + +@patch('vast.http_post') +def test_api_call_post_handles_json_decode_error(mock_http_post): + """api_call() handles JSONDecodeError for POST requests too.""" + from vast import api_call, JSONDecodeError + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b"Internal Server Error" + mock_response.text = "Internal Server Error" + mock_response.json.side_effect = JSONDecodeError("msg", "doc", 0) + mock_response.raise_for_status.return_value = None + mock_http_post.return_value = mock_response + + args = _make_args() + result = api_call(args, "POST", "/instances/", json_body={"test": True}) + + assert result is not None + assert result["_raw_text"] == "Internal Server Error" diff --git a/tests/regression/test_main_error_handling.py b/tests/regression/test_main_error_handling.py new file mode 100644 index 00000000..b5d4187b --- /dev/null +++ b/tests/regression/test_main_error_handling.py @@ -0,0 +1,221 @@ +"""Regression tests for code quality fixes. + +- No Python builtins (id, sum) should be shadowed by local variables +- strip('-') should not be used for prefix removal (use startswith + lstrip) +- Error messages should reference correct field name +- Unused variables should be removed +""" +import re +import sys +from pathlib import Path + +# Add vast-cli to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + + +class TestVariableShadowing: + """No Python builtins should be shadowed by local variable assignments.""" + + def test_no_id_shadowing(self): + """Local variable 'id' should not shadow builtin. + + Note: keyword args in argparse.Namespace() like id=value are NOT shadowing. + """ + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Find all assignments to bare 'id' (not id= which is keyword arg) + # Pattern: whitespace + id + optional space + = + not = (not ==) + # This should NOT match lines like "id=ask_contract_id," in Namespace() + matches = re.findall(r'^\s+id\s+=[^=]', content, re.MULTILINE) + + assert len(matches) == 0, ( + f"Found {len(matches)} instances of 'id' shadowing builtin. " + f"These should be renamed to domain-specific names like instance_id, " + f"workergroup_id, etc. First few matches: {matches[:5]}" + ) + + def test_no_sum_function_shadowing(self): + """Function 'sum' should not shadow builtin. + + The custom sum function should be renamed to sum_field or similar. + """ + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Check for def sum( function definition + matches = re.findall(r'^def sum\s*\(', content, re.MULTILINE) + + assert len(matches) == 0, ( + f"Found def sum() which shadows Python builtin. " + f"Should be renamed to sum_field() or similar." + ) + + def test_sum_field_function_exists(self): + """The renamed sum_field function should exist.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Check for def sum_field( function definition + assert 'def sum_field(' in content, ( + "Expected sum_field() function to exist after renaming from sum()" + ) + + def test_domain_specific_id_names_used(self): + """Verify domain-specific id names are used instead of bare 'id'.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + expected_patterns = [ + 'workergroup_id = args.id', + 'endpoint_id = args.id', + 'volume_id = args.id', + ] + + for pattern in expected_patterns: + assert pattern in content, ( + f"Expected domain-specific name pattern '{pattern}' not found. " + f"Bare 'id' may not have been renamed properly." + ) + + +class TestStringMethodFixes: + """strip('-') should not be used for prefix removal.""" + + def test_no_strip_dash_for_direction(self): + """strip('-') removes from both ends - should use startswith + lstrip.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # .strip("-") or .strip('-') should not appear for direction parsing + matches = re.findall(r'name\.strip\(["\'][-+]["\']', content) + + assert len(matches) == 0, ( + f"Found {len(matches)} instances of name.strip('-') or name.strip('+'). " + f"These should use startswith() + lstrip() instead: {matches}" + ) + + def test_direction_parsing_uses_startswith(self): + """Sort direction parsing should use startswith, not strip comparison.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Should find startswith("-") for direction checks + has_startswith_minus = 'name.startswith("-")' in content + has_startswith_plus = 'name.startswith("+")' in content + + assert has_startswith_minus, ( + "Expected name.startswith('-') for descending sort direction parsing" + ) + assert has_startswith_plus, ( + "Expected name.startswith('+') for ascending sort direction parsing" + ) + + def test_direction_parsing_uses_lstrip(self): + """After detecting prefix, should use lstrip to remove it.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Should find lstrip("-") and lstrip("+") + has_lstrip_minus = 'name.lstrip("-")' in content + has_lstrip_plus = 'name.lstrip("+")' in content + + assert has_lstrip_minus, ( + "Expected name.lstrip('-') for removing descending prefix" + ) + assert has_lstrip_plus, ( + "Expected name.lstrip('+') for removing ascending prefix" + ) + + def test_elif_structure_for_direction_parsing(self): + """Direction parsing should use elif, not two separate if statements. + + Using two separate if statements means a field like '-score' would first + match startswith('-') then incorrectly also check startswith('+'). + """ + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Look for the correct pattern: if startswith("-") ... elif startswith("+") + # Not: if startswith("-") ... if startswith("+") + pattern = r'if name\.startswith\("-"\):.*?elif name\.startswith\("\+"\):' + matches = re.findall(pattern, content, re.DOTALL) + + assert len(matches) >= 4, ( + f"Expected at least 4 occurrences of correct if/elif pattern for " + f"direction parsing (in search__offers, search__instances, " + f"search__volumes, search__network_volumes), found {len(matches)}" + ) + + +class TestErrorMessages: + """Error messages should reference correct field.""" + + def test_start_date_error_message_correct(self): + """Error for start_date should say 'start date', not 'end date'.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Should not have "start date" error saying "Ignoring end date" + bad_pattern = re.search( + r'Invalid start date.*Ignoring end date', + content, + re.IGNORECASE + ) + + assert bad_pattern is None, ( + "Found misleading error message - start_date error mentions 'Ignoring end date'. " + "Should say 'Ignoring start date' instead." + ) + + def test_start_date_error_says_start_date(self): + """Start date errors should reference start date.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Should find "Invalid start date" + "Ignoring start date" pattern + correct_pattern = re.search( + r'Invalid start date.*Ignoring start date', + content, + re.IGNORECASE + ) + + assert correct_pattern is not None, ( + "Expected start_date error message to say 'Ignoring start date'" + ) + + +class TestUnusedVariables: + """Unused variables should be removed.""" + + def test_no_unused_date_txt_in_show_earnings(self): + """In show__earnings, date_txt variables should not be assigned if unused. + + Note: date_txt variables ARE used in invoice functions, just not in show__earnings. + """ + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Find the show__earnings function + show_earnings_match = re.search( + r'def show__earnings\(args\):.*?(?=\ndef |$)', + content, + re.DOTALL + ) + + assert show_earnings_match is not None, "Could not find show__earnings function" + + show_earnings_code = show_earnings_match.group(0) + + # Check that end_date_txt and start_date_txt are not assigned in this function + has_end_date_txt = 'end_date_txt' in show_earnings_code + has_start_date_txt = 'start_date_txt' in show_earnings_code + + assert not has_end_date_txt, ( + "Found unused end_date_txt assignment in show__earnings. " + "This variable is assigned but never used in this function." + ) + assert not has_start_date_txt, ( + "Found unused start_date_txt assignment in show__earnings. " + "This variable is assigned but never used in this function." + ) diff --git a/tests/regression/test_main_raw_output.py b/tests/regression/test_main_raw_output.py new file mode 100644 index 00000000..7e9df44f --- /dev/null +++ b/tests/regression/test_main_raw_output.py @@ -0,0 +1,199 @@ +"""Regression tests for API request and raw output fixes.""" +import re +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + + +class TestApiKeyHandling: + """API key should be in headers, not JSON body.""" + + def test_no_api_key_in_json_blob(self): + """api_key should not appear in json_blob assignments.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Look for json_blob containing api_key (but not in comments) + lines = content.split('\n') + violations = [] + for i, line in enumerate(lines, 1): + # Skip comments + if line.strip().startswith('#'): + continue + # Check for api_key in json dict literal on same line as json_blob + if 'json_blob' in line and 'api_key' in line and '=' in line: + violations.append(f"Line {i}: {line.strip()}") + + assert len(violations) == 0, f"Found api_key in json_blob: {violations}" + + def test_get_endpt_logs_no_api_key_in_body(self): + """get__endpt_logs should not have api_key in JSON.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Find function and check its body + pattern = r'def get__endpt_logs\(.*?(?=\ndef |\Z)' + match = re.search(pattern, content, re.DOTALL) + assert match, "get__endpt_logs function not found" + + func_body = match.group(0) + # Should not have api_key in any dict literal + json_lines = [l for l in func_body.split('\n') + if 'json_blob' in l and 'api_key' in l and '=' in l + and not l.strip().startswith('#')] + assert len(json_lines) == 0, f"get__endpt_logs has api_key in json_blob: {json_lines}" + + def test_get_wrkgrp_logs_no_api_key_in_body(self): + """get__wrkgrp_logs should not have api_key in JSON.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + pattern = r'def get__wrkgrp_logs\(.*?(?=\ndef |\Z)' + match = re.search(pattern, content, re.DOTALL) + assert match, "get__wrkgrp_logs function not found" + + func_body = match.group(0) + json_lines = [l for l in func_body.split('\n') + if 'json_blob' in l and 'api_key' in l and '=' in l + and not l.strip().startswith('#')] + assert len(json_lines) == 0, f"get__wrkgrp_logs has api_key in json_blob: {json_lines}" + + def test_show_workergroups_no_api_key_in_body(self): + """show__workergroups should not have api_key in JSON.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + pattern = r'def show__workergroups\(.*?(?=\ndef |\Z)' + match = re.search(pattern, content, re.DOTALL) + assert match, "show__workergroups function not found" + + func_body = match.group(0) + json_lines = [l for l in func_body.split('\n') + if 'json_blob' in l and 'api_key' in l and '=' in l + and not l.strip().startswith('#')] + assert len(json_lines) == 0, f"show__workergroups has api_key in json_blob: {json_lines}" + + def test_show_endpoints_no_api_key_in_body(self): + """show__endpoints should not have api_key in JSON.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + pattern = r'def show__endpoints\(.*?(?=\ndef |\Z)' + match = re.search(pattern, content, re.DOTALL) + assert match, "show__endpoints function not found" + + func_body = match.group(0) + json_lines = [l for l in func_body.split('\n') + if 'json_blob' in l and 'api_key' in l and '=' in l + and not l.strip().startswith('#')] + assert len(json_lines) == 0, f"show__endpoints has api_key in json_blob: {json_lines}" + + +class TestSafeIteration: + """next() calls should have default to prevent StopIteration.""" + + def test_next_calls_have_default(self): + """All next() calls with generators should have a default parameter.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Find next() calls without default + # Pattern: next(something for something) without a comma before closing + lines = content.split('\n') + risky_calls = [] + for i, line in enumerate(lines, 1): + if 'next(' in line and not line.strip().startswith('#'): + # Simple heuristic: if line has next( with 'for' inside but no comma + # This catches: next(x for x in y) but not next((x for x in y), None) + match = re.search(r'next\(\s*[^,)]+\s+for\s+[^,)]+\)', line) + if match: + risky_calls.append(f"Line {i}: {line.strip()}") + + assert len(risky_calls) == 0, f"Found next() without default: {risky_calls}" + + def test_show_clusters_next_has_default(self): + """show__clusters manager_node lookup should have default.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + pattern = r'def show__clusters\(.*?(?=\ndef |\Z)' + match = re.search(pattern, content, re.DOTALL) + assert match, "show__clusters function not found" + + func_body = match.group(0) + + # Should have next(..., None) pattern - may be multiline with nested parens + assert 'next(' in func_body, "show__clusters should use next()" + # Check for manager_node = next(..., None) with multiline and nested parens + # The pattern has: next(\n(generator),\nNone\n) + has_none_default = re.search(r'manager_node\s*=\s*next\s*\(.*?,\s*None\s*\)', func_body, re.DOTALL) + assert has_none_default, "show__clusters next() should have None default" + + def test_show_clusters_handles_missing_manager(self): + """show__clusters should handle case when no manager node exists.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + pattern = r'def show__clusters\(.*?(?=\ndef |\Z)' + match = re.search(pattern, content, re.DOTALL) + assert match, "show__clusters function not found" + + func_body = match.group(0) + + # Should check for None manager_node + assert 'manager_node is None' in func_body or 'if manager_node is None' in func_body, \ + "show__clusters should check for None manager_node" + + +class TestTransferCredit: + """--transfer_credit should be implemented or removed from docs.""" + + def test_transfer_credit_not_in_create_team_epilog(self): + """create team epilog should not mention --transfer_credit.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Find create__team function and its decorator + pattern = r'@parser\.command\([^)]*argument\([^)]*team_name[^)]*\)[^)]*\)[^@]*def create__team' + match = re.search(pattern, content, re.DOTALL) + if match: + decorator_and_func = match.group(0) + # Should not mention --transfer_credit as a flag + assert '--transfer_credit' not in decorator_and_func, \ + "create team should not document --transfer_credit as a flag" + + def test_transfer_credit_consistency(self): + """If --transfer_credit is mentioned, it should be documented correctly.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # Find create__team decorator/epilog section + pattern = r'@parser\.command\(\s*argument\("--team_name".*?def create__team' + match = re.search(pattern, content, re.DOTALL) + + if match: + section = match.group(0) + # If transfer_credit is mentioned at all, it should NOT be as --transfer_credit flag + if 'transfer_credit' in section.lower(): + # Should be pointing to the separate command, not documenting a flag + assert 'vastai transfer credit' in section or '--transfer_credit' not in section, \ + "transfer_credit should point to 'vastai transfer credit' command, not a flag" + + def test_transfer_credit_command_exists(self): + """The transfer__credit command should exist as separate command.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + content = vast_path.read_text(encoding='utf-8') + + # transfer__credit should exist as its own command + assert 'def transfer__credit' in content, \ + "transfer__credit should exist as a separate command" + + # It should have proper argument decorators - multiline decorator + # Look for recipient and amount in argument() calls before def transfer__credit + pattern = r'@parser\.command\(.*?def transfer__credit' + match = re.search(pattern, content, re.DOTALL) + assert match, "transfer__credit should have @parser.command decorator" + decorator_section = match.group(0) + assert 'recipient' in decorator_section, "transfer__credit should have recipient argument" + assert 'amount' in decorator_section, "transfer__credit should have amount argument" diff --git a/tests/regression/test_mutable_defaults.py b/tests/regression/test_mutable_defaults.py new file mode 100644 index 00000000..777c2b72 --- /dev/null +++ b/tests/regression/test_mutable_defaults.py @@ -0,0 +1,54 @@ +"""Mutable default arguments in http_put, http_post, http_del. + +The bug: Using json={} as a default argument means all calls share the same +dict object. If any caller mutates the dict, subsequent calls see the mutation. + +The fix: Use json=None and initialize to {} at runtime. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +from unittest.mock import MagicMock, patch +import argparse + + +def _make_args(): + return argparse.Namespace(retry=1, curl=False) + + +def _mock_http_request(verb, args, req_url, headers, json, **kwargs): + """Capture the json arg and return a mock response.""" + r = MagicMock() + r.status_code = 200 + # Store reference to the json dict that was passed + r._captured_json = json + return r + + +@patch('vast.http_request', side_effect=_mock_http_request) +def test_http_put_no_shared_default(mock_req): + from vast import http_put + args = _make_args() + r1 = http_put(args, "http://test1", headers=None) + r2 = http_put(args, "http://test2", headers=None) + # Each call must get its own dict, not share the mutable default + assert r1._captured_json is not r2._captured_json + + +@patch('vast.http_request', side_effect=_mock_http_request) +def test_http_post_no_shared_default(mock_req): + from vast import http_post + args = _make_args() + r1 = http_post(args, "http://test1", headers=None) + r2 = http_post(args, "http://test2", headers=None) + assert r1._captured_json is not r2._captured_json + + +@patch('vast.http_request', side_effect=_mock_http_request) +def test_http_del_no_shared_default(mock_req): + from vast import http_del + args = _make_args() + r1 = http_del(args, "http://test1", headers=None) + r2 = http_del(args, "http://test2", headers=None) + assert r1._captured_json is not r2._captured_json diff --git a/tests/regression/test_namespace_typo.py b/tests/regression/test_namespace_typo.py new file mode 100644 index 00000000..2aa8a79b --- /dev/null +++ b/tests/regression/test_namespace_typo.py @@ -0,0 +1,23 @@ +"""Typo 'debbuging' in Namespace construction. + +The bug: destroy_args = argparse.Namespace(..., debbuging=args.debugging, ...) +creates an attribute 'debbuging' instead of 'debugging'. Any code accessing +destroy_args.debugging gets AttributeError. + +The fix: Change 'debbuging' to 'debugging'. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + + +def test_no_debbuging_typo_in_source(): + """Verify the typo 'debbuging' does not appear in vast.py source.""" + vast_path = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + with open(vast_path, 'r', encoding='utf-8', errors='replace') as f: + source = f.read() + + assert 'debbuging' not in source, ( + "Found 'debbuging' typo in vast.py. " + "Should be 'debugging' in the Namespace construction." + ) diff --git a/tests/regression/test_parse_query.py b/tests/regression/test_parse_query.py new file mode 100644 index 00000000..9d6bfaac --- /dev/null +++ b/tests/regression/test_parse_query.py @@ -0,0 +1,59 @@ +"""parse_query() field alias bug -- v references dict after pop. + +The bug: `v = res.setdefault(field, {})` gets a reference to a dict at the +original field name. Then `res.pop(field)` removes it from res. Writing to v +modifies an orphaned dict that's no longer in res. + +Fields affected: cuda_vers->cuda_max_good, dph->dph_total, dlperf_usd->dlperf_per_dphtotal + +The fix: Resolve alias before calling setdefault. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + + +def test_field_alias_cuda_vers(): + """parse_query correctly aliases cuda_vers to cuda_max_good.""" + from vast import parse_query, offers_fields, offers_alias + + result = parse_query("cuda_vers >= 12.0", {}, offers_fields, offers_alias) + # The result should have cuda_max_good, NOT cuda_vers + assert 'cuda_max_good' in result, ( + f"Expected 'cuda_max_good' in result, got keys: {list(result.keys())}. " + f"Field alias not applied correctly." + ) + assert 'cuda_vers' not in result, "Old field name 'cuda_vers' should not be in result" + assert 'gte' in result['cuda_max_good'], ( + f"Expected 'gte' operator in cuda_max_good, got: {result['cuda_max_good']}" + ) + + +def test_field_alias_dph(): + """parse_query correctly aliases dph to dph_total.""" + from vast import parse_query, offers_fields, offers_alias, offers_mult + + result = parse_query("dph <= 1.5", {}, offers_fields, offers_alias, offers_mult) + assert 'dph_total' in result, f"Expected 'dph_total', got keys: {list(result.keys())}" + assert 'dph' not in result + assert 'lte' in result['dph_total'] + + +def test_field_alias_value_preserved(): + """After alias resolution, the operator value is correctly stored.""" + from vast import parse_query, offers_fields, offers_alias + + result = parse_query("cuda_vers >= 12.0", {}, offers_fields, offers_alias) + assert result['cuda_max_good']['gte'] == '12.0', ( + f"Expected value '12.0', got: {result['cuda_max_good'].get('gte')}" + ) + + +def test_non_aliased_field_unaffected(): + """Fields without aliases work normally.""" + from vast import parse_query, offers_fields, offers_alias + + result = parse_query("num_gpus >= 2", {}, offers_fields, offers_alias) + assert 'num_gpus' in result + assert 'gte' in result['num_gpus'] + assert result['num_gpus']['gte'] == '2' diff --git a/tests/regression/test_raw_completeness.py b/tests/regression/test_raw_completeness.py new file mode 100644 index 00000000..a798c9dd --- /dev/null +++ b/tests/regression/test_raw_completeness.py @@ -0,0 +1,176 @@ +""" +Regression Tests: Raw Mode Completeness + +Verifies that command functions have consistent --raw handling. +""" +import re +import sys +from pathlib import Path + +# Add vast-cli to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + + +class TestRawModeCompleteness: + """Tests that command functions have raw mode handling.""" + + def setup_method(self): + """Load vast.py source code once per test.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + self.source = vast_path.read_text(encoding="utf-8") + + def test_minimum_raw_handlers(self): + """Verify at least 90 raw mode handlers exist.""" + raw_count = len(re.findall(r"if args\.raw:", self.source)) + assert raw_count >= 90, f"Expected at least 90 raw handlers, found {raw_count}" + + def test_command_functions_have_raw_handling(self): + """Check that common command functions have if args.raw: patterns.""" + # Key command functions that should have raw handling + expected_functions = [ + "attach__ssh", + "cancel__copy", + "cancel__sync", + "change__bid", + "create__api_key", + "create__env_var", + "create__ssh_key", + "create__workergroup", + "create__endpoint", + "create__subaccount", + "create__team", + "create__team_role", + "create__template", + "delete__api_key", + "delete__ssh_key", + "delete__scheduled_job", + "delete__workergroup", + "delete__endpoint", + "delete__env_var", + "delete__template", + "destroy__team", + "detach__ssh", + "invite__member", + "label__instance", + "prepay__instance", + "reboot__instance", + "recycle__instance", + "remove__member", + "remove__team_role", + "reports", + "reset__api_key", + "transfer__credit", + ] + + for func_name in expected_functions: + # Find the function definition + func_pattern = rf"^def {func_name}\(args" + match = re.search(func_pattern, self.source, re.MULTILINE) + assert match, f"Function {func_name} not found" + + # Get function body (rough approximation - from def to next def or EOF) + start = match.start() + next_def = re.search(r"^def \w+\(", self.source[start + 10:], re.MULTILINE) + end = start + 10 + next_def.start() if next_def else len(self.source) + func_body = self.source[start:end] + + # Check for raw handling + has_raw = "if args.raw:" in func_body + assert has_raw, f"Function {func_name} missing 'if args.raw:' handling" + + def test_no_orphan_print_without_raw_check(self): + """ + Verify that functions returning JSON data have raw checks before prints. + + This is a sampling test - check a few critical functions. + """ + # Functions that should have raw handling before their main output + critical_patterns = [ + # (function_name, expected_pattern_after_raw_check) + ("create__team", r"if args\.raw:.*?return.*?print\(result\)"), + ("delete__api_key", r"if args\.raw:.*?return.*?print\(result\)"), + ] + + for func_name, _ in critical_patterns: + func_pattern = rf"^def {func_name}\(args" + match = re.search(func_pattern, self.source, re.MULTILINE) + assert match, f"Function {func_name} not found" + + start = match.start() + next_def = re.search(r"^def \w+\(", self.source[start + 10:], re.MULTILINE) + end = start + 10 + next_def.start() if next_def else len(self.source) + func_body = self.source[start:end] + + # Verify raw check exists + assert "if args.raw:" in func_body, f"{func_name} should have raw check" + + +class TestRawModeReturnsData: + """Tests that raw mode handlers return data, not None.""" + + def setup_method(self): + """Load vast.py source code once per test.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + self.source = vast_path.read_text(encoding="utf-8") + + def test_raw_handlers_have_return_statements(self): + """Verify raw mode handlers include return statements.""" + # Find all if args.raw: blocks + raw_pattern = r"if args\.raw:\s*\n\s+return" + matches = re.findall(raw_pattern, self.source) + # Should have many return statements following raw checks + assert len(matches) >= 80, f"Expected 80+ 'if args.raw: return' patterns, found {len(matches)}" + + def test_no_empty_raw_blocks(self): + """Verify no raw blocks that just pass or do nothing.""" + # Pattern for raw checks that just pass + empty_raw_pattern = r"if args\.raw:\s*\n\s+pass\s*\n" + matches = re.findall(empty_raw_pattern, self.source) + assert len(matches) == 0, f"Found {len(matches)} empty 'if args.raw: pass' blocks" + + +class TestRawModeConsistency: + """Tests for consistent raw mode patterns across the codebase.""" + + def setup_method(self): + """Load vast.py source code once per test.""" + vast_path = Path(__file__).parent.parent.parent / "vast.py" + self.source = vast_path.read_text(encoding="utf-8") + + def test_consistent_raw_return_pattern(self): + """ + Verify raw mode uses consistent return patterns. + + Expected patterns: + - if args.raw: return rj + - if args.raw: return result + - if args.raw: return rows + - if args.raw: return data + """ + # Find all raw return patterns + raw_return_pattern = r"if args\.raw:\s*\n\s+return\s+(\w+)" + matches = re.findall(raw_return_pattern, self.source) + + # Common return variable names + valid_names = {"rj", "result", "rows", "data", "processed", "user_blob", + "response_data", "instances", "volumes", "machines", "offers"} + + for var_name in matches: + # Allow any reasonable variable name (not just the common ones) + # This is a sanity check - variables should be short identifiers + assert len(var_name) < 30, f"Suspicious return variable: {var_name}" + + def test_output_result_handles_raw(self): + """Verify output_result function exists and handles raw mode.""" + # Check that output_result is defined + assert "def output_result(" in self.source, "output_result function not found" + + # Check that output_result checks args.raw + output_result_match = re.search( + r"def output_result\(.*?\n(.*?)(?=^def |\Z)", + self.source, + re.MULTILINE | re.DOTALL + ) + assert output_result_match, "Could not extract output_result function body" + func_body = output_result_match.group(1) + assert "args.raw" in func_body, "output_result should check args.raw" diff --git a/tests/regression/test_raw_errors.py b/tests/regression/test_raw_errors.py new file mode 100644 index 00000000..b30b54b7 --- /dev/null +++ b/tests/regression/test_raw_errors.py @@ -0,0 +1,188 @@ +"""Error messages in --raw mode must be valid JSON. + +The bug: When --raw mode is active and an HTTPError or ValueError occurs, +the error handler prints plain text (e.g., "failed with error 500: ..."). +Scripts and automation tools parsing JSON output get broken by mixed +text/JSON output. + +The fix: Check args.raw in the HTTPError and ValueError exception handlers +in main(), and output a JSON object with error/status_code/msg fields. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import json +import argparse +import pytest +from unittest import mock +from io import StringIO + +import requests + + +class TestHTTPErrorRawMode: + """Verify HTTPError in --raw mode produces valid JSON.""" + + def _make_mock_response(self, status_code, json_body=None): + """Create a mock response for HTTPError.""" + resp = mock.MagicMock() + resp.status_code = status_code + if json_body is not None: + resp.json.return_value = json_body + else: + resp.json.side_effect = json.JSONDecodeError("No JSON", "", 0) + return resp + + def _run_main_with_error(self, args): + """Run vast.main() with proper mocking, return captured stdout.""" + import vast + + captured = StringIO() + with mock.patch.object(vast, 'ARGS', args): + with mock.patch('vast.parser') as mock_parser: + mock_parser.parse_args.return_value = args + mock_parser.add_argument = mock.MagicMock() + mock_parser.parser = mock.MagicMock() + with mock.patch('vast.should_check_for_update', False): + with mock.patch('vast.TABCOMPLETE', False): + with mock.patch('vast.api_key_guard', 'GUARD'): + with mock.patch('sys.stdout', captured): + try: + vast.main() + except SystemExit: + pass + return captured.getvalue().strip() + + def test_http_error_raw_mode_produces_json(self): + """HTTPError with --raw should output valid JSON with error/status_code/msg.""" + mock_resp = self._make_mock_response(500, {"msg": "Internal server error"}) + http_error = requests.exceptions.HTTPError(response=mock_resp) + + args = argparse.Namespace( + raw=True, func=mock.MagicMock(side_effect=http_error), + api_key="test", url="https://test.com", explain=False, + retry=3, full=False, curl=False, no_color=False, + ) + + output = self._run_main_with_error(args) + parsed = json.loads(output) + assert parsed["error"] is True + assert parsed["status_code"] == 500 + assert parsed["msg"] == "Internal server error" + + def test_http_error_raw_mode_401_produces_json(self): + """HTTPError 401 with --raw should output JSON with login message.""" + mock_resp = self._make_mock_response(401, json_body=None) + http_error = requests.exceptions.HTTPError(response=mock_resp) + + args = argparse.Namespace( + raw=True, func=mock.MagicMock(side_effect=http_error), + api_key="test", url="https://test.com", explain=False, + retry=3, full=False, curl=False, no_color=False, + ) + + output = self._run_main_with_error(args) + parsed = json.loads(output) + assert parsed["error"] is True + assert parsed["status_code"] == 401 + assert "log in" in parsed["msg"].lower() or "sign up" in parsed["msg"].lower() + + def test_http_error_non_raw_mode_produces_text(self): + """HTTPError without --raw should output human-readable text.""" + mock_resp = self._make_mock_response(500, {"msg": "Server error"}) + http_error = requests.exceptions.HTTPError(response=mock_resp) + + args = argparse.Namespace( + raw=False, func=mock.MagicMock(side_effect=http_error), + api_key="test", url="https://test.com", explain=False, + retry=3, full=False, curl=False, no_color=False, + ) + + output = self._run_main_with_error(args) + # Non-raw should NOT be valid JSON with error key + assert "failed with error" in output.lower() + + +class TestValueErrorRawMode: + """Verify ValueError in --raw mode produces valid JSON.""" + + def _run_main_with_error(self, args): + """Run vast.main() with proper mocking, return captured stdout.""" + import vast + + captured = StringIO() + with mock.patch.object(vast, 'ARGS', args): + with mock.patch('vast.parser') as mock_parser: + mock_parser.parse_args.return_value = args + mock_parser.add_argument = mock.MagicMock() + mock_parser.parser = mock.MagicMock() + with mock.patch('vast.should_check_for_update', False): + with mock.patch('vast.TABCOMPLETE', False): + with mock.patch('vast.api_key_guard', 'GUARD'): + with mock.patch('sys.stdout', captured): + try: + vast.main() + except SystemExit: + pass + return captured.getvalue().strip() + + def test_value_error_raw_mode_produces_json(self): + """ValueError with --raw should output valid JSON with error/msg.""" + args = argparse.Namespace( + raw=True, func=mock.MagicMock(side_effect=ValueError("bad value")), + api_key="test", url="https://test.com", explain=False, + retry=3, full=False, curl=False, no_color=False, + ) + + output = self._run_main_with_error(args) + parsed = json.loads(output) + assert parsed["error"] is True + assert parsed["msg"] == "bad value" + + def test_value_error_non_raw_mode_produces_text(self): + """ValueError without --raw should print the error message as text.""" + args = argparse.Namespace( + raw=False, func=mock.MagicMock(side_effect=ValueError("bad value")), + api_key="test", url="https://test.com", explain=False, + retry=3, full=False, curl=False, no_color=False, + ) + + output = self._run_main_with_error(args) + assert output == "bad value" + + +class TestRawErrorHandlerLintChecks: + """Lint-style tests to verify raw error handling exists in main().""" + + def test_httperror_handler_checks_args_raw(self): + """The HTTPError handler in main() must check args.raw.""" + VAST_PY_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + + # Find the HTTPError handler section + assert 'HTTPError' in content, "HTTPError handler must exist" + # Find args.raw in the context of error handling + import re + # Look for args.raw near HTTPError handling + http_error_section = content[content.index('HTTPError'):] + # Limit to the next except or end of function + next_section = http_error_section[:http_error_section.index('except ValueError')] + assert 'args.raw' in next_section, ( + "HTTPError handler must check args.raw for JSON output" + ) + + def test_valueerror_handler_checks_args_raw(self): + """The ValueError handler in main() must check args.raw.""" + VAST_PY_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + + # Find the ValueError handler in main() - the one near end of file + # Look for the last 'except ValueError' which is in main() + last_ve_idx = content.rindex('except ValueError') + ve_section = content[last_ve_idx:last_ve_idx + 300] + assert 'args.raw' in ve_section, ( + "ValueError handler in main() must check args.raw for JSON output" + ) diff --git a/tests/regression/test_retry.py b/tests/regression/test_retry.py new file mode 100644 index 00000000..ab1822a5 --- /dev/null +++ b/tests/regression/test_retry.py @@ -0,0 +1,233 @@ +"""Incomplete retry logic -- only retries on HTTP 429. + +The bug: http_request() only retries when the server returns 429 (rate limit). +Transient 5xx errors (502, 503, 504) and connection failures cause immediate +command failure instead of being retried. + +The fix: Expand retry to cover {429, 502, 503, 504} status codes and split +exception handling so ConnectionError/Timeout are retried while other +RequestException subclasses are raised immediately. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import argparse +from unittest.mock import MagicMock, patch +import pytest +import requests.exceptions + + +def _make_args(retry=3): + return argparse.Namespace(retry=retry, curl=False) + + +@patch('vast.time.sleep') +@patch('vast.requests.Session') +def test_retry_on_502(mock_session_cls, mock_sleep): + """http_request retries on 502 Bad Gateway and recovers.""" + from vast import http_request + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + response_502 = MagicMock() + response_502.status_code = 502 + response_200 = MagicMock() + response_200.status_code = 200 + + mock_session.send.side_effect = [response_502, response_200] + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=3) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test') + + assert result.status_code == 200 + assert mock_session.send.call_count == 2 + + +@patch('vast.time.sleep') +@patch('vast.requests.Session') +def test_retry_on_503(mock_session_cls, mock_sleep): + """http_request retries on 503 Service Unavailable.""" + from vast import http_request + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + response_503 = MagicMock() + response_503.status_code = 503 + response_200 = MagicMock() + response_200.status_code = 200 + + mock_session.send.side_effect = [response_503, response_200] + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=3) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test') + + assert result.status_code == 200 + assert mock_session.send.call_count == 2 + + +@patch('vast.time.sleep') +@patch('vast.requests.Session') +def test_retry_on_504(mock_session_cls, mock_sleep): + """http_request retries on 504 Gateway Timeout.""" + from vast import http_request + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + response_504 = MagicMock() + response_504.status_code = 504 + response_200 = MagicMock() + response_200.status_code = 200 + + mock_session.send.side_effect = [response_504, response_200] + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=3) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test') + + assert result.status_code == 200 + assert mock_session.send.call_count == 2 + + +@patch('vast.time.sleep') +@patch('vast.requests.Session') +def test_retry_on_429_still_works(mock_session_cls, mock_sleep): + """Original 429 retry behavior is preserved.""" + from vast import http_request + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + response_429 = MagicMock() + response_429.status_code = 429 + response_200 = MagicMock() + response_200.status_code = 200 + + mock_session.send.side_effect = [response_429, response_200] + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=3) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test') + + assert result.status_code == 200 + assert mock_session.send.call_count == 2 + # Should have slept once (after the 429) + assert mock_sleep.call_count == 1 + + +@patch('vast.time.sleep') +@patch('vast.requests.Session') +def test_retry_on_connection_error(mock_session_cls, mock_sleep): + """http_request retries on ConnectionError and recovers.""" + from vast import http_request + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + response_200 = MagicMock() + response_200.status_code = 200 + + mock_session.send.side_effect = [ + requests.exceptions.ConnectionError("connection refused"), + response_200, + ] + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=3) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test') + + assert result.status_code == 200 + assert mock_session.send.call_count == 2 + + +@patch('vast.time.sleep') +@patch('vast.requests.Session') +def test_retry_on_timeout_exception(mock_session_cls, mock_sleep): + """http_request retries on Timeout exception and recovers.""" + from vast import http_request + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + response_200 = MagicMock() + response_200.status_code = 200 + + mock_session.send.side_effect = [ + requests.exceptions.Timeout("read timed out"), + response_200, + ] + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=3) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test') + + assert result.status_code == 200 + assert mock_session.send.call_count == 2 + + +@patch('vast.time.sleep') +@patch('vast.requests.Session') +def test_no_retry_on_non_retryable_exception(mock_session_cls, mock_sleep): + """Non-retryable RequestException subclasses are raised immediately.""" + from vast import http_request + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + # InvalidURL is a non-retryable error + mock_session.send.side_effect = requests.exceptions.InvalidURL("bad url") + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=3) + with patch('vast.ARGS', args): + with pytest.raises(requests.exceptions.InvalidURL): + http_request('GET', args, 'http://example.com/test') + + # Should NOT have retried -- only 1 call + assert mock_session.send.call_count == 1 + # Should NOT have slept + assert mock_sleep.call_count == 0 + + +@patch('vast.time.sleep') +@patch('vast.requests.Session') +def test_no_retry_on_non_retryable_status(mock_session_cls, mock_sleep): + """Non-retryable status codes (e.g., 400, 404, 500) break out immediately.""" + from vast import http_request + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + response_400 = MagicMock() + response_400.status_code = 400 + + mock_session.send.return_value = response_400 + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=3) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test') + + # Should return on first call without retrying + assert result.status_code == 400 + assert mock_session.send.call_count == 1 + + +@patch('vast.time.sleep') +@patch('vast.requests.Session') +def test_retryable_status_codes_constant(mock_session_cls, mock_sleep): + """RETRYABLE_STATUS_CODES contains exactly {429, 502, 503, 504}.""" + from vast import RETRYABLE_STATUS_CODES + + assert RETRYABLE_STATUS_CODES == {429, 502, 503, 504} diff --git a/tests/regression/test_return_response.py b/tests/regression/test_return_response.py new file mode 100644 index 00000000..c97d9ca1 --- /dev/null +++ b/tests/regression/test_return_response.py @@ -0,0 +1,58 @@ +"""22 functions return Response object instead of parsed JSON. + +The bug: Functions using http_* directly return `r` (Response object) in raw +mode. Response objects are not JSON-serializable, causing json.dumps to fail. +The bare except: in main() masks this by calling res.json() as a fallback. + +The fix: Change `return r` to `return r.json()` in all 22 functions. +Also: Remove bare except in main() since all returns are now serializable. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import re + + +def test_no_bare_return_r_in_functions(): + """No function (except http_request) returns bare `r` (Response object).""" + vast_path = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + with open(vast_path, 'r', encoding='utf-8', errors='replace') as f: + lines = f.readlines() + + # Find all 'return r' lines outside http_request + in_http_request = False + violations = [] + for i, line in enumerate(lines, 1): + stripped = line.strip() + if stripped.startswith('def http_request('): + in_http_request = True + elif stripped.startswith('def ') and in_http_request: + in_http_request = False + + if in_http_request: + continue + + # Match 'return r' but not 'return rows', 'return r.json()', etc. + if re.match(r'\s+return r\s*$', line): + violations.append(f"Line {i}: {stripped}") + + assert not violations, ( + f"Found {len(violations)} function(s) still returning bare Response object:\n" + + "\n".join(violations) + ) + + +def test_no_bare_except_in_main(): + """main() should not have a bare except: clause for raw output.""" + vast_path = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + with open(vast_path, 'r', encoding='utf-8', errors='replace') as f: + source = f.read() + + # Check that res.json() fallback is gone from main's raw handler + # The old pattern was: try: json.dumps(res) except: json.dumps(res.json()) + # After fix: just json.dumps(res) with no try/except + assert 'res.json()' not in source, ( + "Found res.json() fallback in main(). " + "All returns are JSON-serializable and the fallback is dead code." + ) diff --git a/tests/regression/test_safe_dict_access.py b/tests/regression/test_safe_dict_access.py new file mode 100644 index 00000000..fc285a5e --- /dev/null +++ b/tests/regression/test_safe_dict_access.py @@ -0,0 +1,237 @@ +"""Safe dict access on API response dicts. + +The bug: 60+ locations accessed API response dicts with rj["key"], +r.json()["key"], or result["key"] which raises KeyError if the API +response format changes or an endpoint returns unexpected data. + +The fix: Convert all API response dict accesses to .get() with +appropriate defaults: + - Boolean checks: rj.get("success") -- None is falsy + - Messages: rj.get("msg", "Unknown error") -- fallback text + - Iterable data: rj.get("offers", []) -- empty list for iteration + - Required fields: rj.get("result_url") with explicit error check +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import re +import argparse +import pytest +from unittest.mock import MagicMock, patch + +VAST_PY_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + + +# --------------------------------------------------------------------------- # +# Lint-style tests # +# --------------------------------------------------------------------------- # + +class TestMinimalRawDictAccess: + """Ensure almost no raw rj['key'] access patterns remain on API data.""" + + def test_minimal_rj_bracket_access(self): + """Count rj["..."] patterns -- should be zero after the fix.""" + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + raw_accesses = re.findall(r'rj\["[^"]+"\]', content) + assert len(raw_accesses) == 0, ( + f"Found {len(raw_accesses)} raw rj[\"key\"] accesses; " + f"expected 0. Convert to rj.get('key', default). " + f"Matches: {raw_accesses[:5]}" + ) + + def test_minimal_rj_single_quote_access(self): + """Count rj['...'] patterns -- should be zero after the fix.""" + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + raw_accesses = re.findall(r"rj\['[^']+'\]", content) + assert len(raw_accesses) == 0, ( + f"Found {len(raw_accesses)} raw rj['key'] accesses; " + f"expected 0. Convert to rj.get('key', default). " + f"Matches: {raw_accesses[:5]}" + ) + + def test_minimal_r_json_bracket_access(self): + """Count r.json()["..."] patterns (excluding comments).""" + with open(VAST_PY_PATH, encoding='utf-8') as f: + lines = f.readlines() + raw_accesses = [] + for i, line in enumerate(lines, 1): + stripped = line.lstrip() + if stripped.startswith('#'): + continue + matches = re.findall(r'\.json\(\)\["[^"]+"\]', line) + for m in matches: + raw_accesses.append(f"line {i}: {m}") + assert len(raw_accesses) == 0, ( + f"Found {len(raw_accesses)} raw .json()[\"key\"] accesses in " + f"non-comment lines; expected 0. Matches: {raw_accesses[:5]}" + ) + + def test_high_safe_access_count(self): + """Verify a high number of .get() patterns exist.""" + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + rj_gets = len(re.findall(r'rj\.get\(', content)) + json_gets = len(re.findall(r'\.json\(\)\.get\(', content)) + result_gets = len(re.findall(r'result\.get\(', content)) + total = rj_gets + json_gets + result_gets + assert total >= 60, ( + f"Found only {total} safe .get() accesses on API response dicts " + f"(rj.get: {rj_gets}, .json().get: {json_gets}, result.get: {result_gets}); " + f"expected >= 60 after safe dict access conversion." + ) + + +# --------------------------------------------------------------------------- # +# Functional tests: missing "success" key # +# --------------------------------------------------------------------------- # + +class TestMissingSuccessKey: + """Verify functions handle missing 'success' key without KeyError.""" + + @patch('vast.http_put') + def test_prepay_instance_no_success_key(self, mock_put, capsys): + """prepay__instance should not crash if 'success' key is missing.""" + import vast + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + # API returns dict WITHOUT 'success' key + mock_response.json.return_value = {"some_other_field": 123} + mock_put.return_value = mock_response + + args = argparse.Namespace( + api_key="test-key", url="https://console.vast.ai", + retry=3, raw=False, explain=False, + id=12345, amount=10.0 + ) + + # Should not raise KeyError + vast.prepay__instance(args) + + captured = capsys.readouterr() + # Since success is missing (falsy), it should print the error branch + assert "Unknown error" in captured.out + + @patch('vast.api_call') + def test_label_instance_no_success_key(self, mock_api_call, capsys): + """label__instance should not crash if 'success' key is missing.""" + import vast + + # api_call returns a dict without 'success' key + mock_api_call.return_value = {"some_other_field": 123} + + args = argparse.Namespace( + api_key="test-key", url="https://console.vast.ai", + retry=3, raw=False, explain=False, + id=12345, label="test-label" + ) + + # Should not raise KeyError + vast.label__instance(args) + + captured = capsys.readouterr() + # Since success is missing (falsy), it should print the error branch + assert "Unknown error" in captured.out + + +# --------------------------------------------------------------------------- # +# Functional tests: missing "msg" key # +# --------------------------------------------------------------------------- # + +class TestMissingMsgKey: + """Verify functions handle missing 'msg' key by printing fallback.""" + + @patch('vast.http_put') + def test_prepay_failure_no_msg(self, mock_put, capsys): + """When API returns success=False without msg, print 'Unknown error'.""" + import vast + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = {"success": False} + mock_put.return_value = mock_response + + args = argparse.Namespace( + api_key="test-key", url="https://console.vast.ai", + retry=3, raw=False, explain=False, + id=999, amount=5.0 + ) + + vast.prepay__instance(args) + + captured = capsys.readouterr() + assert "Unknown error" in captured.out + + @patch('vast.http_post') + def test_create_overlay_no_msg(self, mock_post, capsys): + """create__overlay should print 'Unknown error' when msg key missing.""" + import vast + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = {"status": "ok"} # no "msg" key + mock_post.return_value = mock_response + + args = argparse.Namespace( + api_key="test-key", url="https://console.vast.ai", + retry=3, raw=False, explain=False, + name="test-overlay", cluster_id=123 + ) + + vast.create__overlay(args) + + captured = capsys.readouterr() + assert "Unknown error" in captured.out + + +# --------------------------------------------------------------------------- # +# Functional tests: missing data extraction keys # +# --------------------------------------------------------------------------- # + +class TestMissingDataKeys: + """Verify functions handle missing data keys gracefully.""" + + @patch('vast.http_get') + def test_show_instances_missing_instances_key(self, mock_get, capsys): + """show__instances should handle missing 'instances' key without crash.""" + import vast + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = {"success": True} # no "instances" + mock_get.return_value = mock_response + + args = argparse.Namespace( + api_key="test-key", url="https://console.vast.ai", + retry=3, raw=True, explain=False, + quiet=False + ) + + # Should return empty list, not KeyError + result = vast.show__instances(args) + assert result == [] + + @patch('vast.api_call') + def test_show_volumes_missing_volumes_key(self, mock_api_call, capsys): + """show__volumes should handle missing 'volumes' key without crash.""" + import vast + + # api_call returns a dict without 'volumes' key + mock_api_call.return_value = {"success": True} + + args = argparse.Namespace( + api_key="test-key", url="https://console.vast.ai", + retry=3, raw=True, explain=False, + quiet=False, type="all" + ) + + # Should return empty list, not KeyError + result = vast.show__volumes(args) + assert result == [] or result is None diff --git a/tests/regression/test_shell_injection.py b/tests/regression/test_shell_injection.py new file mode 100644 index 00000000..40556aa0 --- /dev/null +++ b/tests/regression/test_shell_injection.py @@ -0,0 +1,116 @@ +"""No shell=True subprocess calls in vast.py. + +The bug: subprocess calls with shell=True are vulnerable to command injection +(CWE-78). User-controlled data (paths, instance IDs, addresses) could be +injected into shell commands. Additionally, subprocess.getoutput() implicitly +uses shell=True. + +The fix: Convert all shell=True calls to argument lists. Replace +subprocess.getoutput("echo $HOME") with os.path.expanduser("~"). Convert +get_update_command() to return a list instead of a string. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import re +import pytest + + +VAST_PY_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + + +class TestNoShellTrue: + """Lint-style tests verifying no shell=True remains in vast.py.""" + + def test_no_shell_true(self): + """No shell=True subprocess calls should exist in vast.py. + + shell=True enables command injection when user-controlled data + is passed to subprocess calls. + """ + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + assert 'shell=True' not in content, ( + "Found shell=True in vast.py. All subprocess calls must use " + "argument lists instead of shell command strings." + ) + + def test_no_subprocess_getoutput(self): + """No subprocess.getoutput() calls should exist in vast.py. + + subprocess.getoutput() implicitly uses shell=True and is + vulnerable to command injection. + """ + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + assert 'subprocess.getoutput' not in content, ( + "Found subprocess.getoutput in vast.py. Use os.path.expanduser " + "or subprocess.run with argument lists instead." + ) + + +class TestGetUpdateCommand: + """Verify get_update_command() returns a list, not a string.""" + + def test_returns_list_for_pip(self): + """get_update_command() should return a list when is_pip_package.""" + import vast + from unittest.mock import patch + + with patch.object(vast, 'is_pip_package', return_value=True): + result = vast.get_update_command("1.2.3") + assert isinstance(result, list), ( + f"get_update_command() returned {type(result).__name__}, expected list" + ) + assert all(isinstance(item, str) for item in result), ( + "All elements of the command list must be strings" + ) + assert "vastai==1.2.3" in result, ( + "Command list should include version-pinned package name" + ) + + def test_returns_list_for_git(self): + """get_update_command() should return a list when not pip.""" + import vast + from unittest.mock import patch + + with patch.object(vast, 'is_pip_package', return_value=False): + result = vast.get_update_command("1.2.3") + assert isinstance(result, list), ( + f"get_update_command() returned {type(result).__name__}, expected list" + ) + assert all(isinstance(item, str) for item in result), ( + "All elements of the command list must be strings" + ) + assert "git" in result, "Git command list should contain 'git'" + + def test_pip_command_has_no_shell_operators(self): + """The pip command list should not contain shell operators.""" + import vast + from unittest.mock import patch + + # Shell operators that indicate command chaining/injection + shell_operators = ['&&', '||', '|', ';', '>', '<', '$(', '`'] + with patch.object(vast, 'is_pip_package', return_value=True): + result = vast.get_update_command("1.2.3") + combined = " ".join(result) + for op in shell_operators: + assert op not in combined, ( + f"Shell operator {op!r} found in pip command: {combined!r}" + ) + + def test_git_command_has_no_shell_operators(self): + """The git command list should not contain && or | operators.""" + import vast + from unittest.mock import patch + + with patch.object(vast, 'is_pip_package', return_value=False): + result = vast.get_update_command("1.2.3") + combined = " ".join(result) + assert "&&" not in combined, ( + "Git command should not contain '&&' -- use separate subprocess calls" + ) + assert "|" not in combined, ( + "Git command should not contain pipe operators" + ) diff --git a/tests/regression/test_show_instances.py b/tests/regression/test_show_instances.py new file mode 100644 index 00000000..fad8df97 --- /dev/null +++ b/tests/regression/test_show_instances.py @@ -0,0 +1,90 @@ +"""show__instances() loop rebinds local variable without updating list. + +The bug: `for row in rows: row = {...}` rebinds the local `row` variable to a +new dict, but the original dict in `rows` is unchanged. The stripped strings +and computed duration are lost. + +The fix: Build a new list and reassign rows. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import argparse +import time +from unittest.mock import MagicMock, patch + + +def test_rows_are_modified_after_loop(): + """After show__instances processes rows, the returned rows have modified data.""" + from vast import show__instances + + mock_instances = [ + { + "id": 12345, + "start_date": time.time() - 3600, # started 1 hour ago + "extra_env": [["KEY1", "val1"], ["KEY2", "val2"]], + "status": "running", + "name": " test ", # has leading/trailing spaces + } + ] + + mock_response = MagicMock() + mock_response.json.return_value = {"instances": mock_instances} + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + + args = argparse.Namespace( + api_key="test", url="https://console.vast.ai", + retry=3, raw=True, explain=False, quiet=False, + curl=False, full=False, + ) + + with patch('vast.apiurl', return_value="https://console.vast.ai/api/v0/instances"), \ + patch('vast.http_get', return_value=mock_response): + result = show__instances(args, extra={}) + + # In raw mode, show__instances returns rows + assert result is not None, "show__instances returned None in raw mode" + assert len(result) > 0, "show__instances returned empty list" + row = result[0] + # The row should have 'duration' field computed from the loop + assert 'duration' in row, "Row missing 'duration' -- loop rebinding bug still present" + assert row['duration'] > 0, f"Duration should be positive, got {row['duration']}" + # extra_env should be converted from list-of-pairs to dict + assert isinstance(row['extra_env'], dict), "extra_env not converted to dict" + assert row['extra_env'].get('KEY1') == 'val1' + + +def test_rows_stripped_strings_preserved(): + """Verify strip_strings is applied and preserved in the returned rows.""" + from vast import show__instances + + mock_instances = [ + { + "id": 99, + "start_date": time.time() - 100, + "extra_env": [], + "status": " running ", + } + ] + + mock_response = MagicMock() + mock_response.json.return_value = {"instances": mock_instances} + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + + args = argparse.Namespace( + api_key="test", url="https://console.vast.ai", + retry=3, raw=True, explain=False, quiet=False, + curl=False, full=False, + ) + + with patch('vast.apiurl', return_value="https://console.vast.ai/api/v0/instances"), \ + patch('vast.http_get', return_value=mock_response): + result = show__instances(args, extra={}) + + assert result is not None + row = result[0] + # strip_strings should have trimmed the status value + assert row['status'] == 'running', f"Expected 'running', got '{row['status']}' -- strip not applied" diff --git a/tests/regression/test_show_machine.py b/tests/regression/test_show_machine.py new file mode 100644 index 00000000..f75d0b56 --- /dev/null +++ b/tests/regression/test_show_machine.py @@ -0,0 +1,58 @@ +"""show__machine() doesn't handle single dict response. + +The bug: api_call returns a dict for single-machine queries. The code +iterates `for row in rows` which iterates over dict KEYS (strings like 'id', +'gpu_name') instead of dicts. display_table also expects a list of dicts. + +The fix: Wrap single dict responses in a list. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import argparse +from unittest.mock import patch, MagicMock + + +def test_single_dict_response_handled(): + """show__machine wraps a single dict response in a list.""" + from vast import show__machine + + single_machine = { + "id": 42, + "gpu_name": "RTX 4090", + "num_gpus": 1, + } + + args = argparse.Namespace( + api_key="test", url="https://console.vast.ai", + retry=3, raw=True, explain=False, quiet=False, + curl=False, id=42, + ) + + with patch('vast.api_call', return_value=single_machine): + result = show__machine(args) + + # In raw mode, should return a list (wrapped dict) + assert isinstance(result, list), f"Expected list, got {type(result)}" + assert len(result) == 1 + assert result[0]['id'] == 42 + + +def test_list_response_unchanged(): + """show__machine leaves list responses as-is.""" + from vast import show__machine + + machine_list = [{"id": 42, "gpu_name": "RTX 4090"}] + + args = argparse.Namespace( + api_key="test", url="https://console.vast.ai", + retry=3, raw=True, explain=False, quiet=False, + curl=False, id=42, + ) + + with patch('vast.api_call', return_value=machine_list): + result = show__machine(args) + + assert isinstance(result, list) + assert len(result) == 1 diff --git a/tests/regression/test_timeout.py b/tests/regression/test_timeout.py new file mode 100644 index 00000000..330d994c --- /dev/null +++ b/tests/regression/test_timeout.py @@ -0,0 +1,129 @@ +"""Missing timeout on HTTP requests. + +The bug: session.send() has no timeout parameter. Requests can hang +indefinitely if the server never responds or the connection stalls. + +The fix: Add timeout=DEFAULT_TIMEOUT (30s) to http_request() and forward +it to session.send(). All wrapper functions (http_get, http_post, http_put, +http_del) accept and forward the timeout parameter. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import argparse +from unittest.mock import MagicMock, patch, call +import pytest + + +def _make_args(retry=1): + return argparse.Namespace(retry=retry, curl=False) + + +@patch('vast.time.sleep') +@patch('vast.requests.Session') +def test_default_timeout_forwarded(mock_session_cls, mock_sleep): + """http_request passes timeout=30 (DEFAULT_TIMEOUT) to session.send by default.""" + from vast import http_request, DEFAULT_TIMEOUT + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + success_response = MagicMock() + success_response.status_code = 200 + mock_session.send.return_value = success_response + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=1) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test') + + assert result.status_code == 200 + # Verify timeout was passed to session.send + mock_session.send.assert_called_once() + _, kwargs = mock_session.send.call_args + assert kwargs.get('timeout') == DEFAULT_TIMEOUT + assert kwargs.get('timeout') == 30 + + +@patch('vast.time.sleep') +@patch('vast.requests.Session') +def test_custom_timeout_forwarded(mock_session_cls, mock_sleep): + """http_request forwards a custom timeout value to session.send.""" + from vast import http_request + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + success_response = MagicMock() + success_response.status_code = 200 + mock_session.send.return_value = success_response + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=1) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test', timeout=120) + + assert result.status_code == 200 + _, kwargs = mock_session.send.call_args + assert kwargs.get('timeout') == 120 + + +@patch('vast.http_request') +def test_http_get_forwards_timeout(mock_http_request): + """http_get forwards timeout parameter to http_request.""" + from vast import http_get + + mock_http_request.return_value = MagicMock(status_code=200) + args = _make_args() + + http_get(args, 'http://example.com/test', timeout=60) + + mock_http_request.assert_called_once() + _, kwargs = mock_http_request.call_args + assert kwargs.get('timeout') == 60 + + +@patch('vast.http_request') +def test_http_post_forwards_timeout(mock_http_request): + """http_post forwards timeout parameter to http_request.""" + from vast import http_post + + mock_http_request.return_value = MagicMock(status_code=200) + args = _make_args() + + http_post(args, 'http://example.com/test', timeout=90) + + mock_http_request.assert_called_once() + _, kwargs = mock_http_request.call_args + assert kwargs.get('timeout') == 90 + + +@patch('vast.http_request') +def test_http_put_forwards_timeout(mock_http_request): + """http_put forwards timeout parameter to http_request.""" + from vast import http_put + + mock_http_request.return_value = MagicMock(status_code=200) + args = _make_args() + + http_put(args, 'http://example.com/test', timeout=45) + + mock_http_request.assert_called_once() + _, kwargs = mock_http_request.call_args + assert kwargs.get('timeout') == 45 + + +@patch('vast.http_request') +def test_http_del_forwards_timeout(mock_http_request): + """http_del forwards timeout parameter to http_request.""" + from vast import http_del + + mock_http_request.return_value = MagicMock(status_code=200) + args = _make_args() + + http_del(args, 'http://example.com/test', timeout=15) + + mock_http_request.assert_called_once() + _, kwargs = mock_http_request.call_args + assert kwargs.get('timeout') == 15 diff --git a/tests/regression/test_timezone_handling.py b/tests/regression/test_timezone_handling.py new file mode 100644 index 00000000..2bddcad0 --- /dev/null +++ b/tests/regression/test_timezone_handling.py @@ -0,0 +1,41 @@ +"""Timezone handling uses time.mktime() which interprets as local time. + +The bug: time.mktime() converts a time tuple to epoch using the LOCAL timezone. +For a user in PST (UTC-8), a date "01/15/2025" would produce an epoch value +that's 8 hours later than UTC midnight, giving wrong results. + +The fix: Use calendar.timegm() which always interprets the time tuple as UTC. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import calendar + + +def test_string_to_unix_epoch_utc(): + """string_to_unix_epoch returns UTC timestamps regardless of local timezone.""" + from vast import string_to_unix_epoch + + # 01/15/2025 00:00:00 UTC = 1736899200 + result = string_to_unix_epoch("01/15/2025") + expected = calendar.timegm((2025, 1, 15, 0, 0, 0, 0, 0, 0)) + assert expected == 1736899200, f"Sanity check: expected 1736899200, got {expected}" + assert result == expected, ( + f"string_to_unix_epoch('01/15/2025') returned {result}, expected {expected}. " + f"This likely means time.mktime() is still being used instead of calendar.timegm()." + ) + + +def test_string_to_unix_epoch_returns_float_passthrough(): + """string_to_unix_epoch returns float values as-is.""" + from vast import string_to_unix_epoch + + assert string_to_unix_epoch("1736899200") == 1736899200.0 + + +def test_string_to_unix_epoch_none(): + """string_to_unix_epoch returns None for None input.""" + from vast import string_to_unix_epoch + + assert string_to_unix_epoch(None) is None diff --git a/tests/regression/test_unreachable_code.py b/tests/regression/test_unreachable_code.py new file mode 100644 index 00000000..f70e3b93 --- /dev/null +++ b/tests/regression/test_unreachable_code.py @@ -0,0 +1,227 @@ +"""No unreachable code after raise_for_status(). + +The bug: 19 functions had `if (r.status_code == 200):` guards immediately +after `r.raise_for_status()`. Since raise_for_status() raises HTTPError for +non-2xx responses, the status_code check was always True and the `else` branch +(printing "failed with error {r.status_code}") was unreachable dead code. + +The fix: Remove the redundant status_code == 200 wrapper, de-indent the +success path, and remove the unreachable else branches. API-level success +checks (rj.get("success")) are preserved since those check the JSON body, +not the HTTP status. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import re +import argparse +import pytest +from unittest import mock +from unittest.mock import MagicMock, patch +from io import StringIO + +VAST_PY_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + + +# --------------------------------------------------------------------------- # +# Lint-style tests # +# --------------------------------------------------------------------------- # + +class TestNoUnreachableStatusCheck: + """Ensure no status_code == 200 checks follow raise_for_status().""" + + def test_no_status_check_after_raise_for_status(self): + """Pattern: raise_for_status() followed within 3 lines by status_code == 200.""" + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + # Match raise_for_status() followed (within a few lines) by status_code == 200 + pattern = r'raise_for_status\(\)\s*\n(?:\s*\n)*\s*if\s*\(?r\.status_code\s*==\s*200' + matches = re.findall(pattern, content) + assert len(matches) == 0, ( + f"Found {len(matches)} unreachable status_code == 200 checks after " + f"raise_for_status(). These are unreachable and should be removed." + ) + + def test_no_status_check_after_raise_for_status_response_var(self): + """Same pattern but with 'response' variable name.""" + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + pattern = r'raise_for_status\(\)\s*\n(?:\s*\n)*\s*if\s*\(?response\.status_code\s*==\s*200' + matches = re.findall(pattern, content) + assert len(matches) == 0, ( + f"Found {len(matches)} unreachable status_code == 200 checks (response var) " + f"after raise_for_status()." + ) + + def test_no_unreachable_failed_with_error_after_unconditional_raise(self): + """The 'failed with error' message should not appear after an + *unconditional* raise_for_status() call (same indentation as function body). + Conditional raise_for_status() calls (inside if blocks) are excluded.""" + with open(VAST_PY_PATH, encoding='utf-8') as f: + lines = f.readlines() + + # Find unconditional raise_for_status lines (indented exactly 4 spaces, + # i.e., at function-body level, not nested inside an if/else) + rfs_indices = [] + for i, line in enumerate(lines): + stripped = line.rstrip() + if 'raise_for_status()' in stripped and not stripped.lstrip().startswith('#'): + # Check indentation: unconditional means base function indent (4 spaces) + indent = len(line) - len(line.lstrip()) + if indent == 4: + rfs_indices.append(i) + + for rfs_idx in rfs_indices: + # Check within 20 lines after raise_for_status for the dead pattern + for offset in range(1, 20): + check_idx = rfs_idx + offset + if check_idx >= len(lines): + break + line = lines[check_idx] + # If we hit a new function def or decorator, stop scanning + if re.match(r'^def\s+', line) or re.match(r'^@', line): + break + if 'failed with error {r.status_code}' in line: + assert False, ( + f"Line {check_idx + 1}: Found unreachable 'failed with error' " + f"message after unconditional raise_for_status() at line {rfs_idx + 1}" + ) + + +# --------------------------------------------------------------------------- # +# Functional tests # +# --------------------------------------------------------------------------- # + +class TestStartInstanceBehavior: + """Verify start_instance works correctly after unreachable code removal.""" + + @patch('vast.http_put') + def test_success_prints_message(self, mock_put, capsys): + """When API returns 200 with success=True, print starting message.""" + import vast + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = {"success": True} + mock_put.return_value = mock_response + + args = argparse.Namespace( + api_key="test-key", url="https://console.vast.ai", + retry=3, raw=False, explain=False + ) + + result = vast.start_instance(12345, args) + + assert result is True + captured = capsys.readouterr() + assert "starting instance" in captured.out + + @patch('vast.http_put') + def test_api_failure_prints_msg(self, mock_put, capsys): + """When API returns 200 with success=False, print the error msg.""" + import vast + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = {"success": False, "msg": "instance not found"} + mock_put.return_value = mock_response + + args = argparse.Namespace( + api_key="test-key", url="https://console.vast.ai", + retry=3, raw=False, explain=False + ) + + result = vast.start_instance(12345, args) + + assert result is True + captured = capsys.readouterr() + assert "instance not found" in captured.out + + @patch('vast.http_put') + def test_http_error_raises(self, mock_put): + """When API returns 500, raise_for_status raises HTTPError.""" + import vast + import requests + + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError( + "500 Server Error" + ) + mock_put.return_value = mock_response + + args = argparse.Namespace( + api_key="test-key", url="https://console.vast.ai", + retry=3, raw=False, explain=False + ) + + with pytest.raises(requests.exceptions.HTTPError): + vast.start_instance(12345, args) + + @patch('vast.http_put') + def test_missing_msg_uses_default(self, mock_put, capsys): + """When API returns success=False without msg, use default error.""" + import vast + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = {"success": False} + mock_put.return_value = mock_response + + args = argparse.Namespace( + api_key="test-key", url="https://console.vast.ai", + retry=3, raw=False, explain=False + ) + + vast.start_instance(12345, args) + + captured = capsys.readouterr() + assert "Unknown error" in captured.out + + +class TestDestroyInstanceBehavior: + """Verify destroy_instance preserves raw mode path after refactor.""" + + @patch('vast.http_del') + def test_raw_mode_returns_json(self, mock_del): + """In raw mode, destroy_instance returns parsed JSON.""" + import vast + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = {"success": True} + mock_del.return_value = mock_response + + args = argparse.Namespace( + api_key="test-key", url="https://console.vast.ai", + retry=3, raw=True, explain=False + ) + + result = vast.destroy_instance(12345, args) + assert result == {"success": True} + + @patch('vast.http_del') + def test_non_raw_prints_destroying(self, mock_del, capsys): + """In non-raw mode, destroy_instance prints 'destroying instance'.""" + import vast + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = {"success": True} + mock_del.return_value = mock_response + + args = argparse.Namespace( + api_key="test-key", url="https://console.vast.ai", + retry=3, raw=False, explain=False + ) + + vast.destroy_instance(12345, args) + + captured = capsys.readouterr() + assert "destroying instance" in captured.out diff --git a/tests/regression/test_utc_display.py b/tests/regression/test_utc_display.py new file mode 100644 index 00000000..0d31ca32 --- /dev/null +++ b/tests/regression/test_utc_display.py @@ -0,0 +1,97 @@ +"""UTC-labeled timestamps must actually display UTC. + +The bug: datetime.fromtimestamp(ts) without tz= returns LOCAL time, but +the columns are labeled "UTC" or the output is treated as UTC. Users in +non-UTC timezones see wrong times. + +The fix: All fromtimestamp() calls that produce UTC-labeled output now +pass tz=timezone.utc explicitly. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import re +import pytest +from datetime import datetime, timezone + + +VAST_PY_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + + +class TestUnixToReadableUTC: + """Verify unix_to_readable() produces UTC output.""" + + def test_known_epoch_returns_utc_midnight(self): + """1704067200 is 2024-01-01 00:00:00 UTC. Output must show 00:00:00.""" + from vast import unix_to_readable + result = unix_to_readable(1704067200) + assert "00:00:00" in result, ( + f"unix_to_readable(1704067200) should show 00:00:00 UTC, got: {result}" + ) + + def test_known_epoch_returns_correct_date(self): + """1704067200 is 2024-01-01 UTC. Output must contain Jan-01-2024.""" + from vast import unix_to_readable + result = unix_to_readable(1704067200) + assert "Jan-01-2024" in result, ( + f"unix_to_readable(1704067200) should contain Jan-01-2024, got: {result}" + ) + + def test_midday_epoch_shows_utc_time(self): + """1704110400 is 2024-01-01 12:00:00 UTC. Output must show 12:00:00.""" + from vast import unix_to_readable + result = unix_to_readable(1704110400) + assert "12:00:00" in result, ( + f"unix_to_readable(1704110400) should show 12:00:00 UTC, got: {result}" + ) + + +class TestFromtimestampCallsUseUTC: + """Lint-style test: every fromtimestamp() call must use tz=timezone.utc.""" + + def test_all_fromtimestamp_calls_have_tz_utc(self): + """Every fromtimestamp( call in vast.py should include tz=timezone.utc. + + This prevents regressions where a new fromtimestamp() call is added + without the timezone parameter, silently producing local time. + """ + with open(VAST_PY_PATH, encoding='utf-8') as f: + lines = f.readlines() + + violations = [] + for i, line in enumerate(lines, 1): + # Skip comments + stripped = line.strip() + if stripped.startswith('#'): + continue + # Find fromtimestamp( calls + if 'fromtimestamp(' in line and 'utcfromtimestamp' not in line: + if 'tz=timezone.utc' not in line: + violations.append(f"Line {i}: {stripped}") + + assert len(violations) == 0, ( + f"Found {len(violations)} fromtimestamp() calls without tz=timezone.utc:\n" + + "\n".join(violations) + ) + + +class TestCacheAgeUsesAwareDatetimes: + """Verify cache age calculation uses timezone-aware datetimes on both sides.""" + + def test_cache_age_uses_tz_aware_now(self): + """datetime.now() in cache age must use tz=timezone.utc.""" + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + + # Find the cache_age line + match = re.search(r'cache_age\s*=\s*(.+)', content) + assert match is not None, "cache_age assignment not found in vast.py" + + cache_age_line = match.group(1) + assert 'datetime.now(tz=timezone.utc)' in cache_age_line, ( + f"cache_age should use datetime.now(tz=timezone.utc), got: {cache_age_line}" + ) + assert 'fromtimestamp(' in cache_age_line and 'tz=timezone.utc' in cache_age_line, ( + f"cache_age should use fromtimestamp with tz=timezone.utc, got: {cache_age_line}" + ) diff --git a/tests/regression/test_utcfromtimestamp.py b/tests/regression/test_utcfromtimestamp.py new file mode 100644 index 00000000..4d5dd255 --- /dev/null +++ b/tests/regression/test_utcfromtimestamp.py @@ -0,0 +1,93 @@ +"""No deprecated utcfromtimestamp() calls in vast.py. + +The bug: datetime.utcfromtimestamp() is deprecated since Python 3.12 and +will be removed in a future version. It also returns a naive datetime that +is ambiguous about its timezone. + +The fix: Replace all utcfromtimestamp() calls with +datetime.fromtimestamp(ts, tz=timezone.utc), which returns an aware +datetime and is not deprecated. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import re +import pytest +from datetime import datetime, timezone + + +VAST_PY_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'vast.py') + + +class TestNoUtcfromtimestamp: + """Lint-style tests verifying utcfromtimestamp is not used anywhere.""" + + def test_no_utcfromtimestamp_in_source(self): + """utcfromtimestamp should not appear anywhere in vast.py. + + This deprecated method returns naive datetimes and will be removed + in a future Python version. Use fromtimestamp(ts, tz=timezone.utc). + """ + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + + matches = re.findall(r'utcfromtimestamp', content) + assert len(matches) == 0, ( + f"Found {len(matches)} occurrences of utcfromtimestamp in vast.py. " + "Use datetime.fromtimestamp(ts, tz=timezone.utc) instead." + ) + + def test_no_utcnow_in_source(self): + """utcnow() is also deprecated for the same reason as utcfromtimestamp. + + Prevent regression by ensuring neither deprecated UTC method is used. + """ + with open(VAST_PY_PATH, encoding='utf-8') as f: + content = f.read() + + # Match utcnow() but not in comments + lines = content.split('\n') + violations = [] + for i, line in enumerate(lines, 1): + stripped = line.strip() + if stripped.startswith('#'): + continue + if 'utcnow()' in line: + violations.append(f"Line {i}: {stripped}") + + assert len(violations) == 0, ( + f"Found {len(violations)} occurrences of utcnow() in vast.py. " + "Use datetime.now(tz=timezone.utc) instead.\n" + + "\n".join(violations) + ) + + +class TestReplacementProducesAwareDatetime: + """Verify the replacement fromtimestamp(ts, tz=timezone.utc) works correctly.""" + + def test_fromtimestamp_with_tz_returns_aware_datetime(self): + """datetime.fromtimestamp(ts, tz=timezone.utc) must return a tz-aware datetime.""" + ts = 1704067200 # 2024-01-01 00:00:00 UTC + dt = datetime.fromtimestamp(ts, tz=timezone.utc) + assert dt.tzinfo is not None, "Result should be timezone-aware" + assert dt.tzinfo == timezone.utc, "Result timezone should be UTC" + + def test_fromtimestamp_with_tz_correct_values(self): + """Verify the replacement produces correct year/month/day/hour.""" + ts = 1704067200 # 2024-01-01 00:00:00 UTC + dt = datetime.fromtimestamp(ts, tz=timezone.utc) + assert dt.year == 2024 + assert dt.month == 1 + assert dt.day == 1 + assert dt.hour == 0 + assert dt.minute == 0 + assert dt.second == 0 + + def test_fromtimestamp_with_tz_matches_expected_format(self): + """The replacement in schedule_maintenance should format correctly.""" + ts = 1704067200 # 2024-01-01 00:00:00 UTC + dt = datetime.fromtimestamp(ts, tz=timezone.utc) + # The schedule_maintenance function uses str(dt) implicitly in f-string + dt_str = str(dt) + assert "2024-01-01" in dt_str, f"Expected 2024-01-01 in {dt_str}" diff --git a/tests/smoke/__init__.py b/tests/smoke/__init__.py new file mode 100644 index 00000000..157ea08b --- /dev/null +++ b/tests/smoke/__init__.py @@ -0,0 +1 @@ +"""Smoke tests for CLI commands and standalone execution.""" diff --git a/tests/smoke/test_cli_commands.py b/tests/smoke/test_cli_commands.py new file mode 100644 index 00000000..dab965bf --- /dev/null +++ b/tests/smoke/test_cli_commands.py @@ -0,0 +1,429 @@ +"""Smoke tests for CLI commands. + +TEST-08: CLI smoke tests - major commands parse args and call correct endpoints. + +These tests verify that CLI commands: +1. Parse their arguments correctly +2. Call the expected API endpoints +3. Handle mocked responses appropriately + +All HTTP is mocked - no network calls are made. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import argparse +import json +import pytest +from unittest.mock import MagicMock, patch + + +def _make_base_args(**overrides): + """Create base args namespace with common fields.""" + args = argparse.Namespace( + api_key="test-key", + url="https://console.vast.ai", + retry=3, + raw=False, + explain=False, + quiet=False, + curl=False, + full=False, + no_color=True, + debugging=False, + ) + for key, value in overrides.items(): + setattr(args, key, value) + return args + + +@pytest.fixture(autouse=True) +def setup_vast_args(): + """Set up vast.ARGS to prevent NoneType errors in http_request.""" + import vast + old_args = vast.ARGS + vast.ARGS = _make_base_args() + yield + vast.ARGS = old_args + + +class TestSearchOffers: + """Smoke tests for 'search offers' command.""" + + @patch('vast.http_post') + def test_search_offers_calls_bundles_endpoint(self, mock_http_post): + """search offers calls /api/v0/bundles/ endpoint via POST.""" + import vast + + mock_response = MagicMock() + mock_response.json.return_value = {"offers": []} + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.raise_for_status = MagicMock() + mock_http_post.return_value = mock_response + + args = _make_base_args( + query=["gpu_ram>=8"], + type="bid", + raw=True, + no_default=False, + new=False, + limit=None, + disable_bundling=False, + storage=5.0, + order="score-", + ) + vast.search__offers(args) + + mock_http_post.assert_called_once() + call_url = mock_http_post.call_args[0][1] + assert "/api/v0/bundles" in call_url + + @patch('vast.http_post') + def test_search_offers_with_gpu_name(self, mock_http_post): + """search offers parses gpu_name filter.""" + import vast + + mock_response = MagicMock() + mock_response.json.return_value = {"offers": []} + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.raise_for_status = MagicMock() + mock_http_post.return_value = mock_response + + args = _make_base_args( + query=["gpu_name=RTX_4090"], + type="on-demand", + raw=True, + no_default=False, + new=False, + limit=None, + disable_bundling=False, + storage=5.0, + order="score-", + ) + vast.search__offers(args) + + mock_http_post.assert_called_once() + + +class TestShowInstances: + """Smoke tests for 'show instances' command.""" + + @patch('vast.http_get') + def test_show_instances_calls_instances_endpoint(self, mock_http_get): + """show instances calls /api/v0/instances/ endpoint.""" + import vast + + mock_response = MagicMock() + mock_response.json.return_value = {"instances": []} + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_http_get.return_value = mock_response + + args = _make_base_args(raw=True) + vast.show__instances(args) + + mock_http_get.assert_called_once() + call_url = mock_http_get.call_args[0][1] + assert "/api/v0/instances" in call_url + + @patch('vast.http_get') + def test_show_instances_returns_list(self, mock_http_get): + """show instances handles list response correctly.""" + import vast + + mock_response = MagicMock() + mock_response.json.return_value = { + "instances": [ + {"id": 123, "status": "running", "start_date": 1700000000, "extra_env": []}, + {"id": 456, "status": "stopped", "start_date": 1700000000, "extra_env": []} + ] + } + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_http_get.return_value = mock_response + + args = _make_base_args(raw=True) + result = vast.show__instances(args) + + # In raw mode, should return the data + assert result is not None or mock_http_get.called + + +class TestShowMachines: + """Smoke tests for 'show machines' command.""" + + @patch('vast.api_call') + def test_show_machines_calls_machines_endpoint(self, mock_api_call): + """show machines calls /api/v0/machines/ endpoint.""" + import vast + + mock_api_call.return_value = {"machines": []} + + args = _make_base_args(raw=True, quiet=True) + vast.show__machines(args) + + mock_api_call.assert_called_once() + call_args = mock_api_call.call_args + assert call_args[0][1] == "GET" + assert "/machines" in call_args[0][2] + + +class TestShowUser: + """Smoke tests for 'show user' command.""" + + @patch('vast.api_call') + def test_show_user_calls_users_endpoint(self, mock_api_call): + """show user calls /api/v0/users/current endpoint.""" + import vast + + mock_api_call.return_value = {"id": 12345, "username": "testuser", "api_key": "secret"} + + args = _make_base_args(raw=True, quiet=True) + vast.show__user(args) + + mock_api_call.assert_called_once() + call_args = mock_api_call.call_args + assert call_args[0][1] == "GET" + assert "/users/current" in call_args[0][2] + + +class TestCreateInstance: + """Smoke tests for 'create instance' command.""" + + @patch('vast.http_put') + def test_create_instance_calls_asks_endpoint(self, mock_http_put): + """create instance calls /api/v0/asks/{id}/ endpoint.""" + import vast + + mock_response = MagicMock() + mock_response.json.return_value = {"success": True, "new_contract": 789} + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_http_put.return_value = mock_response + + args = _make_base_args( + id=12345, + bid_price=None, + disk=20.0, + image="pytorch/pytorch:latest", + raw=True, + onstart=None, + onstart_cmd=None, + entrypoint=None, + env=None, + args=None, + label=None, + extra=None, + jupyter=False, + jupyter_dir=None, + jupyter_lab=False, + lang_utf8=False, + python_utf8=False, + ssh=False, + direct=False, + cancel_unavail=False, + force=False, + login=None, + template_hash=None, + user=None, + create_volume=None, + link_volume=None, + ) + vast.create__instance(args) + + mock_http_put.assert_called_once() + call_url = mock_http_put.call_args[0][1] + assert "/api/v0/asks" in call_url + + +class TestDestroyInstance: + """Smoke tests for 'destroy instance' command.""" + + @patch('vast.http_del') + def test_destroy_instance_calls_instances_endpoint(self, mock_http_del): + """destroy instance calls DELETE /api/v0/instances/{id}/ endpoint.""" + import vast + + mock_response = MagicMock() + mock_response.json.return_value = {"success": True} + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_http_del.return_value = mock_response + + args = _make_base_args(id=12345, raw=True) + vast.destroy__instance(args) + + mock_http_del.assert_called_once() + call_url = mock_http_del.call_args[0][1] + assert "/api/v0/instances" in call_url + assert "12345" in call_url + + +class TestLogsCommand: + """Smoke tests for 'logs' command.""" + + @patch('vast.http_get') + @patch('vast.http_put') + def test_logs_calls_instances_endpoint(self, mock_http_put, mock_http_get): + """logs command calls appropriate endpoint.""" + import vast + + # Mock the logs request + mock_put_response = MagicMock() + mock_put_response.json.return_value = {"result_url": "https://example.com/logs"} + mock_put_response.status_code = 200 + mock_put_response.raise_for_status = MagicMock() + mock_http_put.return_value = mock_put_response + + # Mock the log fetch + mock_get_response = MagicMock() + mock_get_response.status_code = 200 + mock_get_response.text = "Log output here" + mock_http_get.return_value = mock_get_response + + args = _make_base_args( + INSTANCE_ID=123, + raw=True, + tail=None, + filter=None, + daemon_logs=False, + ) + + # logs may print and not return in non-raw mode + vast.logs(args) + + # Should have made HTTP calls + assert mock_http_put.called + + +class TestSetApiKey: + """Smoke tests for 'set api-key' command.""" + + @patch('builtins.open', create=True) + @patch('os.path.exists') + def test_set_api_key_writes_file(self, mock_exists, mock_open): + """set api-key writes key to config file.""" + import vast + + mock_exists.return_value = False + mock_file = MagicMock() + mock_open.return_value.__enter__ = MagicMock(return_value=mock_file) + mock_open.return_value.__exit__ = MagicMock(return_value=False) + + args = _make_base_args(new_api_key="sk-test-key-12345") + vast.set__api_key(args) + + mock_open.assert_called() + + +class TestShowApiKeys: + """Smoke tests for 'show api-keys' command.""" + + @patch('vast.http_get') + def test_show_api_keys_calls_auth_endpoint(self, mock_http_get): + """show api-keys calls /api/v0/auth/apikeys/ endpoint.""" + import vast + + mock_response = MagicMock() + mock_response.json.return_value = [{"id": 1, "name": "test-key"}] + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_http_get.return_value = mock_response + + args = _make_base_args(raw=True) + result = vast.show__api_keys(args) + + mock_http_get.assert_called_once() + call_url = mock_http_get.call_args[0][1] + assert "/api/v0/auth/apikeys" in call_url + + +class TestParserStructure: + """Tests verifying the argparse parser structure.""" + + def test_parser_has_subcommands(self): + """Parser has expected subcommand structure.""" + import vast + + # Parser should exist + assert hasattr(vast, 'parser') + + # Should be able to parse help + with pytest.raises(SystemExit) as exc_info: + vast.parser.parse_args(['--help']) + assert exc_info.value.code == 0 + + def test_search_offers_subcommand_exists(self): + """search offers subcommand is registered.""" + import vast + + # Should not raise - parse_args with --help raises SystemExit(0) + with pytest.raises(SystemExit) as exc_info: + vast.parser.parse_args(['search', 'offers', '--help']) + assert exc_info.value.code == 0 + + def test_show_instances_subcommand_exists(self): + """show instances subcommand is registered.""" + import vast + + # Parse minimal args (should work) + args = vast.parser.parse_args(['show', 'instances']) + assert args is not None + + def test_create_instance_subcommand_exists(self): + """create instance subcommand is registered.""" + import vast + + # Should parse with required args + with pytest.raises(SystemExit) as exc_info: + vast.parser.parse_args(['create', 'instance', '--help']) + assert exc_info.value.code == 0 + + def test_destroy_instance_subcommand_exists(self): + """destroy instance subcommand is registered.""" + import vast + + # Should parse with required args + with pytest.raises(SystemExit) as exc_info: + vast.parser.parse_args(['destroy', 'instance', '--help']) + assert exc_info.value.code == 0 + + def test_show_user_subcommand_exists(self): + """show user subcommand is registered.""" + import vast + + args = vast.parser.parse_args(['show', 'user']) + assert args is not None + + def test_show_machines_subcommand_exists(self): + """show machines subcommand is registered.""" + import vast + + args = vast.parser.parse_args(['show', 'machines']) + assert args is not None + + def test_set_api_key_subcommand_exists(self): + """set api-key subcommand is registered.""" + import vast + + with pytest.raises(SystemExit) as exc_info: + vast.parser.parse_args(['set', 'api-key', '--help']) + assert exc_info.value.code == 0 + + def test_logs_subcommand_exists(self): + """logs subcommand is registered.""" + import vast + + with pytest.raises(SystemExit) as exc_info: + vast.parser.parse_args(['logs', '--help']) + assert exc_info.value.code == 0 + + def test_show_api_keys_subcommand_exists(self): + """show api-keys subcommand is registered.""" + import vast + + args = vast.parser.parse_args(['show', 'api-keys']) + assert args is not None diff --git a/tests/smoke/test_standalone.py b/tests/smoke/test_standalone.py new file mode 100644 index 00000000..918689cd --- /dev/null +++ b/tests/smoke/test_standalone.py @@ -0,0 +1,275 @@ +"""Smoke tests for standalone vast.py execution. + +TEST-09: Standalone vast.py smoke test - python vast.py --help works without pip dependencies. + +This test verifies that vast.py can be executed as a standalone script +with only the minimal dependencies (requests, python-dateutil) available. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import subprocess +import pytest + + +# Get the path to vast.py relative to this test file +VAST_CLI_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) +VAST_PY_PATH = os.path.join(VAST_CLI_DIR, 'vast.py') + + +class TestStandaloneHelp: + """Tests for standalone vast.py --help execution.""" + + def test_vast_help_exits_zero(self): + """vast.py --help exits with code 0.""" + result = subprocess.run( + [sys.executable, VAST_PY_PATH, '--help'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + + assert result.returncode == 0, f"Exit code was {result.returncode}, stderr: {result.stderr}" + + def test_vast_help_contains_usage(self): + """vast.py --help output contains usage information.""" + result = subprocess.run( + [sys.executable, VAST_PY_PATH, '--help'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + + assert 'usage' in result.stdout.lower() or 'vast' in result.stdout.lower(), \ + f"Help output missing expected content: {result.stdout[:500]}" + + def test_vast_help_contains_commands(self): + """vast.py --help output lists available commands.""" + result = subprocess.run( + [sys.executable, VAST_PY_PATH, '--help'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + + # Should mention some key commands + output = result.stdout.lower() + assert 'search' in output or 'show' in output or 'create' in output, \ + f"Help output missing command listings: {result.stdout[:500]}" + + +class TestSubcommandHelp: + """Tests for subcommand help output.""" + + def test_search_offers_help(self): + """vast.py search offers --help works.""" + result = subprocess.run( + [sys.executable, VAST_PY_PATH, 'search', 'offers', '--help'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + + assert result.returncode == 0, f"Exit code was {result.returncode}, stderr: {result.stderr}" + assert 'search' in result.stdout.lower() or 'offers' in result.stdout.lower() + + def test_show_instances_help(self): + """vast.py show instances --help works.""" + result = subprocess.run( + [sys.executable, VAST_PY_PATH, 'show', 'instances', '--help'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + + assert result.returncode == 0, f"Exit code was {result.returncode}, stderr: {result.stderr}" + + def test_create_instance_help(self): + """vast.py create instance --help works.""" + result = subprocess.run( + [sys.executable, VAST_PY_PATH, 'create', 'instance', '--help'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + + assert result.returncode == 0, f"Exit code was {result.returncode}, stderr: {result.stderr}" + + def test_destroy_instance_help(self): + """vast.py destroy instance --help works.""" + result = subprocess.run( + [sys.executable, VAST_PY_PATH, 'destroy', 'instance', '--help'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + + assert result.returncode == 0, f"Exit code was {result.returncode}, stderr: {result.stderr}" + + +class TestVersionFlag: + """Tests for version flag if implemented.""" + + def test_vast_version_or_help(self): + """vast.py responds to --version or --help without error.""" + # Try --version first + result = subprocess.run( + [sys.executable, VAST_PY_PATH, '--version'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + + # If --version isn't implemented, that's okay - just verify it doesn't crash + # (might return non-zero for unrecognized flag, but shouldn't hang or throw) + assert result.returncode in [0, 1, 2], \ + f"Unexpected exit code {result.returncode}, stderr: {result.stderr}" + + +class TestInvalidCommand: + """Tests for invalid command handling.""" + + def test_invalid_command_exits_nonzero(self): + """vast.py with invalid command exits with non-zero code.""" + result = subprocess.run( + [sys.executable, VAST_PY_PATH, 'not_a_real_command'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + + # Should exit non-zero for invalid command + assert result.returncode != 0, "Invalid command should exit non-zero" + + def test_invalid_command_prints_error(self): + """vast.py with invalid command prints error message.""" + result = subprocess.run( + [sys.executable, VAST_PY_PATH, 'not_a_real_command'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + + # Should have some error output + combined_output = result.stdout + result.stderr + assert len(combined_output) > 0, "Should produce some output for invalid command" + + +class TestImportOnly: + """Tests that verify vast.py can be imported without side effects.""" + + def test_vast_importable(self): + """vast.py is importable as a module.""" + # This runs in the test process, verifying import works + import vast + + assert hasattr(vast, 'parser') + assert hasattr(vast, 'main') + + def test_vast_main_exists(self): + """vast.main() function exists.""" + import vast + + assert callable(vast.main) + + def test_vast_has_core_functions(self): + """vast module has core CLI functions.""" + import vast + + # Check for key command functions + assert hasattr(vast, 'search__offers') + assert hasattr(vast, 'show__instances') + assert hasattr(vast, 'create__instance') + assert hasattr(vast, 'destroy__instance') + + def test_vast_has_http_helpers(self): + """vast module has HTTP helper functions.""" + import vast + + assert hasattr(vast, 'http_get') + assert hasattr(vast, 'http_post') + assert hasattr(vast, 'http_put') + assert hasattr(vast, 'http_del') + + +class TestMinimalDependencies: + """Tests verifying vast.py works with minimal dependencies.""" + + def test_import_only_requires_requests(self): + """vast.py import only requires requests (and python-dateutil).""" + # Run a subprocess that only has requests available + # This is tested implicitly by the fact that we can import vast + # without extra dependencies installed + result = subprocess.run( + [sys.executable, '-c', 'import vast; print("OK")'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + + assert result.returncode == 0, f"Import failed: {result.stderr}" + assert "OK" in result.stdout + + def test_help_runs_without_optional_deps(self): + """vast.py --help works even if optional dependencies are missing.""" + # argcomplete and curlify are optional + # This test verifies help still works + result = subprocess.run( + [sys.executable, VAST_PY_PATH, '--help'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + + assert result.returncode == 0 + + +class TestExitCodes: + """Tests for proper exit codes.""" + + def test_help_exits_zero(self): + """--help exits with code 0.""" + result = subprocess.run( + [sys.executable, VAST_PY_PATH, '--help'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + assert result.returncode == 0 + + def test_subcommand_help_exits_zero(self): + """Subcommand --help exits with code 0.""" + result = subprocess.run( + [sys.executable, VAST_PY_PATH, 'show', 'instances', '--help'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + assert result.returncode == 0 + + def test_missing_required_arg_exits_nonzero(self): + """Missing required argument exits with non-zero code.""" + # create instance requires an ID + result = subprocess.run( + [sys.executable, VAST_PY_PATH, 'create', 'instance'], + capture_output=True, + text=True, + cwd=VAST_CLI_DIR, + timeout=30, + ) + assert result.returncode != 0 diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 00000000..8ff8dac6 --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1 @@ +"""Unit tests for vast.py helper functions.""" diff --git a/tests/unit/test_http_helpers.py b/tests/unit/test_http_helpers.py new file mode 100644 index 00000000..3441f628 --- /dev/null +++ b/tests/unit/test_http_helpers.py @@ -0,0 +1,267 @@ +"""Unit tests for HTTP helper functions (api_call, http_*, retry logic). + +TEST-02: Unit tests for HTTP helper functions with mocked responses. + +These tests verify the helper functions in isolation, complementing the +regression tests that focus on specific bug fixes. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import argparse +import json +import pytest +from unittest.mock import MagicMock, patch, call +from requests.exceptions import ConnectionError, Timeout, HTTPError + + +def _make_args(retry=3, raw=False, explain=False, curl=False): + """Create minimal args namespace for testing.""" + return argparse.Namespace( + api_key="test-key", + url="https://console.vast.ai", + retry=retry, + raw=raw, + explain=explain, + quiet=False, + curl=curl, + ) + + +class TestApiCall: + """Tests for the api_call() helper function.""" + + @patch('vast.http_get') + def test_api_call_get_request(self, mock_http_get): + """api_call makes GET request and returns parsed JSON.""" + from vast import api_call + + mock_response = MagicMock() + mock_response.json.return_value = {"offers": [{"id": 1}]} + mock_http_get.return_value = mock_response + + args = _make_args() + # api_call signature: api_call(args, method, path, *, json_body=None, query_args=None) + result = api_call(args, "GET", "/api/v0/bundles") + + assert result == {"offers": [{"id": 1}]} + mock_http_get.assert_called_once() + + @patch('vast.http_post') + def test_api_call_post_request(self, mock_http_post): + """api_call makes POST request with JSON body.""" + from vast import api_call + + mock_response = MagicMock() + mock_response.json.return_value = {"success": True} + mock_http_post.return_value = mock_response + + args = _make_args() + # api_call signature: api_call(args, method, path, *, json_body=None, query_args=None) + result = api_call(args, "POST", "/api/v0/instances/123/", json_body={"action": "start"}) + + assert result == {"success": True} + mock_http_post.assert_called_once() + + @patch('vast.http_get') + def test_api_call_handles_http_error(self, mock_http_get): + """api_call propagates HTTPError from response.""" + from vast import api_call + + mock_response = MagicMock() + mock_response.raise_for_status.side_effect = HTTPError("404 Not Found") + mock_http_get.return_value = mock_response + + args = _make_args() + with pytest.raises(HTTPError): + # api_call signature: api_call(args, method, path, *, json_body=None, query_args=None) + api_call(args, "GET", "/api/v0/nonexistent") + + +class TestHttpHelpers: + """Tests for http_get, http_post, http_put, http_del functions.""" + + @patch('vast.http_request') + def test_http_get_constructs_correct_request(self, mock_http_request): + """http_get passes correct method and URL to http_request.""" + from vast import http_get + + mock_http_request.return_value = MagicMock(status_code=200) + args = _make_args() + + http_get(args, "https://example.com/api/test") + + mock_http_request.assert_called_once() + call_args = mock_http_request.call_args + assert call_args[0][0] == "GET" # method + assert call_args[0][2] == "https://example.com/api/test" # url + + @patch('vast.http_request') + def test_http_post_sends_json_body(self, mock_http_request): + """http_post includes JSON body in request.""" + from vast import http_post + + mock_http_request.return_value = MagicMock(status_code=200) + args = _make_args() + body = {"key": "value"} + + # http_post signature: http_post(args, req_url, headers=None, json=None, timeout=DEFAULT_TIMEOUT) + http_post(args, "https://example.com/api/test", json=body) + + # http_request is called as: http_request('POST', args, req_url, headers, json, timeout=timeout) + # json is passed as positional arg at index 4 + call_args = mock_http_request.call_args[0] + assert call_args[4] == body # json is 5th positional arg (index 4) + + @patch('vast.http_request') + def test_http_put_sends_json_body(self, mock_http_request): + """http_put includes JSON body in request.""" + from vast import http_put + + mock_http_request.return_value = MagicMock(status_code=200) + args = _make_args() + body = {"update": "data"} + + # http_put signature: http_put(args, req_url, headers=None, json=None, timeout=DEFAULT_TIMEOUT) + http_put(args, "https://example.com/api/test", json=body) + + # http_request is called as: http_request('PUT', args, req_url, headers, json, timeout=timeout) + # json is passed as positional arg at index 4 + call_args = mock_http_request.call_args[0] + assert call_args[4] == body # json is 5th positional arg (index 4) + + @patch('vast.http_request') + def test_http_del_makes_delete_request(self, mock_http_request): + """http_del uses DELETE method.""" + from vast import http_del + + mock_http_request.return_value = MagicMock(status_code=200) + args = _make_args() + + http_del(args, "https://example.com/api/resource/123") + + call_args = mock_http_request.call_args + assert call_args[0][0] == "DELETE" + + +class TestHttpRequestRetry: + """Tests for retry logic in http_request.""" + + @patch('vast.time.sleep') + @patch('vast.requests.Session') + def test_retry_on_connection_error(self, mock_session_cls, mock_sleep): + """http_request retries on ConnectionError.""" + from vast import http_request + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + # First call fails, second succeeds + success_response = MagicMock(status_code=200) + mock_session.send.side_effect = [ + ConnectionError("Connection refused"), + success_response + ] + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=3) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test') + + assert result.status_code == 200 + assert mock_session.send.call_count == 2 + + @patch('vast.time.sleep') + @patch('vast.requests.Session') + def test_retry_on_timeout(self, mock_session_cls, mock_sleep): + """http_request retries on Timeout.""" + from vast import http_request + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + success_response = MagicMock(status_code=200) + mock_session.send.side_effect = [ + Timeout("Request timed out"), + success_response + ] + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=3) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test') + + assert result.status_code == 200 + assert mock_session.send.call_count == 2 + + @patch('vast.time.sleep') + @patch('vast.requests.Session') + def test_retry_on_503_status(self, mock_session_cls, mock_sleep): + """http_request retries on 503 Service Unavailable.""" + from vast import http_request, RETRYABLE_STATUS_CODES + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + error_response = MagicMock(status_code=503) + success_response = MagicMock(status_code=200) + mock_session.send.side_effect = [error_response, success_response] + mock_session.prepare_request.return_value = MagicMock() + + assert 503 in RETRYABLE_STATUS_CODES + + args = _make_args(retry=3) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test') + + assert result.status_code == 200 + assert mock_session.send.call_count == 2 + + @patch('vast.time.sleep') + @patch('vast.requests.Session') + def test_no_retry_on_400_status(self, mock_session_cls, mock_sleep): + """http_request does not retry on 400 Bad Request.""" + from vast import http_request + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + error_response = MagicMock(status_code=400) + mock_session.send.return_value = error_response + mock_session.prepare_request.return_value = MagicMock() + + args = _make_args(retry=3) + with patch('vast.ARGS', args): + result = http_request('GET', args, 'http://example.com/test') + + assert result.status_code == 400 + assert mock_session.send.call_count == 1 # No retry + + +class TestTimeoutConstants: + """Tests for timeout constant values.""" + + def test_default_timeout_defined(self): + """DEFAULT_TIMEOUT constant is defined and reasonable.""" + from vast import DEFAULT_TIMEOUT + + assert DEFAULT_TIMEOUT == 30 + assert isinstance(DEFAULT_TIMEOUT, (int, float)) + + def test_long_timeout_defined(self): + """LONG_TIMEOUT constant is defined for file operations.""" + from vast import LONG_TIMEOUT + + assert LONG_TIMEOUT == 120 + assert LONG_TIMEOUT > 30 # Should be longer than default + + def test_retryable_status_codes_defined(self): + """RETRYABLE_STATUS_CODES contains expected HTTP statuses.""" + from vast import RETRYABLE_STATUS_CODES + + assert 429 in RETRYABLE_STATUS_CODES # Too Many Requests + assert 502 in RETRYABLE_STATUS_CODES # Bad Gateway + assert 503 in RETRYABLE_STATUS_CODES # Service Unavailable + assert 504 in RETRYABLE_STATUS_CODES # Gateway Timeout + assert 500 not in RETRYABLE_STATUS_CODES # 500 not retried (may have side effects) diff --git a/tests/unit/test_query_parser.py b/tests/unit/test_query_parser.py new file mode 100644 index 00000000..2fe89948 --- /dev/null +++ b/tests/unit/test_query_parser.py @@ -0,0 +1,160 @@ +"""Unit tests for query parsing functions (parse_query, field aliases). + +TEST-04: Unit tests for query parsing. + +These tests verify the query parser handles field names, operators, +aliases, and various query formats correctly. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import pytest + + +class TestParseQuery: + """Tests for parse_query() function.""" + + def test_simple_equality(self): + """Simple equality operator parses correctly.""" + from vast import parse_query, offers_fields, offers_alias + + result = parse_query("gpu_ram=8", {}, offers_fields, offers_alias) + + assert 'gpu_ram' in result + assert result['gpu_ram']['eq'] == '8' + + def test_greater_than_or_equal(self): + """Greater-than-or-equal operator parses correctly.""" + from vast import parse_query, offers_fields, offers_alias + + result = parse_query("gpu_ram>=16", {}, offers_fields, offers_alias) + + assert 'gpu_ram' in result + assert 'gte' in result['gpu_ram'] + assert result['gpu_ram']['gte'] == '16' + + def test_less_than_or_equal(self): + """Less-than-or-equal operator parses correctly.""" + from vast import parse_query, offers_fields, offers_alias + + result = parse_query("gpu_ram<=32", {}, offers_fields, offers_alias) + + assert 'gpu_ram' in result + assert 'lte' in result['gpu_ram'] + assert result['gpu_ram']['lte'] == '32' + + def test_multiple_conditions(self): + """Multiple space-separated conditions parse correctly.""" + from vast import parse_query, offers_fields, offers_alias + + result = parse_query("gpu_ram>=8 num_gpus>=2", {}, offers_fields, offers_alias) + + assert 'gpu_ram' in result + assert 'num_gpus' in result + + +class TestFieldAliases: + """Tests for field alias handling in parse_query.""" + + def test_cuda_vers_alias(self): + """cuda_vers is aliased to cuda_max_good.""" + from vast import parse_query, offers_fields, offers_alias + + result = parse_query("cuda_vers>=12.0", {}, offers_fields, offers_alias) + + # After alias resolution, cuda_vers should become cuda_max_good + assert 'cuda_max_good' in result + assert 'cuda_vers' not in result + assert 'gte' in result['cuda_max_good'] + + def test_alias_resolution_preserves_value(self): + """Alias resolution preserves the comparison value.""" + from vast import parse_query, offers_fields, offers_alias + + result = parse_query("cuda_vers>=11.8", {}, offers_fields, offers_alias) + + assert result['cuda_max_good']['gte'] == '11.8' + + +class TestOfferFieldsDefinition: + """Tests verifying offers_fields contains expected fields.""" + + def test_gpu_ram_field_exists(self): + """gpu_ram is a valid offer field.""" + from vast import offers_fields + + assert 'gpu_ram' in offers_fields + + def test_num_gpus_field_exists(self): + """num_gpus is a valid offer field.""" + from vast import offers_fields + + assert 'num_gpus' in offers_fields + + def test_cuda_max_good_field_exists(self): + """cuda_max_good is a valid offer field.""" + from vast import offers_fields + + assert 'cuda_max_good' in offers_fields + + def test_dph_total_field_exists(self): + """dph_total (price) is a valid offer field.""" + from vast import offers_fields + + assert 'dph_total' in offers_fields + + +class TestQueryEdgeCases: + """Tests for edge cases in query parsing.""" + + def test_empty_query(self): + """Empty query returns empty result.""" + from vast import parse_query, offers_fields, offers_alias + + result = parse_query("", {}, offers_fields, offers_alias) + + assert result == {} or result is not None + + def test_whitespace_handling(self): + """Whitespace around operators is handled.""" + from vast import parse_query, offers_fields, offers_alias + + # Query with extra whitespace should still parse + result = parse_query("gpu_ram >= 8", {}, offers_fields, offers_alias) + + # Should have gpu_ram field with gte constraint + assert 'gpu_ram' in result + + def test_decimal_values(self): + """Decimal values in queries parse correctly.""" + from vast import parse_query, offers_fields, offers_alias + + result = parse_query("dph_total<=0.5", {}, offers_fields, offers_alias) + + assert 'dph_total' in result + assert result['dph_total']['lte'] == '0.5' + + +class TestOfferAliasDefinition: + """Tests for offer field alias definitions.""" + + def test_offers_alias_is_dict(self): + """offers_alias is a dictionary.""" + from vast import offers_alias + + assert isinstance(offers_alias, dict) + + def test_cuda_vers_alias_defined(self): + """cuda_vers -> cuda_max_good alias is defined.""" + from vast import offers_alias + + assert 'cuda_vers' in offers_alias + assert offers_alias['cuda_vers'] == 'cuda_max_good' + + def test_dph_alias_defined(self): + """dph -> dph_total alias is defined.""" + from vast import offers_alias + + assert 'dph' in offers_alias + assert offers_alias['dph'] == 'dph_total' diff --git a/tests/unit/test_timezone.py b/tests/unit/test_timezone.py new file mode 100644 index 00000000..9ea33596 --- /dev/null +++ b/tests/unit/test_timezone.py @@ -0,0 +1,125 @@ +"""Unit tests for timezone handling functions. + +TEST-03: Unit tests for timezone handling functions. + +These tests verify that all timezone conversions produce correct UTC results +regardless of the local system timezone. +""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +import calendar +from datetime import datetime, timezone +import pytest + + +class TestStringToUnixEpoch: + """Tests for string_to_unix_epoch() function.""" + + def test_date_string_to_utc_epoch(self): + """Date string converts to correct UTC epoch timestamp.""" + from vast import string_to_unix_epoch + + # 01/15/2025 00:00:00 UTC = 1736899200 + result = string_to_unix_epoch("01/15/2025") + expected = calendar.timegm((2025, 1, 15, 0, 0, 0, 0, 0, 0)) + + assert expected == 1736899200, f"Sanity check failed: expected 1736899200, got {expected}" + assert result == expected + + def test_numeric_string_passthrough(self): + """Numeric string is converted to float and returned.""" + from vast import string_to_unix_epoch + + assert string_to_unix_epoch("1736899200") == 1736899200.0 + assert string_to_unix_epoch("1736899200.5") == 1736899200.5 + + def test_none_returns_none(self): + """None input returns None.""" + from vast import string_to_unix_epoch + + assert string_to_unix_epoch(None) is None + + def test_empty_string_raises_value_error(self): + """Empty string raises ValueError (not numeric, can't parse as date).""" + from vast import string_to_unix_epoch + + # Empty string is not a valid float and not a valid date format + with pytest.raises(ValueError): + string_to_unix_epoch("") + + def test_various_date_formats(self): + """Various date formats parse correctly.""" + from vast import string_to_unix_epoch + + # Test MM/DD/YYYY format + result1 = string_to_unix_epoch("12/31/2024") + expected1 = calendar.timegm((2024, 12, 31, 0, 0, 0, 0, 0, 0)) + assert result1 == expected1 + + # Test boundary dates + result2 = string_to_unix_epoch("01/01/2025") + expected2 = calendar.timegm((2025, 1, 1, 0, 0, 0, 0, 0, 0)) + assert result2 == expected2 + + +class TestFromTimestampUtc: + """Tests for UTC-aware datetime.fromtimestamp usage.""" + + def test_fromtimestamp_with_utc(self): + """datetime.fromtimestamp with UTC timezone gives correct result.""" + # This tests the pattern used after the fix fix + epoch = 1736899200 # 01/15/2025 00:00:00 UTC + + # The correct pattern (after fix) + dt_utc = datetime.fromtimestamp(epoch, tz=timezone.utc) + + assert dt_utc.year == 2025 + assert dt_utc.month == 1 + assert dt_utc.day == 15 + assert dt_utc.hour == 0 + assert dt_utc.minute == 0 + assert dt_utc.second == 0 + assert dt_utc.tzinfo == timezone.utc + + def test_calendar_timegm_inverse(self): + """calendar.timegm is the inverse of datetime.fromtimestamp(tz=utc).""" + original_epoch = 1736899200 + + # Convert to datetime + dt = datetime.fromtimestamp(original_epoch, tz=timezone.utc) + + # Convert back to epoch + timetuple = dt.timetuple() + recovered_epoch = calendar.timegm(timetuple) + + assert recovered_epoch == original_epoch + + +class TestTimezoneConsistency: + """Tests verifying timezone handling is consistent across functions.""" + + def test_known_epoch_value(self): + """Test against a known epoch value that's verifiable.""" + # Unix epoch 0 is January 1, 1970, 00:00:00 UTC + epoch_zero = 0 + + dt = datetime.fromtimestamp(epoch_zero, tz=timezone.utc) + + assert dt.year == 1970 + assert dt.month == 1 + assert dt.day == 1 + assert dt.hour == 0 + + def test_y2k_epoch(self): + """Test Y2K timestamp (known reference point).""" + # January 1, 2000, 00:00:00 UTC = 946684800 + y2k_epoch = 946684800 + + dt = datetime.fromtimestamp(y2k_epoch, tz=timezone.utc) + + assert dt.year == 2000 + assert dt.month == 1 + assert dt.day == 1 + assert dt.hour == 0 diff --git a/vast.py b/vast.py index 4d737770..ded88cc7 100755 --- a/vast.py +++ b/vast.py @@ -9,7 +9,8 @@ import argparse import os import time -from typing import Dict, List, Tuple, Optional +import calendar +from typing import Any, Dict, List, Tuple, Optional from datetime import date, datetime, timedelta, timezone import hashlib import math @@ -21,6 +22,7 @@ from time import sleep from subprocess import PIPE import urllib3 +import ssl import atexit from contextlib import redirect_stdout, redirect_stderr from io import StringIO @@ -43,7 +45,7 @@ try: import argcomplete TABCOMPLETE = True -except: +except ImportError: # No tab-completion for you pass @@ -57,15 +59,7 @@ except ImportError: from urllib.parse import quote_plus # Python 3+ -try: - JSONDecodeError = json.JSONDecodeError -except AttributeError: - JSONDecodeError = ValueError - -try: - input = raw_input -except NameError: - pass +JSONDecodeError = json.JSONDecodeError #server_url_default = "https://vast.ai" @@ -80,6 +74,10 @@ format="%(levelname)s - %(message)s" ) +DEFAULT_TIMEOUT = 30 # seconds -- normal API calls +LONG_TIMEOUT = 120 # seconds -- file operations, large queries +RETRYABLE_STATUS_CODES = {429, 502, 503, 504} + def parse_version(version: str) -> tuple[int, ...]: parts = version.split(".") @@ -117,14 +115,15 @@ def is_pip_package(): except Exception: return False -def get_update_command(stable_version: str) -> str: +def get_update_command(stable_version: str) -> list: if is_pip_package(): + cmd = [sys.executable, "-m", "pip", "install", "--force-reinstall", "--no-cache-dir"] if "test.pypi.org" in PYPI_BASE_PATH: - return f"{sys.executable} -m pip install --force-reinstall --no-cache-dir -i {PYPI_BASE_PATH} vastai=={stable_version}" - else: - return f"{sys.executable} -m pip install --force-reinstall --no-cache-dir vastai=={stable_version}" + cmd.extend(["-i", PYPI_BASE_PATH]) + cmd.append(f"vastai=={stable_version}") + return cmd else: - return f"git fetch --all --tags --prune && git checkout tags/v{stable_version}" + return ["git", "fetch", "--all", "--tags", "--prune"] def get_local_version(): @@ -135,7 +134,7 @@ def get_local_version(): def get_project_data(project_name: str) -> dict[str, dict[str, str]]: url = PYPI_BASE_PATH + f"/pypi/{project_name}/json" - response = requests.get(url, headers={"Accept": "application/json"}) + response = requests.get(url, headers={"Accept": "application/json"}, timeout=10) # this will raise for HTTP status 4xx and 5xx response.raise_for_status() @@ -184,12 +183,20 @@ def check_for_update(): print("Updating...") _ = subprocess.run( update_command, - shell=True, check=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) + if not is_pip_package(): + # git case: need a second command to checkout the tag + _ = subprocess.run( + ["git", "checkout", f"tags/v{pypi_version}"], + check=True, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) print("Update completed successfully!\nAttempt to run your command again!") sys.exit(0) @@ -197,12 +204,6 @@ def check_for_update(): APP_NAME = "vastai" VERSION = get_local_version() -# define emoji support and fallbacks -_HAS_EMOJI = sys.stdout.encoding and 'utf' in sys.stdout.encoding.lower() -SUCCESS = "✅" if _HAS_EMOJI else "[OK]" -WARN = "⚠️" if _HAS_EMOJI else "[!]" -FAIL = "❌" if _HAS_EMOJI else "[X]" -INFO = "ℹ️" if _HAS_EMOJI else "[i]" try: # Although xdg-base-dirs is the newer name, there's @@ -218,7 +219,7 @@ def check_for_update(): 'temp': xdg.xdg_cache_home() } -except: +except (ImportError, KeyError, OSError): # Reasonable defaults. DIRS = { 'config': os.path.join(os.getenv('HOME'), '.config'), @@ -237,6 +238,13 @@ def check_for_update(): APIKEY_FILE_HOME = os.path.expanduser("~/.vast_api_key") # Legacy TFAKEY_FILE = os.path.join(DIRS['config'], "vast_tfa_key") +# Emoji support with fallbacks for terminals that don't support Unicode +_HAS_EMOJI = sys.stdout.encoding and 'utf' in sys.stdout.encoding.lower() +SUCCESS = "\u2705" if _HAS_EMOJI else "[OK]" +WARN = "\u26a0\ufe0f" if _HAS_EMOJI else "[!]" +FAIL = "\u274c" if _HAS_EMOJI else "[X]" +INFO = "\u2139\ufe0f" if _HAS_EMOJI else "[i]" + if not os.path.exists(APIKEY_FILE) and os.path.exists(APIKEY_FILE_HOME): #print(f'copying key from {APIKEY_FILE_HOME} -> {APIKEY_FILE}') shutil.copyfile(APIKEY_FILE_HOME, APIKEY_FILE) @@ -287,11 +295,11 @@ def string_to_unix_epoch(date_string): except ValueError: # If not, parse it as a date string date_object = datetime.strptime(date_string, "%m/%d/%Y") - return time.mktime(date_object.timetuple()) + return calendar.timegm(date_object.timetuple()) def unix_to_readable(ts): # ts: integer or float, Unix timestamp - return datetime.fromtimestamp(ts).strftime('%H:%M:%S|%h-%d-%Y') + return datetime.fromtimestamp(ts, tz=timezone.utc).strftime('%H:%M:%S|%h-%d-%Y') def fix_date_fields(query: Dict[str, Dict], date_fields: List[str]): """Takes in a query and date fields to correct and returns query with appropriate epoch dates""" @@ -330,18 +338,13 @@ def __nonzero__(self): def append(self, x): self.l.append(x) -def http_request(verb, args, req_url, headers: dict[str, str] | None = None, json_data = None): +def http_request(verb, args, req_url, headers: dict[str, str] | None = None, json = None, timeout=DEFAULT_TIMEOUT): t = 0.15 + r = None for i in range(0, args.retry): - req = requests.Request(method=verb, url=req_url, headers=headers, json=json_data) + req = requests.Request(method=verb, url=req_url, headers=headers, json=json) session = requests.Session() prep = session.prepare_request(req) - if args.explain: - print(f"\n{INFO} Prepared Request:") - print(f"{prep.method} {prep.url}") - print(f"Headers: {json.dumps(headers, indent=1)}") - print(f"Body: {json.dumps(json_data, indent=1)}" + "\n" + "_"*100 + "\n") - if ARGS.curl: as_curl = curlify.to_curl(prep) simple = re.sub(r" -H '[^']*'", '', as_curl) @@ -352,26 +355,42 @@ def http_request(verb, args, req_url, headers: dict[str, str] | None = None, jso print("\n" + ' \\\n '.join(parts).strip() + "\n") sys.exit(0) else: - r = session.send(prep) + try: + r = session.send(prep, timeout=timeout) + except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e: + if i < args.retry - 1: + time.sleep(t) + t *= 1.5 + continue + raise + except requests.exceptions.RequestException as e: + # Non-retryable request errors (e.g., InvalidURL) + raise - if (r.status_code == 429): + if r.status_code in RETRYABLE_STATUS_CODES: time.sleep(t) t *= 1.5 else: break return r -def http_get(args, req_url, headers = None, json = None): - return http_request('GET', args, req_url, headers, json) +def http_get(args, req_url, headers=None, json=None, timeout=DEFAULT_TIMEOUT): + return http_request('GET', args, req_url, headers, json, timeout=timeout) -def http_put(args, req_url, headers = None, json = {}): - return http_request('PUT', args, req_url, headers, json) +def http_put(args, req_url, headers=None, json=None, timeout=DEFAULT_TIMEOUT): + if json is None: + json = {} + return http_request('PUT', args, req_url, headers, json, timeout=timeout) -def http_post(args, req_url, headers = None, json={}): - return http_request('POST', args, req_url, headers, json) +def http_post(args, req_url, headers=None, json=None, timeout=DEFAULT_TIMEOUT): + if json is None: + json = {} + return http_request('POST', args, req_url, headers, json, timeout=timeout) -def http_del(args, req_url, headers = None, json={}): - return http_request('DELETE', args, req_url, headers, json) +def http_del(args, req_url, headers=None, json=None, timeout=DEFAULT_TIMEOUT): + if json is None: + json = {} + return http_request('DELETE', args, req_url, headers, json, timeout=timeout) def load_permissions_from_file(file_path): @@ -630,7 +649,7 @@ def apiheaders(args: argparse.Namespace) -> Dict: return result -def deindent(message: str, add_separator: bool = True) -> str: +def deindent(message: str) -> str: """ Deindent a quoted string. Scans message and finds the smallest number of whitespace characters in any line and removes that many from the start of every line. @@ -642,13 +661,141 @@ def deindent(message: str, add_separator: bool = True) -> str: indents = [len(x) for x in re.findall("^ *(?=[^ ])", message, re.MULTILINE) if len(x)] a = min(indents) message = re.sub(r"^ {," + str(a) + "}", "", message, flags=re.MULTILINE) - if add_separator: - # For help epilogs - cleanly separating extra help from options - line_width = min(150, shutil.get_terminal_size((80, 20)).columns) - message = "_"*line_width + "\n"*2 + message.strip() + "\n" + "_"*line_width return message.strip() +def api_call( + args: argparse.Namespace, + method: str, + path: str, + *, + json_body: dict[str, Any] | None = None, + query_args: dict[str, Any] | None = None, +) -> dict[str, Any] | list[dict[str, Any]] | None: + """Centralized API call: URL construction + HTTP dispatch + status check. + + Args: + args: argparse.Namespace with url, api_key, explain, raw, retry, curl. + method: HTTP method string ("GET", "POST", "PUT", "DELETE"). + path: API path (e.g., "/instances/", "/auth/apikeys/{id}/"). + json_body: Optional dict for request body (POST/PUT/DELETE). + query_args: Optional dict for URL query parameters. + + Returns: + Parsed JSON response (dict or list), or None for empty responses. + + Raises: + requests.exceptions.HTTPError: On non-2xx status codes. + """ + url = apiurl(args, path, query_args) + dispatch = { + "GET": http_get, + "POST": http_post, + "PUT": http_put, + "DELETE": http_del, + } + http_fn = dispatch[method] + + if method == "GET": + r = http_fn(args, url, headers=headers, json=json_body) + else: + r = http_fn(args, url, headers=headers, json=json_body if json_body is not None else {}) + + r.raise_for_status() + + if r.content: + try: + return r.json() + except JSONDecodeError: + return {"_raw_text": r.text} + return None + + +def output_result( + args: argparse.Namespace, + data: list[dict[str, Any]] | dict[str, Any], + fields: list[tuple[str, str, str]] | None = None, +) -> list[dict[str, Any]] | dict[str, Any] | None: + """Unified output handler for command results. + + In raw mode: returns data for main() to serialize as JSON. + In table mode: calls display_table() if fields are provided. + In JSON mode: prints formatted JSON (when no fields defined). + + Args: + args: argparse.Namespace with raw flag. + data: The response data (dict, list, or None). + fields: Optional tuple of field definitions for display_table(). + + Returns: + data if in raw mode, None otherwise. + """ + if args.raw: + return data + if data is None: + return None + if fields: + rows = data if isinstance(data, list) else [data] + display_table(rows, fields) + else: + print(json.dumps(data, indent=1, sort_keys=True)) + return None + + +def error_output( + args: argparse.Namespace, + status_code: int, + message: str, + *, + detail: str | None = None, +) -> None: + """Output an error in the appropriate format for the current mode. + + In raw mode: prints JSON error object to stderr. + In non-raw mode: prints human-readable error to stderr. + + Args: + args: argparse.Namespace with raw flag. + status_code: HTTP status code or error code. + message: Error message string. + detail: Optional additional detail string. + """ + if getattr(args, 'raw', False): + error = {"error": True, "status_code": status_code, "msg": message} + if detail: + error["detail"] = detail + print(json.dumps(error), file=sys.stderr) + else: + print(f"failed with error {status_code}: {message}", file=sys.stderr) + + +def require_id(args: argparse.Namespace, field: str = "id") -> int | str: + """Extract and validate an ID argument. + + Args: + args: argparse.Namespace containing the ID field. + field: Name of the attribute on args (default "id"). + + Returns: + The value of the requested field. + + Raises: + SystemExit: If the field is None or missing. + """ + val = getattr(args, field, None) + if val is None: + print(f"Error: {field} is required", file=sys.stderr) + raise SystemExit(1) + return val + + +# Field definition tuples: (key, display_name, format_string, converter_or_None, left_justify) +# key: API response dict key +# display_name: Column header in table output +# format_string: Python format spec (e.g., ">8", "<16", ">10.4f") +# converter_or_None: Lambda to transform value, or None for raw value +# left_justify: Boolean, True for left-aligned columns + # These are the fields that are displayed when a search is run displayable_fields = ( # ("bw_nvlink", "Bandwidth NVLink", "{}", None, True), @@ -873,8 +1020,8 @@ def deindent(message: str, add_separator: bool = True) -> str: # These fields are displayed when you do 'show maints' maintenance_fields = ( ("machine_id", "Machine ID", "{}", None, True), - ("start_time", "Start (Date/Time)", "{}", lambda x: datetime.fromtimestamp(x).strftime('%Y-%m-%d/%H:%M'), True), - ("end_time", "End (Date/Time)", "{}", lambda x: datetime.fromtimestamp(x).strftime('%Y-%m-%d/%H:%M'), True), + ("start_time", "Start (Date/Time)", "{}", lambda x: datetime.fromtimestamp(x, tz=timezone.utc).strftime('%Y-%m-%d/%H:%M'), True), + ("end_time", "End (Date/Time)", "{}", lambda x: datetime.fromtimestamp(x, tz=timezone.utc).strftime('%Y-%m-%d/%H:%M'), True), ("duration_hours", "Duration (Hrs)", "{}", None, True), ("maintenance_category", "Category", "{}", None, True), ) @@ -899,8 +1046,8 @@ def deindent(message: str, add_separator: bool = True) -> str: ("id", "Scheduled Job ID", "{}", None, True), ("instance_id", "Instance ID", "{}", None, True), ("api_endpoint", "API Endpoint", "{}", None, True), - ("start_time", "Start (Date/Time in UTC)", "{}", lambda x: datetime.fromtimestamp(x).strftime('%Y-%m-%d/%H:%M'), True), - ("end_time", "End (Date/Time in UTC)", "{}", lambda x: datetime.fromtimestamp(x).strftime('%Y-%m-%d/%H:%M'), True), + ("start_time", "Start (Date/Time in UTC)", "{}", lambda x: datetime.fromtimestamp(x, tz=timezone.utc).strftime('%Y-%m-%d/%H:%M'), True), + ("end_time", "End (Date/Time in UTC)", "{}", lambda x: datetime.fromtimestamp(x, tz=timezone.utc).strftime('%Y-%m-%d/%H:%M'), True), ("day_of_the_week", "Day of the Week", "{}", None, True), ("hour_of_the_day", "Hour of the Day in UTC", "{}", None, True), ("min_of_the_hour", "Minute of the Hour", "{}", None, True), @@ -1105,13 +1252,16 @@ def parse_query(query_str: str, res: Dict = None, fields = {}, field_alias = {}, for field, op, _, value, _ in opts: value = value.strip(",[]") - v = res.setdefault(field, {}) op = op.strip() op_name = op_names.get(op) if field in field_alias: - res.pop(field) + old_field = field field = field_alias[field] + if old_field in res: + res[field] = res.pop(old_field) + + v = res.setdefault(field, {}) if (field == "driver_version") and ('.' in value): value = numeric_version(value) @@ -1163,23 +1313,35 @@ def parse_query(query_str: str, res: Dict = None, fields = {}, field_alias = {}, #print(res) return res -# ANSI escape codes for background/foreground colors -BG_DARK_GRAY = '\033[40m' # Dark gray background -BG_LIGHT_GRAY = '\033[48;5;240m' # Light gray background -FG_WHITE = '\033[97m' # Bright white text -BG_RESET = '\033[0m' # Reset all formatting -def display_table(rows: list, fields: Tuple, replace_spaces: bool = True, auto_width: bool = True) -> None: - """Basically takes a set of field names and rows containing the corresponding data and prints a nice tidy table - of it. +# ANSI color codes for table formatting +BG_DARK_GRAY = '\033[40m' # Dark gray background +BG_LIGHT_GRAY = '\033[48;5;240m' # Light gray background +FG_WHITE = '\033[97m' # Bright white text +BG_RESET = '\033[0m' # Reset all formatting - :param list rows: Each row is a dict with keys corresponding to the field names (first element) in the fields tuple. - :param Tuple fields: 5-tuple describing a field. First element is field name, second is human readable version, third is format string, fourth is a lambda function run on the data in that field, fifth is a bool determining text justification. True = left justify, False = right justify. Here is an example showing the tuples in action. +def display_table(rows: list, fields: Tuple, replace_spaces: bool = True, auto_width: bool = True) -> None: + """Display data as a formatted table with automatic column width management. - :rtype None: + Takes a set of field definitions and rows of data and prints a formatted table. + When auto_width is enabled, columns are grouped to fit within terminal width, + with alternating row colors for readability. - Example of 5-tuple: ("cpu_ram", "RAM", "{:0.1f}", lambda x: x / 1000, False) + Args: + rows: List of dicts with keys corresponding to field names in the fields tuple. + fields: Tuple of 5-tuples defining each column: + - field_name: API response dict key + - display_name: Column header text + - format_string: Python format spec (e.g., "{:0.1f}") + - converter: Lambda to transform value, or None for raw value + - left_justify: Boolean, True for left-aligned columns + replace_spaces: If True, replace spaces with underscores in cell values. + auto_width: If True, automatically group columns to fit terminal width + with colored alternating rows. If False, print simple table. + + Example field tuple: + ("cpu_ram", "RAM", "{:0.1f}", lambda x: x / 1000, False) """ header = [name for _, name, _, _, _ in fields] out_rows = [header] @@ -1200,7 +1362,7 @@ def display_table(rows: list, fields: Tuple, replace_spaces: bool = True, auto_w idx = len(row) lengths[idx] = max(len(s), lengths[idx]) row.append(s) - + if auto_width: width = shutil.get_terminal_size((80, 20)).columns start_col_idxs = [0] @@ -1210,7 +1372,7 @@ def display_table(rows: list, fields: Tuple, replace_spaces: bool = True, auto_w if total_len > width: start_col_idxs.append(i) # index for the start of the next group total_len = l + 6 # l + 2 + the 4 from the initial length - + groups = {} for row in out_rows: grp_num = 0 @@ -1219,7 +1381,7 @@ def display_table(rows: list, fields: Tuple, replace_spaces: bool = True, auto_w end = start_col_idxs[i+1]-1 if i+1 < len(start_col_idxs) else len(lengths) groups.setdefault(grp_num, []).append(row[start:end]) grp_num += 1 - + for i, group in groups.items(): idx = start_col_idxs[i] group_lengths = lengths[idx:idx+len(group[0])] @@ -1289,7 +1451,7 @@ def parse_vast_url(url_str): try: instance_id = int(path) path = "/" - except: + except (ValueError, TypeError): pass valid_unix_path_regex = re.compile('^(/)?([^/\0]+(/)?)+$') @@ -1317,7 +1479,7 @@ def get_ssh_key(argstr): has around 200 or so "base64" characters and ends with some-user@some-where. "Generate public ssh key" would be a good search term if you don't know how to do this. - """, add_separator=False)) + """)) if not ssh_key.lower().startswith('ssh'): raise ValueError(deindent(""" @@ -1330,7 +1492,7 @@ def get_ssh_key(argstr): {} And welp, that just don't look right. - """.format(ssh_key), add_separator=False)) + """.format(ssh_key))) return ssh_key @@ -1338,14 +1500,15 @@ def get_ssh_key(argstr): @parser.command( argument("instance_id", help="id of instance to attach to", type=int), argument("ssh_key", help="ssh key to attach to instance", type=str), + description="Attach an SSH key to an instance for remote access", usage="vastai attach ssh instance_id ssh_key", - help="Attach an ssh key to an instance. This will allow you to connect to the instance with the ssh key", + help="Attach an SSH key to an instance for remote access", epilog=deindent(""" Attach an ssh key to an instance. This will allow you to connect to the instance with the ssh key. Examples: - vastai attach "ssh 12371 ssh-rsa AAAAB3NzaC1yc2EAAA..." - vastai attach "ssh 12371 ssh-rsa $(cat ~/.ssh/id_rsa)" + vastai attach ssh 12371 ssh-rsa AAAAB3NzaC1yc2EAAA... + vastai attach ssh 12371 ssh-rsa $(cat ~/.ssh/id_rsa) """), ) def attach__ssh(args): @@ -1354,12 +1517,20 @@ def attach__ssh(args): req_json = {"ssh_key": ssh_key} r = http_post(args, url, headers=headers, json=req_json) r.raise_for_status() - print(r.json()) + try: + rj = r.json() + except JSONDecodeError: + print("Error: API returned invalid JSON response", file=sys.stderr) + return + if args.raw: + return rj + print(rj) @parser.command( argument("dst", help="instance_id:/path to target of copy operation", type=str), + description="Cancel an in-progress file copy operation", usage="vastai cancel copy DST", - help="Cancel a remote copy in progress, specified by DST id", + help="Cancel an in-progress file copy operation", epilog=deindent(""" Use this command to cancel any/all current remote copy operations copying to a specific named instance, given by DST. @@ -1388,21 +1559,24 @@ def cancel__copy(args: argparse.Namespace): req_json = { "client_id": "me", "dst_id": dst_id, } r = http_del(args, url, headers=headers,json=req_json) r.raise_for_status() - if (r.status_code == 200): + try: rj = r.json(); - if (rj["success"]): - print("Remote copy canceled - check instance status bar for progress updates (~30 seconds delayed).") - else: - print(rj["msg"]); + except JSONDecodeError: + print("Error: API returned invalid JSON response", file=sys.stderr) + return + if args.raw: + return rj + if rj.get("success"): + print("Remote copy canceled - check instance status bar for progress updates (~30 seconds delayed).") else: - print(r.text); - print("failed with error {r.status_code}".format(**locals())); + print(rj.get("msg", "Unknown error")); @parser.command( argument("dst", help="instance_id:/path to target of sync operation", type=str), + description="Cancel an in-progress file sync operation", usage="vastai cancel sync DST", - help="Cancel a remote copy in progress, specified by DST id", + help="Cancel an in-progress file sync operation", epilog=deindent(""" Use this command to cancel any/all current remote cloud sync operations copying to a specific named instance, given by DST. @@ -1431,15 +1605,17 @@ def cancel__sync(args: argparse.Namespace): req_json = { "client_id": "me", "dst_id": dst_id, } r = http_del(args, url, headers=headers,json=req_json) r.raise_for_status() - if (r.status_code == 200): + try: rj = r.json(); - if (rj["success"]): - print("Remote copy canceled - check instance status bar for progress updates (~30 seconds delayed).") - else: - print(rj["msg"]); + except JSONDecodeError: + print("Error: API returned invalid JSON response", file=sys.stderr) + return + if args.raw: + return rj + if rj.get("success"): + print("Remote copy canceled - check instance status bar for progress updates (~30 seconds delayed).") else: - print(r.text); - print("failed with error {r.status_code}".format(**locals())); + print(rj.get("msg", "Unknown error")); def default_start_date(): return datetime.now(timezone.utc).strftime("%Y-%m-%d") @@ -1491,6 +1667,7 @@ def parse_hour_cron_style(value): argument("--end_date", type=str, default=default_end_date(), help="End date/time in format 'YYYY-MM-DD HH:MM:SS PM' (UTC). Default is 7 days from now. (optional)"), argument("--day", type=parse_day_cron_style, help="Day of week you want scheduled job to run on (0-6, where 0=Sunday) or \"*\". Default will be 0. For ex. --day 0", default=0), argument("--hour", type=parse_hour_cron_style, help="Hour of day you want scheduled job to run on (0-23) or \"*\" (UTC). Default will be 0. For ex. --hour 16", default=0), + description="Change the bid price for a spot/interruptible instance", usage="vastai change bid id [--price PRICE]", help="Change the bid price for a spot/interruptible instance", epilog=deindent(""" @@ -1504,8 +1681,6 @@ def change__bid(args: argparse.Namespace): :param argparse.Namespace args: should supply all the command-line options :rtype int: """ - url = apiurl(args, "/instances/bid_price/{id}/".format(id=args.id)) - json_blob = {"client_id": "me", "price": args.price,} if (args.explain): print("request json: ") @@ -1516,12 +1691,13 @@ def change__bid(args: argparse.Namespace): cli_command = "change bid" api_endpoint = "/api/v0/instances/bid_price/{id}/".format(id=args.id) json_blob["instance_id"] = args.id - add_scheduled_job(args, json_blob, cli_command, api_endpoint, "PUT", instance_id=args.id) + add_scheduled_job(args, json_blob, cli_command, api_endpoint, "PUT", instance_id=args.id) return - - r = http_put(args, url, headers=headers, json=json_blob) - r.raise_for_status() - print("Per gpu bid price changed".format(r.json())) + + result = api_call(args, "PUT", "/instances/bid_price/{id}/".format(id=args.id), json_body=json_blob) + if args.raw: + return result + print("Per gpu bid price changed".format(result)) @@ -1530,8 +1706,9 @@ def change__bid(args: argparse.Namespace): argument("dest", help="id of volume offer volume is being copied to", type=int), argument("-s", "--size", help="Size of new volume contract, in GB. Must be greater than or equal to the source volume, and less than or equal to the destination offer.", type=float), argument("-d", "--disable_compression", action="store_true", help="Do not compress volume data before copying."), + description="Create a copy of an existing volume", usage="vastai copy volume [options]", - help="Clone an existing volume", + help="Create a copy of an existing volume", epilog=deindent(""" Create a new volume with the given offer, by copying the existing volume. Size defaults to the size of the existing volume, but can be increased if there is available space. @@ -1555,18 +1732,24 @@ def clone__volume(args: argparse.Namespace): print(json_blob) r = http_post(args, url, headers=headers,json=json_blob) r.raise_for_status() + try: + rj = r.json() + except JSONDecodeError: + print("Error: API returned invalid JSON response", file=sys.stderr) + return if args.raw: - return r + return rj else: - print("Created. {}".format(r.json())) + print("Created. {}".format(rj)) @parser.command( argument("src", help="Source location for copy operation (supports multiple formats)", type=str), argument("dst", help="Target location for copy operation (supports multiple formats)", type=str), argument("-i", "--identity", help="Location of ssh private key", type=str), + description="Copy files/directories between instances or between local and instance", usage="vastai copy SRC DST", - help="Copy directories between instances and/or local", + help="Copy files/directories between instances or between local and instance", epilog=deindent(""" Copies a directory from a source location to a target location. Each of source and destination directories can be either local or remote, subject to appropriate read and write @@ -1632,43 +1815,43 @@ def copy(args: argparse.Namespace): url = apiurl(args, f"/commands/copy_direct/") r = http_put(args, url, headers=headers,json=req_json) r.raise_for_status() - if (r.status_code == 200): + try: rj = r.json() - #print(json.dumps(rj, indent=1, sort_keys=True)) - if (rj["success"]) and ((src_id is None or src_id == "local") or (dst_id is None or dst_id == "local")): - homedir = subprocess.getoutput("echo $HOME") - #print(f"homedir: {homedir}") - remote_port = None - identity = f"-i {args.identity}" if (args.identity is not None) else "" - if (src_id is None or src_id == "local"): - #result = subprocess.run(f"mkdir -p {src_path}", shell=True) - remote_port = rj["dst_port"] - remote_addr = rj["dst_addr"] - cmd = f"rsync -arz -v --progress --rsh=ssh -e 'ssh {identity} -p {remote_port} -o StrictHostKeyChecking=no' {src_path} vastai_kaalia@{remote_addr}::{dst_id}/{dst_path}" - print(cmd) - result = subprocess.run(cmd, shell=True) - #result = subprocess.run(["sudo", "rsync" "-arz", "-v", "--progress", "-rsh=ssh", "-e 'sudo ssh -i {homedir}/.ssh/id_rsa -p {remote_port} -o StrictHostKeyChecking=no'", src_path, "vastai_kaalia@{remote_addr}::{dst_id}"], shell=True) - elif (dst_id is None or dst_id == "local"): - result = subprocess.run(f"mkdir -p {dst_path}", shell=True) - remote_port = rj["src_port"] - remote_addr = rj["src_addr"] - cmd = f"rsync -arz -v --progress --rsh=ssh -e 'ssh {identity} -p {remote_port} -o StrictHostKeyChecking=no' vastai_kaalia@{remote_addr}::{src_id}/{src_path} {dst_path}" - print(cmd) - result = subprocess.run(cmd, shell=True) - #result = subprocess.run(["sudo", "rsync" "-arz", "-v", "--progress", "-rsh=ssh", "-e 'ssh -i {homedir}/.ssh/id_rsa -p {remote_port} -o StrictHostKeyChecking=no'", "vastai_kaalia@{remote_addr}::{src_id}", dst_path], shell=True) + except JSONDecodeError: + print("Error: API returned invalid JSON response", file=sys.stderr) + return + #print(json.dumps(rj, indent=1, sort_keys=True)) + if rj.get("success") and ((src_id is None or src_id == "local") or (dst_id is None or dst_id == "local")): + homedir = os.path.expanduser("~") + #print(f"homedir: {homedir}") + remote_port = None + identity = f"-i {args.identity}" if (args.identity is not None) else "" + if (src_id is None or src_id == "local"): + remote_port = rj.get("dst_port") + remote_addr = rj.get("dst_addr") + ssh_cmd = f"ssh {identity} -p {remote_port} -o StrictHostKeyChecking=no".strip() + rsync_args = ["rsync", "-arz", "-v", "--progress", "-e", ssh_cmd, src_path, f"vastai_kaalia@{remote_addr}::{dst_id}/{dst_path}"] + print(" ".join(rsync_args)) + result = subprocess.run(rsync_args) + elif (dst_id is None or dst_id == "local"): + os.makedirs(dst_path, exist_ok=True) + remote_port = rj.get("src_port") + remote_addr = rj.get("src_addr") + ssh_cmd = f"ssh {identity} -p {remote_port} -o StrictHostKeyChecking=no".strip() + rsync_args = ["rsync", "-arz", "-v", "--progress", "-e", ssh_cmd, f"vastai_kaalia@{remote_addr}::{src_id}/{src_path}", dst_path] + print(" ".join(rsync_args)) + result = subprocess.run(rsync_args) + else: + if rj.get("success"): + print("Remote to Remote copy initiated - check instance status bar for progress updates (~30 seconds delayed).") else: - if (rj["success"]): - print("Remote to Remote copy initiated - check instance status bar for progress updates (~30 seconds delayed).") + msg = rj.get("msg", "Unknown error") + if msg == "src_path not supported VMs.": + print("copy between VM instances does not currently support subpaths (only full disk copy)") + elif msg == "dst_path not supported for VMs.": + print("copy between VM instances does not currently support subpaths (only full disk copy)") else: - if rj["msg"] == "src_path not supported VMs.": - print("copy between VM instances does not currently support subpaths (only full disk copy)") - elif rj["msg"] == "dst_path not supported for VMs.": - print("copy between VM instances does not currently support subpaths (only full disk copy)") - else: - print(rj["msg"]) - else: - print(r.text) - print("failed with error {r.status_code}".format(**locals())); + print(msg) ''' @@ -1709,20 +1892,21 @@ def vm__copy(args: argparse.Namespace): r = http_put(args, url, headers=headers,json=req_json) r.raise_for_status() - if (r.status_code == 200): + try: rj = r.json(); - if (rj["success"]): - print("Remote to Remote copy initiated - check instance status bar for progress updates (~30 seconds delayed).") - else: - if rj["msg"] == "Invalid src_path.": - print("src instance is not a VM") - elif rj["msg"] == "Invalid dst_path.": - print("dst instance is not a VM") - else: - print(rj["msg"]); + except JSONDecodeError: + print("Error: API returned invalid JSON response", file=sys.stderr) + return + if rj.get("success"): + print("Remote to Remote copy initiated - check instance status bar for progress updates (~30 seconds delayed).") else: - print(r.text); - print("failed with error {r.status_code}".format(**locals())); + msg = rj.get("msg", "Unknown error") + if msg == "Invalid src_path.": + print("src instance is not a VM") + elif msg == "Invalid dst_path.": + print("dst instance is not a VM") + else: + print(msg); ''' @parser.command( @@ -1741,8 +1925,9 @@ def vm__copy(args: argparse.Namespace): argument("--end_date", type=str, help="End date/time in format 'YYYY-MM-DD HH:MM:SS PM' (UTC). Default is contract's end. (optional)"), argument("--day", type=parse_day_cron_style, help="Day of week you want scheduled job to run on (0-6, where 0=Sunday) or \"*\". Default will be 0. For ex. --day 0", default=0), argument("--hour", type=parse_hour_cron_style, help="Hour of day you want scheduled job to run on (0-23) or \"*\" (UTC). Default will be 0. For ex. --hour 16", default=0), + description="Copy files between instances and cloud storage (S3, GCS, Azure)", usage="vastai cloud copy --src SRC --dst DST --instance INSTANCE_ID -connection CONNECTION_ID --transfer TRANSFER_TYPE", - help="Copy files/folders to and from cloud providers", + help="Copy files between instances and cloud storage (S3, GCS, Azure)", epilog=deindent(""" Copies a directory from a source location to a target location. Each of source and destination directories can be either local or remote, subject to appropriate read and write @@ -1810,8 +1995,13 @@ def cloud__copy(args: argparse.Namespace): req_url = apiurl(args, "/instances/{id}/".format(id=args.instance) , {"owner": "me"} ) r = http_get(args, req_url) r.raise_for_status() - row = r.json()["instances"] - + try: + rj = r.json() + except JSONDecodeError: + print("Error: API returned invalid JSON response", file=sys.stderr) + return + row = rj.get("instances") + if args.transfer.lower() == "instance to cloud": if row: # Get the cost per TB of internet upload @@ -1838,12 +2028,8 @@ def cloud__copy(args: argparse.Namespace): r = http_post(args, url, headers=headers,json=req_json) r.raise_for_status() - if (r.status_code == 200): - print("Cloud Copy Started - check instance status bar for progress updates (~30 seconds delayed).") - print("When the operation is finished you should see 'Cloud Copy Operation Finished' in the instance status bar.") - else: - print(r.text); - print("failed with error {r.status_code}".format(**locals())); + print("Cloud Copy Started - check instance status bar for progress updates (~30 seconds delayed).") + print("When the operation is finished you should see 'Cloud Copy Operation Finished' in the instance status bar.") @parser.command( @@ -1853,10 +2039,11 @@ def cloud__copy(args: argparse.Namespace): argument("--docker_login_user",help="Username for container registry with repo", type=str), argument("--docker_login_pass",help="Password or token for container registry with repo", type=str), argument("--pause", help="Pause container's processes being executed by the CPU to take snapshot (true/false). Default will be true", type=str, default="true"), + description="Create a snapshot of a running container and push to registry", usage="vastai take snapshot INSTANCE_ID " "--repo REPO --docker_login_user USER --docker_login_pass PASS" "[--container_registry REGISTRY] [--pause true|false]", - help="Schedule a snapshot of a running container and push it to your repo in a container registry", + help="Create a snapshot of a running container and push to registry", epilog=deindent(""" Takes a snapshot of a running container instance and pushes snapshot to the specified repository in container registry. @@ -1901,16 +2088,15 @@ def take__snapshot(args: argparse.Namespace): # POST to the snapshot endpoint r = http_post(args, url, headers=headers, json=req_json) r.raise_for_status() - - if r.status_code == 200: + try: data = r.json() - if data.get("success"): - print(f"Snapshot request sent successfully. Please check your repo {repo} in container registry {container_registry} in 5-10 mins. It can take longer than 5-10 mins to push your snapshot image to your repo depending on the size of your image.") - else: - print(data.get("msg", "Unknown error with snapshot request")) + except JSONDecodeError: + print("Error: API returned invalid JSON response", file=sys.stderr) + return + if data.get("success"): + print(f"Snapshot request sent successfully. Please check your repo {repo} in container registry {container_registry} in 5-10 mins. It can take longer than 5-10 mins to push your snapshot image to your repo depending on the size of your image.") else: - print(r.text); - print("failed with error {r.status_code}".format(**locals())); + print(data.get("msg", "Unknown error with snapshot request")) def validate_frequency_values(day_of_the_week, hour_of_the_day, frequency): @@ -1961,7 +2147,7 @@ def add_scheduled_job(args, req_json, cli_command, api_endpoint, request_method, "instance_id": instance_id } # Send a POST request - response = requests.post(schedule_job_url, headers=headers, json=request_body) + response = http_post(args, schedule_job_url, headers=headers, json=request_body) if args.explain: print("request json: ") @@ -1975,29 +2161,33 @@ def add_scheduled_job(args, req_json, cli_command, api_endpoint, request_method, elif response.status_code == 422: user_input = input("Existing scheduled job found. Do you want to update it (y|n)? ") if user_input.strip().lower() == "y": - scheduled_job_id = response.json()["scheduled_job_id"] + try: + resp_data = response.json() + except JSONDecodeError: + print("Error: API returned invalid JSON response", file=sys.stderr) + return + scheduled_job_id = resp_data.get("scheduled_job_id") + if not scheduled_job_id: + print("Error: API response missing required 'scheduled_job_id' field", file=sys.stderr) + return schedule_job_url = apiurl(args, f"/commands/schedule_job/{scheduled_job_id}/") - response = update_scheduled_job(cli_command, schedule_job_url, frequency, args.start_date, args.end_date, request_body) + response = update_scheduled_job(args, cli_command, schedule_job_url, frequency, args.start_date, args.end_date, request_body) else: print("Job update aborted by the user.") else: # print(r.text) print(f"add_scheduled_job insert: failed error: {response.status_code}. Response body: {response.text}") -def update_scheduled_job(cli_command, schedule_job_url, frequency, start_date, end_date, request_body): - response = requests.put(schedule_job_url, headers=headers, json=request_body) +def update_scheduled_job(args, cli_command, schedule_job_url, frequency, start_date, end_date, request_body): + response = http_put(args, schedule_job_url, headers=headers, json=request_body) - # Raise an exception for HTTP errors + # Raise an exception for HTTP errors response.raise_for_status() - if response.status_code == 200: - print(f"add_scheduled_job update: success - Scheduling {frequency} job to {cli_command} from {start_date} UTC to {end_date} UTC") - print(response.json()) - elif response.status_code == 401: - print(f"add_scheduled_job update: failed status_code: {response.status_code}. It could be because you aren't using a valid api_key.") - else: - # print(r.text) - print(f"add_scheduled_job update: failed status_code: {response.status_code}.") + print(f"add_scheduled_job update: success - Scheduling {frequency} job to {cli_command} from {start_date} UTC to {end_date} UTC") + try: print(response.json()) + except JSONDecodeError: + print(response.text) return response @@ -2006,8 +2196,9 @@ def update_scheduled_job(cli_command, schedule_job_url, frequency, start_date, e argument("--name", help="name of the api-key", type=str), argument("--permission_file", help="file path for json encoded permissions, see https://vast.ai/docs/cli/roles-and-permissions for more information", type=str), argument("--key_params", help="optional wildcard key params for advanced keys", type=str), + description="Create a new API key with custom permissions", usage="vastai create api-key --name NAME --permission_file PERMISSIONS", - help="Create a new api-key with restricted permissions. Can be sent to other users and teammates", + help="Create a new API key with custom permissions", epilog=deindent(""" In order to create api keys you must understand how permissions must be sent via json format. You can find more information about permissions here: https://vast.ai/docs/cli/roles-and-permissions @@ -2019,7 +2210,14 @@ def create__api_key(args): permissions = load_permissions_from_file(args.permission_file) r = http_post(args, url, headers=headers, json={"name": args.name, "permissions": permissions, "key_params": args.key_params}) r.raise_for_status() - print("api-key created {}".format(r.json())) + try: + rj = r.json() + except JSONDecodeError: + print("Error: API returned invalid JSON response", file=sys.stderr) + return + if args.raw: + return rj + print("api-key created {}".format(rj)) except FileNotFoundError: print("Error: Permission file '{}' not found.".format(args.permission_file)) except requests.exceptions.RequestException as e: @@ -2031,8 +2229,9 @@ def create__api_key(args): @parser.command( argument("subnet", help="local subnet for cluster, ex: '0.0.0.0/24'", type=str), argument("manager_id", help="Machine ID of manager node in cluster. Must exist already.", type=int), + description="[Beta] Create a new machine cluster", usage="vastai create cluster SUBNET MANAGER_ID", - help="Create Vast cluster", + help="[Beta] Create a new machine cluster", epilog=deindent(""" Create Vast Cluster by defining a local subnet and manager id.""") ) @@ -2052,25 +2251,31 @@ def create__cluster(args: argparse.Namespace): r = http_post(args, req_url, json=json_blob) r.raise_for_status() + try: + rj = r.json() + except JSONDecodeError: + print("Error: API returned invalid JSON response", file=sys.stderr) + return + if args.raw: - return r + return rj - print(r.json()["msg"]) + print(rj.get("msg", "Unknown error")) @parser.command( argument("name", help="Environment variable name", type=str), argument("value", help="Environment variable value", type=str), + description="Create a new account-level environment variable", usage="vastai create env-var ", - help="Create a new user environment variable", + help="Create a new account-level environment variable", ) def create__env_var(args): """Create a new environment variable for the current user.""" - url = apiurl(args, "/secrets/") data = {"key": args.name, "value": args.value} - r = http_post(args, url, headers=headers, json=data) - r.raise_for_status() + result = api_call(args, "POST", "/secrets/", json_body=data) - result = r.json() + if args.raw: + return result if result.get("success"): print(result.get("msg", "Environment variable created successfully.")) else: @@ -2079,8 +2284,9 @@ def create__env_var(args): @parser.command( argument("ssh_key", help="add your existing ssh public key to your account (from the .pub file). If no public key is provided, a new key pair will be generated.", type=str, nargs='?'), argument("-y", "--yes", help="automatically answer yes to prompts", action="store_true"), + description="Add an SSH public key to your account", usage="vastai create ssh-key [ssh_public_key] [-y]", - help="Create a new ssh-key", + help="Add an SSH public key to your account", epilog=deindent(""" You may use this command to add an existing public key, or create a new ssh key pair and add that public key, to your Vast account. @@ -2101,20 +2307,20 @@ def create__env_var(args): def create__ssh_key(args): ssh_key_content = args.ssh_key - + # If no SSH key provided, generate one if not ssh_key_content: ssh_key_content = generate_ssh_key(args.yes) else: print("Adding provided SSH public key to account...") - + # Send the SSH key to the API - url = apiurl(args, "/ssh/") - r = http_post(args, url, headers=headers, json={"ssh_key": ssh_key_content}) - r.raise_for_status() - + result = api_call(args, "POST", "/ssh/", json_body={"ssh_key": ssh_key_content}) + + if args.raw: + return result # Print json response - print("ssh-key created {}\nNote: You may need to add the new public key to any pre-existing instances".format(r.json())) + print("ssh-key created {}\nNote: You may need to add the new public key to any pre-existing instances".format(result)) def generate_ssh_key(auto_yes=False): @@ -2245,8 +2451,9 @@ def generate_ssh_key(auto_yes=False): argument("--cold_mult", help="[NOTE: this field isn't currently used at the workergroup level]cold/stopped instance capacity target as multiple of hot capacity target (default 2.0)", type=float), argument("--cold_workers", help="min number of workers to keep 'cold' for this workergroup", type=int), argument("--auto_instance", help=argparse.SUPPRESS, type=str, default="prod"), + description="Create an autoscaling worker group for serverless inference", usage="vastai workergroup create [OPTIONS]", - help="Create a new autoscale group", + help="Create an autoscaling worker group for serverless inference", epilog=deindent(""" Create a new autoscaling group to manage a pool of worker instances. @@ -2275,7 +2482,10 @@ def create__workergroup(args): r.raise_for_status() if 'application/json' in r.headers.get('Content-Type', ''): try: - print("workergroup create {}".format(r.json())) + rj = r.json() + if args.raw: + return rj + print("workergroup create {}".format(rj)) except requests.exceptions.JSONDecodeError: print("The response is not valid JSON.") print(r) @@ -2295,8 +2505,9 @@ def create__workergroup(args): argument("--endpoint_name", help="deployment endpoint name (allows multiple autoscale groups to share same deployment endpoint)", type=str), argument("--auto_instance", help=argparse.SUPPRESS, type=str, default="prod"), + description="Create a serverless inference endpoint", usage="vastai create endpoint [OPTIONS]", - help="Create a new endpoint group", + help="Create a serverless inference endpoint", epilog=deindent(""" Create a new endpoint group to manage many autoscaling groups @@ -2311,11 +2522,14 @@ def create__endpoint(args): if (args.explain): print("request json: ") print(json_blob) - r = requests.post(url, headers=headers,json=json_blob) + r = http_post(args, url, headers=headers, json=json_blob) r.raise_for_status() if 'application/json' in r.headers.get('Content-Type', ''): try: - print("create endpoint {}".format(r.json())) + rj = r.json() + if args.raw: + return rj + print("create endpoint {}".format(rj)) except requests.exceptions.JSONDecodeError: print("The response is not valid JSON.") print(r) @@ -2414,8 +2628,9 @@ def validate_portal_config(json_blob): argument("--mount-path", help="The path to the volume from within the new instance container. e.g. /root/volume", type=str), argument("--volume-label", help="(optional) A name to give the new volume. Only usable with --create-volume", type=str), + description="Create a new GPU instance from an offer", usage="vastai create instance ID [OPTIONS] [--args ...]", - help="Create a new instance", + help="Create a new GPU instance from an offer", epilog=deindent(""" Performs the same action as pressing the "RENT" button on the website at https://console.vast.ai/create/ Creates an instance from an offer ID (which is returned from "search offers"). Each offer ID can only be used to create one instance. @@ -2507,7 +2722,7 @@ def create__instance(args: argparse.Namespace): r = http_put(args, url, headers=headers,json=json_blob) r.raise_for_status() if args.raw: - return r + return r.json() else: print("Started. {}".format(r.json())) @@ -2516,8 +2731,9 @@ def create__instance(args: argparse.Namespace): argument("--username", help="username to use for login", type=str), argument("--password", help="password to use for login", type=str), argument("--type", help="host/client", type=str), + description="Create a subaccount for delegated access", usage="vastai create subaccount --email EMAIL --username USERNAME --password PASSWORD --type TYPE", - help="Create a subaccount", + help="Create a subaccount for delegated access", epilog=deindent(""" Creates a new account that is considered a child of your current account as defined via the API key. @@ -2531,7 +2747,7 @@ def create__subaccount(args): """ # Default value for host_only, can adjust based on expected default behavior host_only = False - + # Only process the --account_type argument if it's provided if args.type: host_only = args.type.lower() == "host" @@ -2554,20 +2770,18 @@ def create__subaccount(args): url = apiurl(args, "/users/") r = http_post(args, url, headers=headers, json=json_blob) r.raise_for_status() - - if r.status_code == 200: - rj = r.json() - print(rj) - else: - print(r.text) - print(f"Failed with error {r.status_code}") + rj = r.json() + if args.raw: + return rj + print(rj) @parser.command( argument("--team_name", help="name of the team", type=str), + description="Create a new team", usage="vastai create-team --team_name TEAM_NAME", help="Create a new team", epilog=deindent(""" - Creates a new team under your account. + Creates a new team under your account. Unlike legacy teams, this command does NOT convert your personal account into a team. Each team is created as a separate account, and you can be a member of multiple teams. @@ -2578,13 +2792,9 @@ def create__subaccount(args): - Default roles (owner, manager, member) are automatically created. - You can invite others, assign roles, and manage resources within the team. - Optional: - You can transfer a portion of your existing personal credits to the team by using - the `--transfer_credit` flag. Example: - vastai create-team --team_name myteam --transfer_credit 25 - Notes: - You cannot create a team from within another team account. + - To transfer credits to a team, use `vastai transfer credit ` after team creation. For more details, see: https://vast.ai/docs/teams-quickstart @@ -2592,27 +2802,28 @@ def create__subaccount(args): ) def create__team(args): - url = apiurl(args, "/team/") - r = http_post(args, url, headers=headers, json={"team_name": args.team_name}) - r.raise_for_status() - print(r.json()) + result = api_call(args, "POST", "/team/", json_body={"team_name": args.team_name}) + if args.raw: + return result + print(result) @parser.command( argument("--name", help="name of the role", type=str), argument("--permissions", help="file path for json encoded permissions, look in the docs for more information", type=str), + description="Create a custom role with specific permissions", usage="vastai create team-role --name NAME --permissions PERMISSIONS", - help="Add a new role to your team", + help="Create a custom role with specific permissions", epilog=deindent(""" Creating a new team role involves understanding how permissions must be sent via json format. You can find more information about permissions here: https://vast.ai/docs/cli/roles-and-permissions """) ) def create__team_role(args): - url = apiurl(args, "/team/roles/") permissions = load_permissions_from_file(args.permissions) - r = http_post(args, url, headers=headers, json={"name": args.name, "permissions": permissions}) - r.raise_for_status() - print(r.json()) + result = api_call(args, "POST", "/team/roles/", json_body={"name": args.name, "permissions": permissions}) + if args.raw: + return result + print(result) def get_template_arguments(): return [ @@ -2640,8 +2851,9 @@ def get_template_arguments(): @parser.command( *get_template_arguments(), + description="Create a reusable instance configuration template", usage="vastai create template", - help="Create a new template", + help="Create a reusable instance configuration template", epilog=deindent(""" Create a template that can be used to create instances with @@ -2668,7 +2880,7 @@ def create__template(args): default_search_query = {} if not args.no_default: default_search_query = {"verified": {"eq": True}, "external": {"eq": False}, "rentable": {"eq": True}, "rented": {"eq": False}} - + extra_filters = parse_query(args.search_params, default_search_query, offers_fields, offers_alias, offers_mult) template = { "name" : args.name, @@ -2701,10 +2913,12 @@ def create__template(args): r.raise_for_status() try: rj = r.json() - if rj["success"]: - print(f"New Template: {rj['template']}") + if args.raw: + return rj + if rj.get("success"): + print(f"New Template: {rj.get('template', '')}") else: - print(rj['msg']) + print(rj.get('msg', 'Unknown error')) except requests.exceptions.JSONDecodeError: print("The response is not valid JSON.") @@ -2714,15 +2928,16 @@ def create__template(args): argument("-s", "--size", help="size in GB of volume. Default %(default)s GB.", default=15, type=float), argument("-n", "--name", help="Optional name of volume.", type=str), + description="Create a new persistent storage volume", usage="vastai create volume ID [options]", - help="Create a new volume", + help="Create a new persistent storage volume", epilog=deindent(""" Creates a volume from an offer ID (which is returned from "search volumes"). Each offer ID can be used to create multiple volumes, provided the size of all volumes does not exceed the size of the offer. """) ) def create__volume(args: argparse.Namespace): - + json_blob ={ "size": int(args.size), "id": int(args.id) @@ -2738,7 +2953,7 @@ def create__volume(args: argparse.Namespace): r = http_put(args, url, headers=headers,json=json_blob) r.raise_for_status() if args.raw: - return r + return r.json() else: print("Created. {}".format(r.json())) @@ -2748,8 +2963,9 @@ def create__volume(args: argparse.Namespace): argument("-s", "--size", help="size in GB of network volume. Default %(default)s GB.", default=15, type=float), argument("-n", "--name", help="Optional name of network volume.", type=str), + description="[Host] [Beta] Create a new network-attached storage volume", usage="vastai create network volume ID [options]", - help="Create a new network volume", + help="[Host] [Beta] Create a new network-attached storage volume", epilog=deindent(""" Creates a network volume from an offer ID (which is returned from "search network volumes"). Each offer ID can be used to create multiple volumes, provided the size of all volumes does not exceed the size of the offer. @@ -2772,15 +2988,16 @@ def create__network_volume(args: argparse.Namespace): r = http_put(args, url, headers=headers,json=json_blob) r.raise_for_status() if args.raw: - return r + return r.json() else: print("Created. {}".format(r.json())) @parser.command( argument("cluster_id", help="ID of cluster to create overlay on top of", type=int), argument("name", help="overlay network name"), + description="[Beta] Create a virtual overlay network on a cluster", usage="vastai create overlay CLUSTER_ID OVERLAY_NAME", - help="Creates overlay network on top of a physical cluster", + help="[Beta] Create a virtual overlay network on a cluster", epilog=deindent(""" Creates an overlay network to allow local networking between instances on a physical cluster""") ) @@ -2798,49 +3015,53 @@ def create__overlay(args: argparse.Namespace): r.raise_for_status() if args.raw: - return r + return r.json() - print(r.json()["msg"]) + print(r.json().get("msg", "Unknown error")) @parser.command( argument("id", help="id of apikey to remove", type=int), + description="Delete an API key", usage="vastai delete api-key ID", - help="Remove an api-key", + help="Delete an API key", ) def delete__api_key(args): - url = apiurl(args, "/auth/apikeys/{id}/".format(id=args.id)) - r = http_del(args, url, headers=headers) - r.raise_for_status() - print(r.json()) + result = api_call(args, "DELETE", f"/auth/apikeys/{args.id}/") + if args.raw: + return result + print(result) @parser.command( argument("id", help="id ssh key to delete", type=int), + description="Remove an SSH key from your account", usage="vastai delete ssh-key ID", - help="Remove an ssh-key", + help="Remove an SSH key from your account", ) def delete__ssh_key(args): - url = apiurl(args, "/ssh/{id}/".format(id=args.id)) - r = http_del(args, url, headers=headers) - r.raise_for_status() - print(r.json()) + result = api_call(args, "DELETE", f"/ssh/{args.id}/") + if args.raw: + return result + print(result) @parser.command( argument("id", help="id of scheduled job to remove", type=int), + description="Delete a scheduled job", usage="vastai delete scheduled-job ID", help="Delete a scheduled job", ) def delete__scheduled_job(args): - url = apiurl(args, "/commands/schedule_job/{id}/".format(id=args.id)) - r = http_del(args, url, headers=headers) - r.raise_for_status() - print(r.json()) + result = api_call(args, "DELETE", f"/commands/schedule_job/{args.id}/") + if args.raw: + return result + print(result) @parser.command( argument("cluster_id", help="ID of cluster to delete", type=int), + description="[Beta] Delete a machine cluster", usage="vastai delete cluster CLUSTER_ID", - help="Delete Cluster", + help="[Beta] Delete a machine cluster", epilog=deindent(""" Delete Vast Cluster""") ) @@ -2852,28 +3073,27 @@ def delete__cluster(args: argparse.Namespace): if args.explain: print("request json:", json_blob) - req_url = apiurl(args, "/cluster/") - r = http_del(args, req_url, json=json_blob) - r.raise_for_status() + result = api_call(args, "DELETE", "/cluster/", json_body=json_blob) if args.raw: - return r + return result - print(r.json()["msg"]) + print(result.get("msg", "Unknown error")) @parser.command( argument("id", help="id of group to delete", type=int), + description="Delete an autoscaling worker group", usage="vastai delete workergroup ID ", - help="Delete a workergroup group", + help="Delete an autoscaling worker group", epilog=deindent(""" Note that deleting a workergroup doesn't automatically destroy all the instances that are associated with your workergroup. Example: vastai delete workergroup 4242 """), ) def delete__workergroup(args): - id = args.id - url = apiurl(args, f"/autojobs/{id}/" ) + workergroup_id = args.id + url = apiurl(args, f"/autojobs/{workergroup_id}/") json_blob = {"client_id": "me", "autojob_id": args.id} if (args.explain): print("request json: ") @@ -2882,7 +3102,10 @@ def delete__workergroup(args): r.raise_for_status() if 'application/json' in r.headers.get('Content-Type', ''): try: - print("workergroup delete {}".format(r.json())) + rj = r.json() + if args.raw: + return rj + print("workergroup delete {}".format(rj)) except requests.exceptions.JSONDecodeError: print("The response is not valid JSON.") print(r) @@ -2893,15 +3116,16 @@ def delete__workergroup(args): @parser.command( argument("id", help="id of endpoint group to delete", type=int), + description="Delete a serverless inference endpoint", usage="vastai delete endpoint ID ", - help="Delete an endpoint group", + help="Delete a serverless inference endpoint", epilog=deindent(""" Example: vastai delete endpoint 4242 """), ) def delete__endpoint(args): - id = args.id - url = apiurl(args, f"/endptjobs/{id}/" ) + endpoint_id = args.id + url = apiurl(args, f"/endptjobs/{endpoint_id}/") json_blob = {"client_id": "me", "endptjob_id": args.id} if (args.explain): print("request json: ") @@ -2910,7 +3134,10 @@ def delete__endpoint(args): r.raise_for_status() if 'application/json' in r.headers.get('Content-Type', ''): try: - print("delete endpoint {}".format(r.json())) + rj = r.json() + if args.raw: + return rj + print("delete endpoint {}".format(rj)) except requests.exceptions.JSONDecodeError: print("The response is not valid JSON.") print(r) @@ -2921,6 +3148,7 @@ def delete__endpoint(args): @parser.command( argument("name", help="Environment variable name to delete", type=str), + description="Delete a user environment variable", usage="vastai delete env-var ", help="Delete a user environment variable", ) @@ -2931,7 +3159,13 @@ def delete__env_var(args): r = http_del(args, url, headers=headers, json=data) r.raise_for_status() - result = r.json() + try: + result = r.json() + except JSONDecodeError: + print("Error: API returned invalid JSON response", file=sys.stderr) + return + if args.raw: + return result if result.get("success"): print(result.get("msg", "Environment variable deleted successfully.")) else: @@ -2939,8 +3173,9 @@ def delete__env_var(args): @parser.command( argument("overlay_identifier", help="ID (int) or name (str) of overlay to delete", nargs="?"), + description="[Beta] Delete an overlay network and its instances", usage="vastai delete overlay OVERLAY_IDENTIFIER", - help="Deletes overlay and removes all of its associated instances" + help="[Beta] Delete an overlay network and its instances" ) def delete__overlay(args: argparse.Namespace): identifier = args.overlay_identifier @@ -2957,20 +3192,19 @@ def delete__overlay(args: argparse.Namespace): if args.explain: print("request json:", json_blob) - req_url = apiurl(args, "/overlay/") - r = http_del(args, req_url, json=json_blob) - r.raise_for_status() + result = api_call(args, "DELETE", "/overlay/", json_body=json_blob) if args.raw: - return r + return result - print(r.json()["msg"]) + print(result.get("msg", "Unknown error")) @parser.command( argument("--template-id", help="Template ID of Template to Delete", type=int), argument("--hash-id", help="Hash ID of Template to Delete", type=str), + description="Delete a template", usage="vastai delete template [--template-id | --hash-id ]", - help="Delete a Template", + help="Delete a template", epilog=deindent(""" Note: Deleting a template only removes the user's replationship to a template. It does not get destroyed Example: vastai delete template --template-id 12345 @@ -2979,7 +3213,7 @@ def delete__overlay(args: argparse.Namespace): ) def delete__template(args): url = apiurl(args, f"/template/" ) - + if args.hash_id: json_blob = { "hash_id": args.hash_id } elif args.template_id: @@ -2987,18 +3221,20 @@ def delete__template(args): else: print('ERROR: Must Specify either Template ID or Hash ID to delete a template') return - + if (args.explain): print("request json: ") print(json_blob) print(args) print(url) r = http_del(args, url, headers=headers,json=json_blob) - print(r) # r.raise_for_status() if 'application/json' in r.headers.get('Content-Type', ''): try: - print(r.json()['msg']) + rj = r.json() + if args.raw: + return rj + print(rj.get('msg', 'Unknown error')) except requests.exceptions.JSONDecodeError: print("The response is not valid JSON.") print(r) @@ -3010,20 +3246,19 @@ def delete__template(args): @parser.command( argument("id", help="id of volume contract", type=int), + description="Delete a persistent storage volume", usage="vastai delete volume ID", - help="Delete a volume", + help="Delete a persistent storage volume", epilog=deindent(""" Deletes volume with the given ID. All instances using the volume must be destroyed before the volume can be deleted. """) ) def delete__volume(args: argparse.Namespace): - url = apiurl(args, "/volumes/", query_args={"id": args.id}) - r = http_del(args, url, headers=headers) - r.raise_for_status() + result = api_call(args, "DELETE", "/volumes/", query_args={"id": args.id}) if args.raw: - return r + return result else: - print("Deleted. {}".format(r.json())) + print("Deleted. {}".format(result)) def destroy_instance(id,args): @@ -3031,25 +3266,32 @@ def destroy_instance(id,args): r = http_del(args, url, headers=headers,json={}) r.raise_for_status() if args.raw: - return r - elif (r.status_code == 200): - rj = r.json(); - if (rj["success"]): - print("destroying instance {id}.".format(**(locals()))); - else: - print(rj["msg"]); + return r.json() + rj = r.json(); + if rj.get("success"): + print("destroying instance {id}.".format(**(locals()))); else: - print(r.text); - print("failed with error {r.status_code}".format(**locals())); + print(rj.get("msg", "Unknown error")); @parser.command( argument("id", help="id of instance to delete", type=int), + description="Destroy an instance (irreversible, deletes data)", usage="vastai destroy instance id [-h] [--api-key API_KEY] [--raw]", help="Destroy an instance (irreversible, deletes data)", epilog=deindent(""" - Perfoms the same action as pressing the "DESTROY" button on the website at https://console.vast.ai/instances/ - Example: vastai destroy instance 4242 + Performs the same action as pressing the "DESTROY" button on the website at https://console.vast.ai/instances/ + + WARNING: This action is IMMEDIATE and IRREVERSIBLE. All data on the instance will be permanently + deleted unless you have saved it to a persistent volume or external storage. + + Examples: + vastai destroy instance 12345 # Destroy instance with ID 12345 + + Before destroying: + - Save any important data using 'vastai copy' or by mounting a persistent volume + - Check instance ID carefully with 'vastai show instances' + - Consider using 'vastai stop instance' if you want to pause without data loss """), ) def destroy__instance(args): @@ -3060,8 +3302,9 @@ def destroy__instance(args): destroy_instance(args.id,args) @parser.command( - argument("ids", help="ids of instance to destroy", type=int, nargs='+'), - usage="vastai destroy instances [--raw] ", + argument("ids", help="ids of instances to destroy", type=int, nargs='+'), + description="Destroy a list of instances (irreversible, deletes data)", + usage="vastai destroy instances IDS [OPTIONS]", help="Destroy a list of instances (irreversible, deletes data)", ) def destroy__instances(args): @@ -3071,20 +3314,22 @@ def destroy__instances(args): destroy_instance(id, args) @parser.command( + description="Delete your team and remove all members", usage="vastai destroy team", - help="Destroy your team", + help="Delete your team and remove all members", ) def destroy__team(args): - url = apiurl(args, "/team/") - r = http_del(args, url, headers=headers) - r.raise_for_status() - print(r.json()) + result = api_call(args, "DELETE", "/team/") + if args.raw: + return result + print(result) @parser.command( argument("instance_id", help="id of the instance", type=int), argument("ssh_key_id", help="id of the key to detach to the instance", type=str), + description="Remove an SSH key from an instance", usage="vastai detach instance_id ssh_key_id", - help="Detach an ssh key from an instance", + help="Remove an SSH key from an instance", epilog=deindent(""" Example: vastai detach 99999 12345 """) @@ -3093,7 +3338,13 @@ def detach__ssh(args): url = apiurl(args, "/instances/{id}/ssh/{ssh_key_id}/".format(id=args.instance_id, ssh_key_id=args.ssh_key_id)) r = http_del(args, url, headers=headers) r.raise_for_status() - print(r.json()) + try: + rj = r.json() + except JSONDecodeError: + rj = {"response": r.text} + if args.raw: + return rj + print(rj) @parser.command( argument("id", help="id of instance to execute on", type=int), @@ -3103,8 +3354,9 @@ def detach__ssh(args): argument("--end_date", type=str, default=default_end_date(), help="End date/time in format 'YYYY-MM-DD HH:MM:SS PM' (UTC). Default is 7 days from now. (optional)"), argument("--day", type=parse_day_cron_style, help="Day of week you want scheduled job to run on (0-6, where 0=Sunday) or \"*\". Default will be 0. For ex. --day 0", default=0), argument("--hour", type=parse_hour_cron_style, help="Hour of day you want scheduled job to run on (0-23) or \"*\" (UTC). Default will be 0. For ex. --hour 16", default=0), + description="Execute a command on a running instance", usage="vastai execute id COMMAND", - help="Execute a (constrained) remote command on a machine", + help="Execute a command on a running instance", epilog=deindent(""" Examples: vastai execute 99999 'ls -l -o -r' @@ -3141,22 +3393,21 @@ def execute(args): add_scheduled_job(args, json_blob, cli_command, api_endpoint, "PUT", instance_id=args.id) return - if (r.status_code == 200): - rj = r.json() - if (rj["success"]): - for i in range(0,30): - time.sleep(0.3) - url = rj["result_url"] - r = requests.get(url) - if (r.status_code == 200): - filtered_text = r.text.replace(rj["writeable_path"], ''); - print(filtered_text) - break - else: - print(rj); + rj = r.json() + if rj.get("success"): + url = rj.get("result_url") + if not url: + print("Error: API response missing required 'result_url' field", file=sys.stderr) + return + for i in range(0,30): + time.sleep(0.3) + r = http_get(args, url) + if (r.status_code == 200): + filtered_text = r.text.replace(rj.get("writeable_path", ''), ''); + print(filtered_text) + break else: - print(r.text); - print("failed with error {r.status_code}".format(**locals())); + print(rj); @@ -3164,8 +3415,9 @@ def execute(args): argument("id", help="id of endpoint group to fetch logs from", type=int), argument("--level", help="log detail level (0 to 3)", type=int, default=1), argument("--tail", help="", type=int, default=None), + description="Get logs for a serverless endpoint", usage="vastai get endpt-logs ID [--api-key API_KEY]", - help="Fetch logs for a specific serverless endpoint group", + help="Get logs for a serverless endpoint", epilog=deindent(""" Example: vastai get endpt-logs 382 """), @@ -3175,7 +3427,7 @@ def get__endpt_logs(args): if args.url == server_url_default: args.url = None url = (args.url or "https://run.vast.ai") + "/get_endpoint_logs/" - json_blob = {"id": args.id, "api_key": args.api_key} + json_blob = {"id": args.id} if args.tail: json_blob["tail"] = args.tail if (args.explain): print(f"{url} with request json: ") @@ -3185,29 +3437,27 @@ def get__endpt_logs(args): r.raise_for_status() levels = {0 : "info0", 1: "info1", 2: "trace", 3: "debug"} - if (r.status_code == 200): - rj = None - try: - rj = r.json() - except Exception as e: - print(str(e)) - print(r.text) - if args.raw: - # sort_keys - return rj or r.text - else: - dbg_lvl = levels[args.level] - if rj and dbg_lvl: print(rj[dbg_lvl]) - #print(json.dumps(rj, indent=1, sort_keys=True)) - else: + rj = None + try: + rj = r.json() + except Exception as e: + print(str(e)) print(r.text) + if args.raw: + # sort_keys + return rj or r.text + else: + dbg_lvl = levels[args.level] + if rj and dbg_lvl: print(rj[dbg_lvl]) + #print(json.dumps(rj, indent=1, sort_keys=True)) @parser.command( argument("id", help="id of endpoint group to fetch logs from", type=int), argument("--level", help="log detail level (0 to 3)", type=int, default=1), argument("--tail", help="", type=int, default=None), + description="Get logs for an autoscaling worker group", usage="vastai get wrkgrp-logs ID [--api-key API_KEY]", - help="Fetch logs for a specific serverless worker group group", + help="Get logs for an autoscaling worker group", epilog=deindent(""" Example: vastai get endpt-logs 382 """), @@ -3217,7 +3467,7 @@ def get__wrkgrp_logs(args): if args.url == server_url_default: args.url = None url = (args.url or "https://run.vast.ai") + "/get_autogroup_logs/" - json_blob = {"id": args.id, "api_key": args.api_key} + json_blob = {"id": args.id} if args.tail: json_blob["tail"] = args.tail if (args.explain): print(f"{url} with request json: ") @@ -3227,45 +3477,46 @@ def get__wrkgrp_logs(args): r.raise_for_status() levels = {0 : "info0", 1: "info1", 2: "trace", 3: "debug"} - if (r.status_code == 200): - rj = None - try: - rj = r.json() - except Exception as e: - print(str(e)) - print(r.text) - if args.raw: - # sort_keys - return rj or r.text - else: - dbg_lvl = levels[args.level] - if rj and dbg_lvl: print(rj[dbg_lvl]) - #print(json.dumps(rj, indent=1, sort_keys=True)) - else: + rj = None + try: + rj = r.json() + except Exception as e: + print(str(e)) print(r.text) + if args.raw: + # sort_keys + return rj or r.text + else: + dbg_lvl = levels[args.level] + if rj and dbg_lvl: print(rj[dbg_lvl]) + #print(json.dumps(rj, indent=1, sort_keys=True)) @parser.command( argument("--email", help="email of user to be invited", type=str), argument("--role", help="role of user to be invited", type=str), + description="Invite a user to join your team", usage="vastai invite member --email EMAIL --role ROLE", - help="Invite a team member", + help="Invite a user to join your team", ) def invite__member(args): url = apiurl(args, "/team/invite/", query_args={"email": args.email, "role": args.role}) r = http_post(args, url, headers=headers) r.raise_for_status() - if (r.status_code == 200): - print(f"successfully invited {args.email} to your current team") - else: - print(r.text); - print(f"failed with error {r.status_code}") + try: + rj = r.json() + except JSONDecodeError: + rj = {"success": True, "email": args.email} + if args.raw: + return rj + print(f"successfully invited {args.email} to your current team") @parser.command( argument("cluster_id", help="ID of cluster to add machine to", type=int), argument("machine_ids", help="machine id(s) to join cluster", type=int, nargs="+"), + description="[Beta] Add a machine to an existing cluster", usage="vastai join cluster CLUSTER_ID MACHINE_IDS", - help="Join Machine to Cluster", + help="[Beta] Add a machine to an existing cluster", epilog=deindent(""" Join's Machine to Vast Cluster """) @@ -3284,16 +3535,17 @@ def join__cluster(args: argparse.Namespace): r.raise_for_status() if args.raw: - return r + return r.json() - print(r.json()["msg"]) + print(r.json().get("msg", "Unknown error")) @parser.command( argument("name", help="Overlay network name to join instance to.", type=str), argument("instance_id", help="Instance ID to add to overlay.", type=int), + description="[Beta] Connect an instance to an overlay network", usage="vastai join overlay OVERLAY_NAME INSTANCE_ID", - help="Adds instance to an overlay network", + help="[Beta] Connect an instance to an overlay network", epilog=deindent(""" Adds an instance to a compatible overlay network.""") ) @@ -3311,15 +3563,16 @@ def join__overlay(args: argparse.Namespace): r.raise_for_status() if args.raw: - return r + return r.json() - print(r.json()["msg"]) + print(r.json().get("msg", "Unknown error")) @parser.command( argument("id", help="id of instance to label", type=int), argument("label", help="label to set", type=str), + description="Assign a string label to an instance", usage="vastai label instance