diff --git a/lib/mlx_lm/weight_utils.rb b/lib/mlx_lm/weight_utils.rb index cffa5c4..abdcf36 100644 --- a/lib/mlx_lm/weight_utils.rb +++ b/lib/mlx_lm/weight_utils.rb @@ -51,30 +51,21 @@ def _tensor_to_mlx(info, mx) dtype_str = info["dtype"] data = info["data"] - # For F32/float32, unpack as little-endian floats - if dtype_str == "F32" || dtype_str == "float32" - values = data.unpack("e*") - mx.array(values).reshape(shape) - elsif dtype_str == "F16" - # 16-bit float: unpack as uint16, create array as float32, then view as float16 - values = data.unpack("S<*") - mx.array(values, dtype: mx.uint16).view(mx.float16).reshape(shape) - elsif dtype_str == "BF16" - values = data.unpack("S<*") - mx.array(values, dtype: mx.uint16).view(mx.bfloat16).reshape(shape) - elsif dtype_str == "I32" || dtype_str == "int32" - values = data.unpack("l<*") - mx.array(values, dtype: mx.int32).reshape(shape) - elsif dtype_str == "I64" - values = data.unpack("q<*") - mx.array(values, dtype: mx.int64).reshape(shape) - elsif dtype_str == "U8" - values = data.unpack("C*") - mx.array(values, dtype: mx.uint8).reshape(shape) + dtype_str = "F32" if dtype_str == "float32" + dtype_str = "I32" if dtype_str == "int32" + + # F16/BF16 lack a direct unpack path; stage through uint16 + .view. + if dtype_str == "F16" || dtype_str == "BF16" + view_dtype = dtype_str == "F16" ? mx.float16 : mx.bfloat16 + return mx.array(data.unpack("S<*"), dtype: mx.uint16).view(view_dtype).reshape(shape) + end + + format_str, dtype_sym = DTYPE_UNPACK[dtype_str] + if format_str + mx.array(data.unpack(format_str), dtype: mx.__send__(dtype_sym)).reshape(shape) else - # Fallback: try F32 - values = data.unpack("e*") - mx.array(values).reshape(shape) + # Unknown dtype — interpret raw bytes as little-endian F32. + mx.array(data.unpack("e*")).reshape(shape) end end