Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/mlx_lm.rb
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
require_relative "mlx_lm/quantize"
require_relative "mlx_lm/quant/awq"
require_relative "mlx_lm/quant/gptq"
require_relative "mlx_lm/hub"
require_relative "mlx_lm/load_utils"
require_relative "mlx_lm/evaluate"
require_relative "mlx_lm/tuner/callbacks"
Expand Down
131 changes: 131 additions & 0 deletions lib/mlx_lm/hub.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
require "fileutils"
require "json"
require "net/http"
require "pathname"
require "uri"

module MlxLm
# Minimal pure-Ruby equivalent of huggingface_hub.snapshot_download.
# Cache layout matches the Python client so caches are mutually reusable.
module Hub
DEFAULT_ENDPOINT = "https://huggingface.co".freeze
MAX_REDIRECTS = 5
DEFAULT_TIMEOUT = 300

module_function

# Download a model snapshot from Hugging Face Hub.
#
# @param repo_id [String] e.g. "mlx-community/Llama-3.2-1B-Instruct-4bit"
# @param revision [String] branch, tag, or commit SHA (default "main")
# @param allow_patterns [Array<String>, String, nil] glob patterns to include
# @param token [String, nil] defaults to ENV["HF_TOKEN"]
# @param cache_dir [String, Pathname, nil] defaults to HF_HUB_CACHE / HF_HOME/hub / ~/.cache/huggingface/hub
# @param endpoint [String, nil] defaults to HF_ENDPOINT or huggingface.co
# @return [Pathname] absolute path to the snapshot directory
def snapshot_download(repo_id, revision: "main", allow_patterns: nil, token: nil, cache_dir: nil, endpoint: nil)
endpoint ||= ENV["HF_ENDPOINT"] || DEFAULT_ENDPOINT
token ||= ENV["HF_TOKEN"]
cache_dir = resolve_cache_dir(cache_dir)
patterns = normalize_patterns(allow_patterns)

info = fetch_model_info(endpoint, repo_id, revision, token)
sha = info.fetch("sha")
siblings = info.fetch("siblings").map { |s| s.fetch("rfilename") }

repo_folder = cache_dir.join("models--#{repo_id.gsub("/", "--")}")
snapshot_dir = repo_folder.join("snapshots", sha)
FileUtils.mkdir_p(snapshot_dir)
FileUtils.mkdir_p(repo_folder.join("refs"))
File.write(repo_folder.join("refs", revision), sha)

siblings.each do |rel|
next unless pattern_match?(rel, patterns)
target = snapshot_dir.join(rel)
next if target.file? && target.size > 0

FileUtils.mkdir_p(target.dirname)
url = "#{endpoint}/#{repo_id}/resolve/#{revision}/#{rel}"
download_file(url, target.to_s, token)
end

snapshot_dir
end

def resolve_cache_dir(explicit)
return Pathname.new(explicit) if explicit
if (v = ENV["HF_HUB_CACHE"]) && !v.empty?
Pathname.new(v)
elsif (v = ENV["HF_HOME"]) && !v.empty?
Pathname.new(v).join("hub")
else
Pathname.new(Dir.home).join(".cache", "huggingface", "hub")
end
end

def normalize_patterns(p)
return nil if p.nil?
p.is_a?(::Array) ? p : [p]
end

def pattern_match?(filename, patterns)
return true if patterns.nil? || patterns.empty?
patterns.any? do |pat|
File.fnmatch(pat, filename, File::FNM_PATHNAME) ||
File.fnmatch(pat, File.basename(filename))
end
end

def fetch_model_info(endpoint, repo_id, revision, token)
url = "#{endpoint}/api/models/#{repo_id}/revision/#{revision}"
body = http_get_body(url, token)
JSON.parse(body)
end

def http_get_body(url, token, limit = MAX_REDIRECTS)
raise "Too many redirects fetching #{url}" if limit <= 0
uri = URI.parse(url)
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
req = Net::HTTP::Get.new(uri.request_uri)
req["Authorization"] = "Bearer #{token}" if token && uri.host.end_with?("huggingface.co")
req["User-Agent"] = user_agent
resp = http.request(req)
case resp
when Net::HTTPSuccess then resp.body
when Net::HTTPRedirection then http_get_body(URI.join(url, resp["location"]).to_s, token, limit - 1)
else raise "HTTP #{resp.code} #{resp.message} fetching #{url}: #{resp.body}"
end
end
end

