Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 14 additions & 23 deletions lib/mlx_lm/weight_utils.rb
Original file line number Diff line number Diff line change
Expand Up @@ -51,30 +51,21 @@ def _tensor_to_mlx(info, mx)
dtype_str = info["dtype"]
data = info["data"]

# For F32/float32, unpack as little-endian floats
if dtype_str == "F32" || dtype_str == "float32"
values = data.unpack("e*")
mx.array(values).reshape(shape)
elsif dtype_str == "F16"
# 16-bit float: unpack as uint16, create array as float32, then view as float16
values = data.unpack("S<*")
mx.array(values, dtype: mx.uint16).view(mx.float16).reshape(shape)
elsif dtype_str == "BF16"
values = data.unpack("S<*")
mx.array(values, dtype: mx.uint16).view(mx.bfloat16).reshape(shape)
elsif dtype_str == "I32" || dtype_str == "int32"
values = data.unpack("l<*")
mx.array(values, dtype: mx.int32).reshape(shape)
elsif dtype_str == "I64"
values = data.unpack("q<*")
mx.array(values, dtype: mx.int64).reshape(shape)
elsif dtype_str == "U8"
values = data.unpack("C*")
mx.array(values, dtype: mx.uint8).reshape(shape)
dtype_str = "F32" if dtype_str == "float32"
dtype_str = "I32" if dtype_str == "int32"

# F16/BF16 lack a direct unpack path; stage through uint16 + .view.
if dtype_str == "F16" || dtype_str == "BF16"
view_dtype = dtype_str == "F16" ? mx.float16 : mx.bfloat16
return mx.array(data.unpack("S<*"), dtype: mx.uint16).view(view_dtype).reshape(shape)
end

format_str, dtype_sym = DTYPE_UNPACK[dtype_str]
if format_str
mx.array(data.unpack(format_str), dtype: mx.__send__(dtype_sym)).reshape(shape)
else
# Fallback: try F32
values = data.unpack("e*")
mx.array(values).reshape(shape)
# Unknown dtype — interpret raw bytes as little-endian F32.
mx.array(data.unpack("e*")).reshape(shape)
end
end

Expand Down