diff --git a/lib/mlx_lm.rb b/lib/mlx_lm.rb index fc06890..10bfc97 100644 --- a/lib/mlx_lm.rb +++ b/lib/mlx_lm.rb @@ -127,6 +127,7 @@ require_relative "mlx_lm/quantize" require_relative "mlx_lm/quant/awq" require_relative "mlx_lm/quant/gptq" +require_relative "mlx_lm/hub" require_relative "mlx_lm/load_utils" require_relative "mlx_lm/evaluate" require_relative "mlx_lm/tuner/callbacks" diff --git a/lib/mlx_lm/hub.rb b/lib/mlx_lm/hub.rb new file mode 100644 index 0000000..470c888 --- /dev/null +++ b/lib/mlx_lm/hub.rb @@ -0,0 +1,131 @@ +require "fileutils" +require "json" +require "net/http" +require "pathname" +require "uri" + +module MlxLm + # Minimal pure-Ruby equivalent of huggingface_hub.snapshot_download. + # Cache layout matches the Python client so caches are mutually reusable. + module Hub + DEFAULT_ENDPOINT = "https://huggingface.co".freeze + MAX_REDIRECTS = 5 + DEFAULT_TIMEOUT = 300 + + module_function + + # Download a model snapshot from Hugging Face Hub. + # + # @param repo_id [String] e.g. "mlx-community/Llama-3.2-1B-Instruct-4bit" + # @param revision [String] branch, tag, or commit SHA (default "main") + # @param allow_patterns [Array, String, nil] glob patterns to include + # @param token [String, nil] defaults to ENV["HF_TOKEN"] + # @param cache_dir [String, Pathname, nil] defaults to HF_HUB_CACHE / HF_HOME/hub / ~/.cache/huggingface/hub + # @param endpoint [String, nil] defaults to HF_ENDPOINT or huggingface.co + # @return [Pathname] absolute path to the snapshot directory + def snapshot_download(repo_id, revision: "main", allow_patterns: nil, token: nil, cache_dir: nil, endpoint: nil) + endpoint ||= ENV["HF_ENDPOINT"] || DEFAULT_ENDPOINT + token ||= ENV["HF_TOKEN"] + cache_dir = resolve_cache_dir(cache_dir) + patterns = normalize_patterns(allow_patterns) + + info = fetch_model_info(endpoint, repo_id, revision, token) + sha = info.fetch("sha") + siblings = info.fetch("siblings").map { |s| s.fetch("rfilename") } + + repo_folder = cache_dir.join("models--#{repo_id.gsub("/", "--")}") + snapshot_dir = repo_folder.join("snapshots", sha) + FileUtils.mkdir_p(snapshot_dir) + FileUtils.mkdir_p(repo_folder.join("refs")) + File.write(repo_folder.join("refs", revision), sha) + + siblings.each do |rel| + next unless pattern_match?(rel, patterns) + target = snapshot_dir.join(rel) + next if target.file? && target.size > 0 + + FileUtils.mkdir_p(target.dirname) + url = "#{endpoint}/#{repo_id}/resolve/#{revision}/#{rel}" + download_file(url, target.to_s, token) + end + + snapshot_dir + end + + def resolve_cache_dir(explicit) + return Pathname.new(explicit) if explicit + if (v = ENV["HF_HUB_CACHE"]) && !v.empty? + Pathname.new(v) + elsif (v = ENV["HF_HOME"]) && !v.empty? + Pathname.new(v).join("hub") + else + Pathname.new(Dir.home).join(".cache", "huggingface", "hub") + end + end + + def normalize_patterns(p) + return nil if p.nil? + p.is_a?(::Array) ? p : [p] + end + + def pattern_match?(filename, patterns) + return true if patterns.nil? || patterns.empty? + patterns.any? do |pat| + File.fnmatch(pat, filename, File::FNM_PATHNAME) || + File.fnmatch(pat, File.basename(filename)) + end + end + + def fetch_model_info(endpoint, repo_id, revision, token) + url = "#{endpoint}/api/models/#{repo_id}/revision/#{revision}" + body = http_get_body(url, token) + JSON.parse(body) + end + + def http_get_body(url, token, limit = MAX_REDIRECTS) + raise "Too many redirects fetching #{url}" if limit <= 0 + uri = URI.parse(url) + Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http| + req = Net::HTTP::Get.new(uri.request_uri) + req["Authorization"] = "Bearer #{token}" if token && uri.host.end_with?("huggingface.co") + req["User-Agent"] = user_agent + resp = http.request(req) + case resp + when Net::HTTPSuccess then resp.body + when Net::HTTPRedirection then http_get_body(URI.join(url, resp["location"]).to_s, token, limit - 1) + else raise "HTTP #{resp.code} #{resp.message} fetching #{url}: #{resp.body}" + end + end + end + + def download_file(url, output_path, token = nil, limit = MAX_REDIRECTS) + raise "Too many redirects fetching #{url}" if limit <= 0 + uri = URI.parse(url) + Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https", read_timeout: DEFAULT_TIMEOUT) do |http| + req = Net::HTTP::Get.new(uri.request_uri) + req["Authorization"] = "Bearer #{token}" if token && uri.host.end_with?("huggingface.co") + req["User-Agent"] = user_agent + http.request(req) do |resp| + case resp + when Net::HTTPSuccess + tmp = "#{output_path}.download.#{Process.pid}" + File.open(tmp, "wb") { |f| resp.read_body { |chunk| f.write(chunk) } } + FileUtils.mv(tmp, output_path) + when Net::HTTPRedirection + return download_file(URI.join(url, resp["location"]).to_s, output_path, token, limit - 1) + else + raise "HTTP #{resp.code} #{resp.message} fetching #{url}" + end + end + end + end + + def user_agent + v = defined?(MlxLm::VERSION) ? MlxLm::VERSION : "unknown" + "mlx-ruby-lm/#{v}" + end + + private_class_method :resolve_cache_dir, :normalize_patterns, :pattern_match?, + :fetch_model_info, :download_file, :user_agent + end +end diff --git a/lib/mlx_lm/load_utils.rb b/lib/mlx_lm/load_utils.rb index 02bbfa2..e5e8d73 100644 --- a/lib/mlx_lm/load_utils.rb +++ b/lib/mlx_lm/load_utils.rb @@ -4,17 +4,29 @@ module MlxLm module LoadUtils module_function - # Load a model and tokenizer from a local directory. + # Load a model and tokenizer from a local directory or Hugging Face repo id. # - # @param model_path [String] Path to the model directory + # @param model_path [String] Local directory path, or HF repo id like "org/model" # @param tokenizer_config [Hash] Additional tokenizer config overrides + # @param revision [String] HF branch/tag/sha; ignored for local paths # @return [Array(nn::Module, TokenizerWrapper)] The loaded model and tokenizer - def load(model_path, tokenizer_config: nil) - model, _config = load_model(model_path) - tokenizer = load_tokenizer(model_path) + def load(model_path, tokenizer_config: nil, revision: "main") + local_path = resolve_path(model_path, revision: revision) + model, _config = load_model(local_path) + tokenizer = load_tokenizer(local_path) [model, tokenizer] end + # Resolve a path-or-repo-id to a local directory, downloading from HF Hub if needed. + def resolve_path(path_or_repo, revision: "main") + return path_or_repo if File.directory?(path_or_repo) + Hub.snapshot_download( + path_or_repo, + revision: revision, + allow_patterns: ["*.json", "*.txt", "*.safetensors", "*.jinja", "*.model"], + ).to_s + end + # Load model from a local directory containing config.json and safetensors. # # @param model_path [String] Path to the model directory diff --git a/lib/mlx_lm/weight_utils.rb b/lib/mlx_lm/weight_utils.rb index cffa5c4..abdcf36 100644 --- a/lib/mlx_lm/weight_utils.rb +++ b/lib/mlx_lm/weight_utils.rb @@ -51,30 +51,21 @@ def _tensor_to_mlx(info, mx) dtype_str = info["dtype"] data = info["data"] - # For F32/float32, unpack as little-endian floats - if dtype_str == "F32" || dtype_str == "float32" - values = data.unpack("e*") - mx.array(values).reshape(shape) - elsif dtype_str == "F16" - # 16-bit float: unpack as uint16, create array as float32, then view as float16 - values = data.unpack("S<*") - mx.array(values, dtype: mx.uint16).view(mx.float16).reshape(shape) - elsif dtype_str == "BF16" - values = data.unpack("S<*") - mx.array(values, dtype: mx.uint16).view(mx.bfloat16).reshape(shape) - elsif dtype_str == "I32" || dtype_str == "int32" - values = data.unpack("l<*") - mx.array(values, dtype: mx.int32).reshape(shape) - elsif dtype_str == "I64" - values = data.unpack("q<*") - mx.array(values, dtype: mx.int64).reshape(shape) - elsif dtype_str == "U8" - values = data.unpack("C*") - mx.array(values, dtype: mx.uint8).reshape(shape) + dtype_str = "F32" if dtype_str == "float32" + dtype_str = "I32" if dtype_str == "int32" + + # F16/BF16 lack a direct unpack path; stage through uint16 + .view. + if dtype_str == "F16" || dtype_str == "BF16" + view_dtype = dtype_str == "F16" ? mx.float16 : mx.bfloat16 + return mx.array(data.unpack("S<*"), dtype: mx.uint16).view(view_dtype).reshape(shape) + end + + format_str, dtype_sym = DTYPE_UNPACK[dtype_str] + if format_str + mx.array(data.unpack(format_str), dtype: mx.__send__(dtype_sym)).reshape(shape) else - # Fallback: try F32 - values = data.unpack("e*") - mx.array(values).reshape(shape) + # Unknown dtype — interpret raw bytes as little-endian F32. + mx.array(data.unpack("e*")).reshape(shape) end end