def download_file(url, output_path, token = nil, limit = MAX_REDIRECTS)
raise "Too many redirects fetching #{url}" if limit <= 0
uri = URI.parse(url)
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https", read_timeout: DEFAULT_TIMEOUT) do |http|
req = Net::HTTP::Get.new(uri.request_uri)
req["Authorization"] = "Bearer #{token}" if token && uri.host.end_with?("huggingface.co")
req["User-Agent"] = user_agent
http.request(req) do |resp|
case resp
when Net::HTTPSuccess
tmp = "#{output_path}.download.#{Process.pid}"
File.open(tmp, "wb") { |f| resp.read_body { |chunk| f.write(chunk) } }
FileUtils.mv(tmp, output_path)
when Net::HTTPRedirection
return download_file(URI.join(url, resp["location"]).to_s, output_path, token, limit - 1)
else
raise "HTTP #{resp.code} #{resp.message} fetching #{url}"
end
end
end
end

def user_agent
v = defined?(MlxLm::VERSION) ? MlxLm::VERSION : "unknown"
"mlx-ruby-lm/#{v}"
end

private_class_method :resolve_cache_dir, :normalize_patterns, :pattern_match?,
:fetch_model_info, :download_file, :user_agent
end
end
22 changes: 17 additions & 5 deletions lib/mlx_lm/load_utils.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,29 @@ module MlxLm
module LoadUtils
module_function

# Load a model and tokenizer from a local directory.
# Load a model and tokenizer from a local directory or Hugging Face repo id.
#
# @param model_path [String] Path to the model directory
# @param model_path [String] Local directory path, or HF repo id like "org/model"
# @param tokenizer_config [Hash] Additional tokenizer config overrides
# @param revision [String] HF branch/tag/sha; ignored for local paths
# @return [Array(nn::Module, TokenizerWrapper)] The loaded model and tokenizer
def load(model_path, tokenizer_config: nil)
model, _config = load_model(model_path)
tokenizer = load_tokenizer(model_path)
def load(model_path, tokenizer_config: nil, revision: "main")
local_path = resolve_path(model_path, revision: revision)
model, _config = load_model(local_path)
tokenizer = load_tokenizer(local_path)
[model, tokenizer]
end

# Resolve a path-or-repo-id to a local directory, downloading from HF Hub if needed.
def resolve_path(path_or_repo, revision: "main")
return path_or_repo if File.directory?(path_or_repo)
Hub.snapshot_download(
path_or_repo,
revision: revision,
allow_patterns: ["*.json", "*.txt", "*.safetensors", "*.jinja", "*.model"],
).to_s
end

# Load model from a local directory containing config.json and safetensors.
#
# @param model_path [String] Path to the model directory
Expand Down
37 changes: 14 additions & 23 deletions lib/mlx_lm/weight_utils.rb
Original file line number Diff line number Diff line change
Expand Up @@ -51,30 +51,21 @@ def _tensor_to_mlx(info, mx)
dtype_str = info["dtype"]
data = info["data"]

# For F32/float32, unpack as little-endian floats
if dtype_str == "F32" || dtype_str == "float32"
values = data.unpack("e*")
mx.array(values).reshape(shape)
elsif dtype_str == "F16"
# 16-bit float: unpack as uint16, create array as float32, then view as float16
values = data.unpack("S<*")
mx.array(values, dtype: mx.uint16).view(mx.float16).reshape(shape)
elsif dtype_str == "BF16"
values = data.unpack("S<*")
mx.array(values, dtype: mx.uint16).view(mx.bfloat16).reshape(shape)
elsif dtype_str == "I32" || dtype_str == "int32"
values = data.unpack("l<*")
mx.array(values, dtype: mx.int32).reshape(shape)
elsif dtype_str == "I64"
values = data.unpack("q<*")
mx.array(values, dtype: mx.int64).reshape(shape)
elsif dtype_str == "U8"
values = data.unpack("C*")
mx.array(values, dtype: mx.uint8).reshape(shape)
dtype_str = "F32" if dtype_str == "float32"
dtype_str = "I32" if dtype_str == "int32"

# F16/BF16 lack a direct unpack path; stage through uint16 + .view.
if dtype_str == "F16" || dtype_str == "BF16"
view_dtype = dtype_str == "F16" ? mx.float16 : mx.bfloat16
return mx.array(data.unpack("S<*"), dtype: mx.uint16).view(view_dtype).reshape(shape)
end

format_str, dtype_sym = DTYPE_UNPACK[dtype_str]
if format_str
mx.array(data.unpack(format_str), dtype: mx.__send__(dtype_sym)).reshape(shape)
else
# Fallback: try F32
values = data.unpack("e*")
mx.array(values).reshape(shape)
# Unknown dtype — interpret raw bytes as little-endian F32.
mx.array(data.unpack("e*")).reshape(shape)
end
end

Expand Down