From e66f80026cd3b03c8926587ab77e192547cf9a61 Mon Sep 17 00:00:00 2001 From: gnuduncan Date: Sun, 19 Apr 2026 09:17:25 +0200 Subject: [PATCH 1/3] Fix U32/U16/I8/I16 weight loading for quantized models _tensor_to_mlx declared U32 and other integer dtypes in DTYPE_UNPACK but never branched on them in the if/elsif chain. Packed 4-bit quantized weights (stored as uint32 in mlx-community safetensors) fell through to the F32 fallback and were decoded as garbage floats, causing `[dequantize] The matrix should be given as a uint32` on the first QuantizedEmbedding forward. Reproduces on mlx-community/Llama-3.2-1B-Instruct-4bit and presumably every 4-bit model. --- lib/mlx_lm/weight_utils.rb | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/lib/mlx_lm/weight_utils.rb b/lib/mlx_lm/weight_utils.rb index cffa5c4..b752476 100644 --- a/lib/mlx_lm/weight_utils.rb +++ b/lib/mlx_lm/weight_utils.rb @@ -71,6 +71,18 @@ def _tensor_to_mlx(info, mx) elsif dtype_str == "U8" values = data.unpack("C*") mx.array(values, dtype: mx.uint8).reshape(shape) + elsif dtype_str == "U16" + values = data.unpack("S<*") + mx.array(values, dtype: mx.uint16).reshape(shape) + elsif dtype_str == "U32" + values = data.unpack("L<*") + mx.array(values, dtype: mx.uint32).reshape(shape) + elsif dtype_str == "I8" + values = data.unpack("c*") + mx.array(values, dtype: mx.int8).reshape(shape) + elsif dtype_str == "I16" + values = data.unpack("s<*") + mx.array(values, dtype: mx.int16).reshape(shape) else # Fallback: try F32 values = data.unpack("e*") From 40650bbab40725f4e834183904171f35236a595f Mon Sep 17 00:00:00 2001 From: gnuduncan Date: Sun, 19 Apr 2026 09:22:32 +0200 Subject: [PATCH 2/3] Use DTYPE_UNPACK for table-driven dtype dispatch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The if/elsif chain in _tensor_to_mlx duplicated the DTYPE_UNPACK constant declared at the top of the file — which is how the U16/U32/ I8/I16 branches went missing from the chain in the first place. Table- driven lookup keeps the mapping in one place. F16 and BF16 stay as explicit branches because they take a different code path (uint16 stage + .view cast). Unknown-dtype F32 fallback is preserved to match prior behavior. Uses __send__ instead of send because MLX::Core defines a `send` method (takes 2..4 args) that would shadow Object#send. --- lib/mlx_lm/weight_utils.rb | 49 +++++++++++--------------------------- 1 file changed, 14 insertions(+), 35 deletions(-) diff --git a/lib/mlx_lm/weight_utils.rb b/lib/mlx_lm/weight_utils.rb index b752476..abdcf36 100644 --- a/lib/mlx_lm/weight_utils.rb +++ b/lib/mlx_lm/weight_utils.rb @@ -51,42 +51,21 @@ def _tensor_to_mlx(info, mx) dtype_str = info["dtype"] data = info["data"] - # For F32/float32, unpack as little-endian floats - if dtype_str == "F32" || dtype_str == "float32" - values = data.unpack("e*") - mx.array(values).reshape(shape) - elsif dtype_str == "F16" - # 16-bit float: unpack as uint16, create array as float32, then view as float16 - values = data.unpack("S<*") - mx.array(values, dtype: mx.uint16).view(mx.float16).reshape(shape) - elsif dtype_str == "BF16" - values = data.unpack("S<*") - mx.array(values, dtype: mx.uint16).view(mx.bfloat16).reshape(shape) - elsif dtype_str == "I32" || dtype_str == "int32" - values = data.unpack("l<*") - mx.array(values, dtype: mx.int32).reshape(shape) - elsif dtype_str == "I64" - values = data.unpack("q<*") - mx.array(values, dtype: mx.int64).reshape(shape) - elsif dtype_str == "U8" - values = data.unpack("C*") - mx.array(values, dtype: mx.uint8).reshape(shape) - elsif dtype_str == "U16" - values = data.unpack("S<*") - mx.array(values, dtype: mx.uint16).reshape(shape) - elsif dtype_str == "U32" - values = data.unpack("L<*") - mx.array(values, dtype: mx.uint32).reshape(shape) - elsif dtype_str == "I8" - values = data.unpack("c*") - mx.array(values, dtype: mx.int8).reshape(shape) - elsif dtype_str == "I16" - values = data.unpack("s<*") - mx.array(values, dtype: mx.int16).reshape(shape) + dtype_str = "F32" if dtype_str == "float32" + dtype_str = "I32" if dtype_str == "int32" + + # F16/BF16 lack a direct unpack path; stage through uint16 + .view. + if dtype_str == "F16" || dtype_str == "BF16" + view_dtype = dtype_str == "F16" ? mx.float16 : mx.bfloat16 + return mx.array(data.unpack("S<*"), dtype: mx.uint16).view(view_dtype).reshape(shape) + end + + format_str, dtype_sym = DTYPE_UNPACK[dtype_str] + if format_str + mx.array(data.unpack(format_str), dtype: mx.__send__(dtype_sym)).reshape(shape) else - # Fallback: try F32 - values = data.unpack("e*") - mx.array(values).reshape(shape) + # Unknown dtype — interpret raw bytes as little-endian F32. + mx.array(data.unpack("e*")).reshape(shape) end end From 238ff17d2452263a43795ff9fd8fa595d9015d8e Mon Sep 17 00:00:00 2001 From: gnuduncan Date: Sun, 19 Apr 2026 11:26:02 +0200 Subject: [PATCH 3/3] Add Hub.snapshot_download for HF Hub repo IDs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Resolves the cli.md caveat that `--model ` "does not work yet — current runtime loading expects a local model directory". MlxLm::Hub.snapshot_download fetches a model snapshot from huggingface.co using pure Ruby + stdlib net/http. Cache layout matches huggingface_hub Python's (models----/{refs,snapshots}/...) so caches are mutually reusable — a model downloaded via Python mlx_lm is used as-is, and vice-versa. LoadUtils.load now accepts either a local path or an HF repo id like "mlx-community/Llama-3.2-1B-Instruct-4bit", plus an optional revision: parameter (default "main"). File.directory? short-circuits the local path; otherwise the repo is resolved via the hub. Respects env vars the Python client uses: HF_HUB_CACHE, HF_HOME, HF_ENDPOINT, HF_TOKEN. Handles HTTP redirects (HF CDN) and resolves relative Location headers against the current URL. Scoped Authorization to huggingface.co hosts (redirected CDNs reject bearer tokens). Tested against mlx-community/Llama-3.2-1B-Instruct-4bit: bundle exec exe/mlx_lm generate \ --model mlx-community/Llama-3.2-1B-Instruct-4bit \ --prompt 'The capital of France is' produces `Paris.` on both a warm cache and a fresh cache (verified with allow_patterns to a throwaway cache_dir). --- lib/mlx_lm.rb | 1 + lib/mlx_lm/hub.rb | 131 +++++++++++++++++++++++++++++++++++++++ lib/mlx_lm/load_utils.rb | 22 +++++-- 3 files changed, 149 insertions(+), 5 deletions(-) create mode 100644 lib/mlx_lm/hub.rb diff --git a/lib/mlx_lm.rb b/lib/mlx_lm.rb index fc06890..10bfc97 100644 --- a/lib/mlx_lm.rb +++ b/lib/mlx_lm.rb @@ -127,6 +127,7 @@ require_relative "mlx_lm/quantize" require_relative "mlx_lm/quant/awq" require_relative "mlx_lm/quant/gptq" +require_relative "mlx_lm/hub" require_relative "mlx_lm/load_utils" require_relative "mlx_lm/evaluate" require_relative "mlx_lm/tuner/callbacks" diff --git a/lib/mlx_lm/hub.rb b/lib/mlx_lm/hub.rb new file mode 100644 index 0000000..470c888 --- /dev/null +++ b/lib/mlx_lm/hub.rb @@ -0,0 +1,131 @@ +require "fileutils" +require "json" +require "net/http" +require "pathname" +require "uri" + +module MlxLm + # Minimal pure-Ruby equivalent of huggingface_hub.snapshot_download. + # Cache layout matches the Python client so caches are mutually reusable. + module Hub + DEFAULT_ENDPOINT = "https://huggingface.co".freeze + MAX_REDIRECTS = 5 + DEFAULT_TIMEOUT = 300 + + module_function + + # Download a model snapshot from Hugging Face Hub. + # + # @param repo_id [String] e.g. "mlx-community/Llama-3.2-1B-Instruct-4bit" + # @param revision [String] branch, tag, or commit SHA (default "main") + # @param allow_patterns [Array, String, nil] glob patterns to include + # @param token [String, nil] defaults to ENV["HF_TOKEN"] + # @param cache_dir [String, Pathname, nil] defaults to HF_HUB_CACHE / HF_HOME/hub / ~/.cache/huggingface/hub + # @param endpoint [String, nil] defaults to HF_ENDPOINT or huggingface.co + # @return [Pathname] absolute path to the snapshot directory + def snapshot_download(repo_id, revision: "main", allow_patterns: nil, token: nil, cache_dir: nil, endpoint: nil) + endpoint ||= ENV["HF_ENDPOINT"] || DEFAULT_ENDPOINT + token ||= ENV["HF_TOKEN"] + cache_dir = resolve_cache_dir(cache_dir) + patterns = normalize_patterns(allow_patterns) + + info = fetch_model_info(endpoint, repo_id, revision, token) + sha = info.fetch("sha") + siblings = info.fetch("siblings").map { |s| s.fetch("rfilename") } + + repo_folder = cache_dir.join("models--#{repo_id.gsub("/", "--")}") + snapshot_dir = repo_folder.join("snapshots", sha) + FileUtils.mkdir_p(snapshot_dir) + FileUtils.mkdir_p(repo_folder.join("refs")) + File.write(repo_folder.join("refs", revision), sha) + + siblings.each do |rel| + next unless pattern_match?(rel, patterns) + target = snapshot_dir.join(rel) + next if target.file? && target.size > 0 + + FileUtils.mkdir_p(target.dirname) + url = "#{endpoint}/#{repo_id}/resolve/#{revision}/#{rel}" + download_file(url, target.to_s, token) + end + + snapshot_dir + end + + def resolve_cache_dir(explicit) + return Pathname.new(explicit) if explicit + if (v = ENV["HF_HUB_CACHE"]) && !v.empty? + Pathname.new(v) + elsif (v = ENV["HF_HOME"]) && !v.empty? + Pathname.new(v).join("hub") + else + Pathname.new(Dir.home).join(".cache", "huggingface", "hub") + end + end + + def normalize_patterns(p) + return nil if p.nil? + p.is_a?(::Array) ? p : [p] + end + + def pattern_match?(filename, patterns) + return true if patterns.nil? || patterns.empty? + patterns.any? do |pat| + File.fnmatch(pat, filename, File::FNM_PATHNAME) || + File.fnmatch(pat, File.basename(filename)) + end + end + + def fetch_model_info(endpoint, repo_id, revision, token) + url = "#{endpoint}/api/models/#{repo_id}/revision/#{revision}" + body = http_get_body(url, token) + JSON.parse(body) + end + + def http_get_body(url, token, limit = MAX_REDIRECTS) + raise "Too many redirects fetching #{url}" if limit <= 0 + uri = URI.parse(url) + Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http| + req = Net::HTTP::Get.new(uri.request_uri) + req["Authorization"] = "Bearer #{token}" if token && uri.host.end_with?("huggingface.co") + req["User-Agent"] = user_agent + resp = http.request(req) + case resp + when Net::HTTPSuccess then resp.body + when Net::HTTPRedirection then http_get_body(URI.join(url, resp["location"]).to_s, token, limit - 1) + else raise "HTTP #{resp.code} #{resp.message} fetching #{url}: #{resp.body}" + end + end + end + + def download_file(url, output_path, token = nil, limit = MAX_REDIRECTS) + raise "Too many redirects fetching #{url}" if limit <= 0 + uri = URI.parse(url) + Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https", read_timeout: DEFAULT_TIMEOUT) do |http| + req = Net::HTTP::Get.new(uri.request_uri) + req["Authorization"] = "Bearer #{token}" if token && uri.host.end_with?("huggingface.co") + req["User-Agent"] = user_agent + http.request(req) do |resp| + case resp + when Net::HTTPSuccess + tmp = "#{output_path}.download.#{Process.pid}" + File.open(tmp, "wb") { |f| resp.read_body { |chunk| f.write(chunk) } } + FileUtils.mv(tmp, output_path) + when Net::HTTPRedirection + return download_file(URI.join(url, resp["location"]).to_s, output_path, token, limit - 1) + else + raise "HTTP #{resp.code} #{resp.message} fetching #{url}" + end + end + end + end + + def user_agent + v = defined?(MlxLm::VERSION) ? MlxLm::VERSION : "unknown" + "mlx-ruby-lm/#{v}" + end + + private_class_method :resolve_cache_dir, :normalize_patterns, :pattern_match?, + :fetch_model_info, :download_file, :user_agent + end +end diff --git a/lib/mlx_lm/load_utils.rb b/lib/mlx_lm/load_utils.rb index 02bbfa2..e5e8d73 100644 --- a/lib/mlx_lm/load_utils.rb +++ b/lib/mlx_lm/load_utils.rb @@ -4,17 +4,29 @@ module MlxLm module LoadUtils module_function - # Load a model and tokenizer from a local directory. + # Load a model and tokenizer from a local directory or Hugging Face repo id. # - # @param model_path [String] Path to the model directory + # @param model_path [String] Local directory path, or HF repo id like "org/model" # @param tokenizer_config [Hash] Additional tokenizer config overrides + # @param revision [String] HF branch/tag/sha; ignored for local paths # @return [Array(nn::Module, TokenizerWrapper)] The loaded model and tokenizer - def load(model_path, tokenizer_config: nil) - model, _config = load_model(model_path) - tokenizer = load_tokenizer(model_path) + def load(model_path, tokenizer_config: nil, revision: "main") + local_path = resolve_path(model_path, revision: revision) + model, _config = load_model(local_path) + tokenizer = load_tokenizer(local_path) [model, tokenizer] end + # Resolve a path-or-repo-id to a local directory, downloading from HF Hub if needed. + def resolve_path(path_or_repo, revision: "main") + return path_or_repo if File.directory?(path_or_repo) + Hub.snapshot_download( + path_or_repo, + revision: revision, + allow_patterns: ["*.json", "*.txt", "*.safetensors", "*.jinja", "*.model"], + ).to_s + end + # Load model from a local directory containing config.json and safetensors. # # @param model_path [String] Path to the model directory