From e66f80026cd3b03c8926587ab77e192547cf9a61 Mon Sep 17 00:00:00 2001 From: gnuduncan Date: Sun, 19 Apr 2026 09:17:25 +0200 Subject: [PATCH 1/2] Fix U32/U16/I8/I16 weight loading for quantized models _tensor_to_mlx declared U32 and other integer dtypes in DTYPE_UNPACK but never branched on them in the if/elsif chain. Packed 4-bit quantized weights (stored as uint32 in mlx-community safetensors) fell through to the F32 fallback and were decoded as garbage floats, causing `[dequantize] The matrix should be given as a uint32` on the first QuantizedEmbedding forward. Reproduces on mlx-community/Llama-3.2-1B-Instruct-4bit and presumably every 4-bit model. --- lib/mlx_lm/weight_utils.rb | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/lib/mlx_lm/weight_utils.rb b/lib/mlx_lm/weight_utils.rb index cffa5c4..b752476 100644 --- a/lib/mlx_lm/weight_utils.rb +++ b/lib/mlx_lm/weight_utils.rb @@ -71,6 +71,18 @@ def _tensor_to_mlx(info, mx) elsif dtype_str == "U8" values = data.unpack("C*") mx.array(values, dtype: mx.uint8).reshape(shape) + elsif dtype_str == "U16" + values = data.unpack("S<*") + mx.array(values, dtype: mx.uint16).reshape(shape) + elsif dtype_str == "U32" + values = data.unpack("L<*") + mx.array(values, dtype: mx.uint32).reshape(shape) + elsif dtype_str == "I8" + values = data.unpack("c*") + mx.array(values, dtype: mx.int8).reshape(shape) + elsif dtype_str == "I16" + values = data.unpack("s<*") + mx.array(values, dtype: mx.int16).reshape(shape) else # Fallback: try F32 values = data.unpack("e*") From 40650bbab40725f4e834183904171f35236a595f Mon Sep 17 00:00:00 2001 From: gnuduncan Date: Sun, 19 Apr 2026 09:22:32 +0200 Subject: [PATCH 2/2] Use DTYPE_UNPACK for table-driven dtype dispatch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The if/elsif chain in _tensor_to_mlx duplicated the DTYPE_UNPACK constant declared at the top of the file — which is how the U16/U32/ I8/I16 branches went missing from the chain in the first place. Table- driven lookup keeps the mapping in one place. F16 and BF16 stay as explicit branches because they take a different code path (uint16 stage + .view cast). Unknown-dtype F32 fallback is preserved to match prior behavior. Uses __send__ instead of send because MLX::Core defines a `send` method (takes 2..4 args) that would shadow Object#send. --- lib/mlx_lm/weight_utils.rb | 49 +++++++++++--------------------------- 1 file changed, 14 insertions(+), 35 deletions(-) diff --git a/lib/mlx_lm/weight_utils.rb b/lib/mlx_lm/weight_utils.rb index b752476..abdcf36 100644 --- a/lib/mlx_lm/weight_utils.rb +++ b/lib/mlx_lm/weight_utils.rb @@ -51,42 +51,21 @@ def _tensor_to_mlx(info, mx) dtype_str = info["dtype"] data = info["data"] - # For F32/float32, unpack as little-endian floats - if dtype_str == "F32" || dtype_str == "float32" - values = data.unpack("e*") - mx.array(values).reshape(shape) - elsif dtype_str == "F16" - # 16-bit float: unpack as uint16, create array as float32, then view as float16 - values = data.unpack("S<*") - mx.array(values, dtype: mx.uint16).view(mx.float16).reshape(shape) - elsif dtype_str == "BF16" - values = data.unpack("S<*") - mx.array(values, dtype: mx.uint16).view(mx.bfloat16).reshape(shape) - elsif dtype_str == "I32" || dtype_str == "int32" - values = data.unpack("l<*") - mx.array(values, dtype: mx.int32).reshape(shape) - elsif dtype_str == "I64" - values = data.unpack("q<*") - mx.array(values, dtype: mx.int64).reshape(shape) - elsif dtype_str == "U8" - values = data.unpack("C*") - mx.array(values, dtype: mx.uint8).reshape(shape) - elsif dtype_str == "U16" - values = data.unpack("S<*") - mx.array(values, dtype: mx.uint16).reshape(shape) - elsif dtype_str == "U32" - values = data.unpack("L<*") - mx.array(values, dtype: mx.uint32).reshape(shape) - elsif dtype_str == "I8" - values = data.unpack("c*") - mx.array(values, dtype: mx.int8).reshape(shape) - elsif dtype_str == "I16" - values = data.unpack("s<*") - mx.array(values, dtype: mx.int16).reshape(shape) + dtype_str = "F32" if dtype_str == "float32" + dtype_str = "I32" if dtype_str == "int32" + + # F16/BF16 lack a direct unpack path; stage through uint16 + .view. + if dtype_str == "F16" || dtype_str == "BF16" + view_dtype = dtype_str == "F16" ? mx.float16 : mx.bfloat16 + return mx.array(data.unpack("S<*"), dtype: mx.uint16).view(view_dtype).reshape(shape) + end + + format_str, dtype_sym = DTYPE_UNPACK[dtype_str] + if format_str + mx.array(data.unpack(format_str), dtype: mx.__send__(dtype_sym)).reshape(shape) else - # Fallback: try F32 - values = data.unpack("e*") - mx.array(values).reshape(shape) + # Unknown dtype — interpret raw bytes as little-endian F32. + mx.array(data.unpack("e*")).reshape(shape) end end