Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 113 additions & 6 deletions dflash/src/safetensors_draft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,10 @@ ggml_tensor * alloc_tensor(ggml_context * ctx,
return nullptr;
}
const StEntry & e = it->second;
if (e.dtype != dtype_expected) {
(void)dtype_expected;
if (e.dtype != "BF16" && e.dtype != "F16") {
set_last_error("safetensors: '" + name + "' dtype=" + e.dtype +
" expected " + dtype_expected);
" expected BF16 or F16");
return nullptr;
}
if (e.shape.size() != expected_shape.size()) {
Expand All @@ -277,10 +278,10 @@ ggml_tensor * alloc_tensor(ggml_context * ctx,
}
}
ggml_type gt = (gt_override == GGML_TYPE_COUNT)
? st_dtype_to_ggml(dtype_expected)
? st_dtype_to_ggml(e.dtype)
: gt_override;
if (gt == GGML_TYPE_COUNT) {
set_last_error("safetensors: unsupported dtype " + dtype_expected);
set_last_error("safetensors: unsupported dtype " + e.dtype);
return nullptr;
}

Expand All @@ -307,6 +308,40 @@ static void bf16_to_f32_array(const uint16_t * src, float * dst, size_t n) {
}
}

static float f16_to_f32(uint16_t h) {
const uint32_t sign = (uint32_t)(h & 0x8000) << 16;
int32_t exp = (h >> 10) & 0x1F;
uint32_t mant = h & 0x03FF;
uint32_t out;
if (exp == 0) {
if (mant == 0) {
out = sign;
} else {
exp = 1;
while ((mant & 0x0400) == 0) {
mant <<= 1;
exp--;
}
mant &= 0x03FF;
const uint32_t exp32 = (uint32_t)(exp + (127 - 15));
out = sign | (exp32 << 23) | (mant << 13);
}
} else if (exp == 31) {
out = sign | 0x7F800000u | (mant << 13);
} else {
out = sign | ((uint32_t)(exp + (127 - 15)) << 23) | (mant << 13);
}
float f;
std::memcpy(&f, &out, sizeof(f));
return f;
}

static void f16_to_f32_array(const uint16_t * src, float * dst, size_t n) {
for (size_t i = 0; i < n; i++) {
dst[i] = f16_to_f32(src[i]);
}
}

// Convert an array of bf16 values to fp16 via f32 intermediate.
static void bf16_to_f16_array(const uint16_t * src, uint16_t * dst, size_t n) {
for (size_t i = 0; i < n; i++) {
Expand Down Expand Up @@ -433,14 +468,76 @@ bool load_draft_safetensors(const std::string & path,
}
}

// ── 4b. Read config.json for SWA layer_types (Qwen3.6 draft) ──
{
// config.json sits next to model.safetensors
std::string dir;
auto slash = path.find_last_of("/\\");
if (slash != std::string::npos) {
dir = path.substr(0, slash);
} else {
dir = "."; // bare filename — look in CWD
}
std::string cfg_path = dir + "/config.json";
FILE * f = std::fopen(cfg_path.c_str(), "r");
if (f) {
std::fseek(f, 0, SEEK_END);
long flen = std::ftell(f);
std::fseek(f, 0, SEEK_SET);
std::string cfg(flen, '\0');
std::fread(&cfg[0], 1, flen, f);
std::fclose(f);

// Parse sliding_window
auto sw_pos = cfg.find("\"sliding_window\"");
if (sw_pos != std::string::npos) {
auto colon = cfg.find(':', sw_pos);
if (colon != std::string::npos) {
int sw = std::atoi(cfg.c_str() + colon + 1);
if (sw > 0) out.swa_window = sw;
}
}

// Parse layer_types array
auto lt_pos = cfg.find("\"layer_types\"");
if (lt_pos != std::string::npos) {
auto arr_start = cfg.find('[', lt_pos);
auto arr_end = cfg.find(']', arr_start);
if (arr_start != std::string::npos && arr_end != std::string::npos) {
std::string arr = cfg.substr(arr_start, arr_end - arr_start + 1);
int li = 0;
size_t search_pos = 0;
while (li < n_layers && search_pos < arr.size()) {
auto q1 = arr.find('"', search_pos);
if (q1 == std::string::npos) break;
auto q2 = arr.find('"', q1 + 1);
if (q2 == std::string::npos) break;
std::string lt = arr.substr(q1 + 1, q2 - q1 - 1);
out.layers[li].is_swa = (lt == "sliding_attention");
li++;
search_pos = q2 + 1;
}
}
}

int n_swa = 0;
for (int il = 0; il < n_layers; il++) {
if (out.layers[il].is_swa) n_swa++;
}
if (n_swa > 0) {
fprintf(stderr, "[draft] SWA layers: %d/%d (window=%d)\n", n_swa, n_layers, out.swa_window);
}
}
}

// ── 5. Allocate backend buffer, copy bytes ───────────────────
out.buf = ggml_backend_alloc_ctx_tensors(out.ctx, backend);
if (!out.buf) { set_last_error("ggml_backend_alloc_ctx_tensors failed (draft)"); return false; }

// Walk the tensors in the context and upload their bytes.
// For tensors whose ggml type differs from the safetensors dtype (i.e.
// BF16-on-disk, F32-in-ggml for norms, or BF16-on-disk, F16-in-ggml for
// projection weights on Turing), convert on the fly via scratch buffers.
// BF16/F16-on-disk, F32-in-ggml for norms, or BF16-on-disk, F16-in-ggml
// for projection weights on Turing), convert on the fly via scratch buffers.
std::vector<float> scratch_f32;
std::vector<uint16_t> scratch_f16;
for (ggml_tensor * t = ggml_get_first_tensor(out.ctx); t != nullptr;
Expand Down Expand Up @@ -480,6 +577,16 @@ bool load_draft_safetensors(const std::string & path,
bf16_to_f32_array((const uint16_t *)(blob + e.data_start),
scratch_f32.data(), n);
ggml_backend_tensor_set(t, scratch_f32.data(), 0, dst_nbytes);
} else if (e.dtype == "F16" && t->type == GGML_TYPE_F32) {
const size_t n = ggml_nelements(t);
if (src_nbytes != n * sizeof(uint16_t) || dst_nbytes != n * sizeof(float)) {
set_last_error("F16->F32 size mismatch for '" + std::string(name) + "'");
return false;
}
scratch_f32.resize(n);
f16_to_f32_array((const uint16_t *)(blob + e.data_start),
scratch_f32.data(), n);
ggml_backend_tensor_set(t, scratch_f32.data(), 0, dst_nbytes);
} else if (e.dtype == "BF16" && t->type == GGML_TYPE_F16) {
const size_t n = ggml_nelements(t);
if (src_nbytes != n * sizeof(uint16_t) || dst_nbytes != n * sizeof(uint16_t)) {
Expand Down