Patches applied to PrismML's llama.cpp fork to compile with CUDA 10.2 on NVIDIA Jetson Nano (SM 5.3 / Maxwell) using GCC 8 and aarch64 NEON.
Stats: 27 files modified, 504 insertions, 501 deletions.
- C++17 to C++14 Downgrade
- CUDA 10.2 API Stubs
- SM 5.3 Maxwell Specifics
- ARM NEON GCC 8 Workarounds
- Linker Fixes
- Critical Correctness Fix (binbcast.cu)
- Build System
Scope: All platforms targeting CUDA 10.2 (nvcc only supports C++14).
Files: common.cuh, binbcast.cu, concat.cu, convert.cuh, cumsum.cu,
fattn-common.cuh, ggml-cuda.cu, mma.cuh, mmf.cuh, mmid.cu, mmq.cuh,
mmvf.cu, mmvq.cu, norm.cu, quantize.cu, rope.cu, softmax.cu,
topk-moe.cu, tri.cu
CUDA 10.2 nvcc does not support if constexpr. All instances replaced with
plain if. The compiler optimizes away dead branches when the condition depends
only on template parameters, so there is no runtime cost.
// Before (C++17)
if constexpr (std::is_same_v<T, float>) {
...
}
// After (C++14)
if (std::is_same<T, float>::value) {
...
}The _v variable template shortcuts are C++17.
// Before
std::is_same_v<T, U>
// After
std::is_same<T, U>::valueC++17 variable templates with fold expressions replaced with C++14 recursive template struct pattern.
// Before (C++17)
template <typename T, typename... Ts>
inline constexpr bool is_any = (std::is_same_v<T, Ts> || ...);
// After (C++14)
template <typename T, typename... Ts>
struct is_any_impl : std::false_type {};
template <typename T, typename First, typename... Rest>
struct is_any_impl<T, First, Rest...>
: std::conditional<std::is_same<T, First>::value,
std::true_type,
is_any_impl<T, Rest...>>::type {};The inline constexpr bool ggml_cuda_dependent_false_v<T> = false helper used
in static_assert was removed entirely. Replaced with sizeof(T) == 0 idiom
or (void)0 where the assert was non-critical.
// Before
static_assert(ggml_cuda_dependent_false_v<T>, "unsupported type");
// After
static_assert(sizeof(T) == 0, "unsupported type");C++17 structured bindings decomposed into explicit pair access.
// Before (C++17)
for (auto & [key, graph] : map) {
...
}
// After (C++14)
for (auto & kv : map) {
auto & key = kv.first;
auto & graph = kv.second;
...
}C++17 fold expressions (expr || ...) replaced with direct calls or simplified
logic.
// Before
[[maybe_unused]] int x = 0;
[[noreturn]] void die();
// After
int x = 0; // attribute simply removed
__attribute__((noreturn)) void die(); // GCC equivalentC++17 inline static data members downgraded.
// Before
static inline const int value = 42;
// After
static const int value = 42;
// or
static constexpr int value = 42;CUDA 10.2 does not support constexpr on __device__ functions. The
constexpr qualifier was removed.
// Before
static constexpr __device__ float warp_reduce(float x);
// After
static __device__ float warp_reduce(float x);Scope: CUDA 10.2 only (these APIs were added in CUDA 11+).
Files: vendors/cuda.h, stubs/cooperative_groups/reduce.h, ggml-cuda.cu,
softmax.cu, fa-stub.cu (new)
cuda_bf16.h does not exist in CUDA 10.2. A full stub header provides the
necessary types and conversion functions.
// vendors/cuda.h (excerpt)
struct __nv_bfloat16 {
unsigned short __x;
__host__ __device__ __nv_bfloat16() : __x(0) {}
// float -> bf16 conversion
__host__ __device__ explicit __nv_bfloat16(float f) {
unsigned int bits;
memcpy(&bits, &f, sizeof(bits));
__x = (unsigned short)(bits >> 16);
}
// bf16 -> float conversion
__host__ __device__ operator float() const {
unsigned int bits = ((unsigned int)__x) << 16;
float f;
memcpy(&f, &bits, sizeof(f));
return f;
}
};
struct __nv_bfloat162 {
__nv_bfloat16 x, y;
};
__host__ __device__ inline __nv_bfloat16 __float2bfloat16(float f) {
return __nv_bfloat16(f);
}
__host__ __device__ inline float __bfloat162float(__nv_bfloat16 bf) {
return (float)bf;
}cooperative_groups/reduce.h does not exist in CUDA 10.2. An empty stub file
is placed at stubs/cooperative_groups/reduce.h and the stubs directory is
added to the include path.
#ifndef CUDA_R_16BF
#define CUDA_R_16BF 2
#endifThis flag does not exist in CUDA 10.2. Removed from cudaHostRegister() calls
(the flag is an optimization hint only, not required for correctness).
CUDA 10.2 requires the explicit flags argument. CUDA 11+ made it optional with
a default of 0.
// Before (CUDA 11+)
cudaStreamWaitEvent(stream, event);
// After (CUDA 10.2)
cudaStreamWaitEvent(stream, event, 0);Scope: SM 5.3 (Jetson Nano / Maxwell) only. These patches address missing hardware features: no tensor cores, limited warp primitives, no constexpr device support.
Files: common.cuh, mmq.cuh, mmvq.cu, CMakeLists.txt,
fa-stub.cu (new)
The upstream functions query device properties at runtime. On Maxwell these are fixed values, and the function-based approach caused issues with CUDA 10.2 constexpr limitations.
// Before
int ws = ggml_cuda_get_physical_warp_size();
int cb = ggml_cuda_get_max_cpy_bytes();
// After
#define ggml_cuda_get_physical_warp_size() 32
#define ggml_cuda_get_max_cpy_bytes() 8Maxwell does not have tensor cores so the dynamic MMQ parameter functions are replaced with fixed values known to work on SM 5.3.
#define mmq_get_nwarps_device() 8
#define get_mmq_y_device() 64
#define get_mmq_x_max_device() 64
#define mmq_get_granularity_device() 8
#define get_iter_k() MMQ_ITER_KFlash attention kernels use features not available on SM 5.3 (tensor cores,
advanced warp primitives). All fattn*.cu files are excluded from the build.
A new fa-stub.cu provides stub implementations that return false or abort,
ensuring the rest of the codebase links without flash attention support.
CUDA 10.2 on Maxwell does not allow function calls inside __launch_bounds__.
Replaced with literal values.
// Before
__launch_bounds__(get_max_threads(), get_min_blocks())
// After
__launch_bounds__(256, 1)
// or
__launch_bounds__(128, 1)Scope: aarch64 with GCC 8 (Jetson Nano default toolchain).
Files: ggml-cpu-impl.h
GCC 8 on aarch64 has a known bug where multi-load NEON intrinsics
(vld1q_s8_x4, vld1q_u8_x4, etc.) return incorrect types, causing
compilation failures.
// Before (uses compiler intrinsics directly)
#define ggml_int8x16x4_t int8x16x4_t
#define ggml_vld1q_s8_x4 vld1q_s8_x4
// After (custom struct + inline function doing individual loads)
typedef struct {
int8x16_t val[4];
} ggml_int8x16x4_t;
static inline ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
ggml_int8x16x4_t res;
res.val[0] = vld1q_s8(ptr + 0);
res.val[1] = vld1q_s8(ptr + 16);
res.val[2] = vld1q_s8(ptr + 32);
res.val[3] = vld1q_s8(ptr + 48);
return res;
}Scope: GCC 8 (requires explicit filesystem library linkage).
Files: ggml/src/CMakeLists.txt
GCC 8 does not automatically link libstdc++fs for std::filesystem. Added
explicit linkage.
target_link_libraries(ggml PRIVATE stdc++fs)
target_link_libraries(ggml-base PRIVATE stdc++fs)Scope: ALL PLATFORMS. This is a correctness bug introduced during the C++14 porting process but affects the result of every binary operation.
File: binbcast.cu
The C++17 fold expression in the binary broadcast kernel was incorrectly
replaced with (void)0 during porting. This silently broke all binary
operations (add, mul, sub, div). Models would load and appear to run but
produce garbage output.
// Original (C++17 fold expression)
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
// BROKEN (incorrect C++14 port)
(void)0; // fold expression removed for CUDA 10.2 compat
// FIXED (correct C++14 equivalent)
result = bin_op(result, (float)src1[i_src1 + i10*s10]);Symptoms when broken: Model loads successfully, tokenizer works, inference runs without errors, but all output is nonsensical. Extremely difficult to diagnose because there are no crashes or warnings.
Note: This fix should be considered for upstream contribution as the fold expression replacement is a correctness issue regardless of CUDA version.
Scope: CUDA 10.2 + SM 5.3 build configuration.
Files: ggml/src/ggml-cuda/CMakeLists.txt
set_target_properties(ggml-cuda PROPERTIES CUDA_STANDARD 14)# Exclude all flash attention template instantiations
list(FILTER GGML_SOURCES_CUDA EXCLUDE REGEX "fattn.*\\.cu")
# Corresponding glob patterns commented out in source listsset(CMAKE_CUDA_ARCHITECTURES 53)| Scope | Categories | Risk |
|---|---|---|
| All platforms | 6 (binbcast.cu correctness fix) | Critical - silent wrong results |
| Any CUDA 10.2 target | 1 (C++17->14), 2 (API stubs), 7 (build) | Build failures without patches |
| SM 5.3 Maxwell only | 3 (hardware-specific) | Build/runtime failures on Maxwell |
| aarch64 + GCC 8 | 4 (NEON), 5 (linker) | Build failures on Jetson Nano toolchain |