Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,13 @@ and so forth, much faster than fine-tuning.
This is also useful for cybersecurity researchers who want to reduce a model's
willingness to provide dual-use or offensive security guidance.

For `ds4-server`, directional steering defaults to the tool-safe
`final-answer` policy: prompt prefill, thinking tokens, and DSML tool-call
syntax stay unsteered, while final visible answer prose uses the configured
direction. Use `--dir-steering-policy decoding` to leave only prefill
unsteered, `always` for the original always-on behavior, or `off` to disable
server-side steering.

## Test Vectors

`tests/test-vectors` contains short and long-context continuation vectors
Expand Down
14 changes: 14 additions & 0 deletions dir-steering/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,26 @@ With no steering file or zero scales, ds4 follows the normal inference path.
--dir-steering-file FILE load a 43 x 4096 f32 direction file
--dir-steering-ffn F apply steering after FFN outputs; default is 1 when a file is provided
--dir-steering-attn F apply steering after attention outputs; default is 0
--dir-steering-policy MODE server-only policy: final-answer, decoding, always, or off; default is final-answer
```

The FFN output is usually the best first target because it is late enough in
each layer to represent behavior, style, and topic signals. Attention steering
is available for experiments, but it can be more fragile.

For tool-using agents, `ds4-server` defaults to `--dir-steering-policy
final-answer`. This keeps prompt prefill, thinking tokens, and DSML tool-call
tokens unsteered. Steering is re-enabled only after generation has clearly
entered final natural-language answer text. This avoids letting a
behavior/style vector perturb tool-call grammar while still allowing the final
prose to use the configured direction.

`--dir-steering-policy decoding` is a middle ground for experiments that should
leave prompt/prefill activations untouched but steer every generated token,
including thinking and tool-call syntax. `always` restores the original
always-on behavior, and `off` disables directional steering at the server policy
layer.

## Verbosity Example

The bundled example builds a style direction from 100 paired prompts. Each pair
Expand Down
88 changes: 82 additions & 6 deletions ds4.c
Original file line number Diff line number Diff line change
Expand Up @@ -15588,6 +15588,9 @@ struct ds4_session {
int ctx_size;
bool checkpoint_valid;
bool mtp_draft_valid;
bool directional_steering_override;
float directional_steering_attn_scale;
float directional_steering_ffn_scale;
};

/* =========================================================================
Expand Down Expand Up @@ -15788,6 +15791,69 @@ static bool ds4_session_is_cpu(const ds4_session *s) {
return s && s->engine && s->engine->backend == DS4_BACKEND_CPU;
}

static void ds4_session_directional_steering_scales(const ds4_session *s,
float *attn,
float *ffn) {
float a = 0.0f;
float f = 0.0f;
if (s && s->engine) {
if (s->directional_steering_override) {
a = s->directional_steering_attn_scale;
f = s->directional_steering_ffn_scale;
} else {
a = s->engine->directional_steering_attn_scale;
f = s->engine->directional_steering_ffn_scale;
}
}
if (attn) *attn = a;
if (ffn) *ffn = f;
}

static void ds4_session_apply_directional_steering_to_backend(ds4_session *s) {
if (!s) return;
#ifndef DS4_NO_GPU
if (!ds4_session_is_cpu(s)) {
float attn = 0.0f;
float ffn = 0.0f;
ds4_session_directional_steering_scales(s, &attn, &ffn);
s->graph.directional_steering_attn_scale = attn;
s->graph.directional_steering_ffn_scale = ffn;
}
#else
(void)s;
#endif
}

static void ds4_session_set_directional_steering_state(ds4_session *s,
bool override,
float attn,
float ffn) {
if (!s) return;
float old_attn = 0.0f;
float old_ffn = 0.0f;
ds4_session_directional_steering_scales(s, &old_attn, &old_ffn);

s->directional_steering_override = override;
s->directional_steering_attn_scale = attn;
s->directional_steering_ffn_scale = ffn;

float new_attn = 0.0f;
float new_ffn = 0.0f;
ds4_session_directional_steering_scales(s, &new_attn, &new_ffn);
if (old_attn != new_attn || old_ffn != new_ffn) {
s->mtp_draft_valid = false;
}
ds4_session_apply_directional_steering_to_backend(s);
}

void ds4_session_set_directional_steering(ds4_session *s, float attn, float ffn) {
ds4_session_set_directional_steering_state(s, true, attn, ffn);
}

void ds4_session_use_engine_directional_steering(ds4_session *s) {
ds4_session_set_directional_steering_state(s, false, 0.0f, 0.0f);
}

static uint32_t session_cpu_raw_live_rows(const ds4_session *s) {
if (!s || !s->checkpoint_valid) return 0;
uint32_t rows = ds4_default_raw_cap((uint32_t)s->ctx_size);
Expand Down Expand Up @@ -17276,6 +17342,9 @@ int ds4_session_sync(ds4_session *s, const ds4_tokens *prompt, char *err, size_t
}
if (ds4_session_is_cpu(s)) {
ds4_engine *e = s->engine;
float steering_attn = 0.0f;
float steering_ffn = 0.0f;
ds4_session_directional_steering_scales(s, &steering_attn, &steering_ffn);
if (s->checkpoint_valid &&
prompt->len >= s->checkpoint.len &&
ds4_tokens_starts_with(prompt, &s->checkpoint))
Expand All @@ -17289,8 +17358,8 @@ int ds4_session_sync(ds4_session *s, const ds4_tokens *prompt, char *err, size_t
prompt->v[i],
(uint32_t)s->checkpoint.len,
e->directional_steering_dirs,
e->directional_steering_attn_scale,
e->directional_steering_ffn_scale,
steering_attn,
steering_ffn,
&s->cpu_scratch);
token_vec_push(&s->checkpoint, prompt->v[i]);
if (s->progress) s->progress(s->progress_ud, "prefill_chunk", i + 1, prompt->len);
Expand All @@ -17306,8 +17375,8 @@ int ds4_session_sync(ds4_session *s, const ds4_tokens *prompt, char *err, size_t
&s->cpu_cache,
prompt,
e->directional_steering_dirs,
e->directional_steering_attn_scale,
e->directional_steering_ffn_scale);
steering_attn,
steering_ffn);
ds4_tokens_copy(&s->checkpoint, prompt);
s->checkpoint_valid = true;
s->mtp_draft_valid = false;
Expand Down Expand Up @@ -17560,15 +17629,18 @@ static int ds4_session_eval_internal(ds4_session *s, int token, bool probe_mtp,
if (!s) return 1;
if (ds4_session_is_cpu(s)) {
ds4_engine *e = s->engine;
float steering_attn = 0.0f;
float steering_ffn = 0.0f;
ds4_session_directional_steering_scales(s, &steering_attn, &steering_ffn);
forward_token_raw_swa_cpu_decode_scratch(s->logits,
&e->model,
&e->weights,
&s->cpu_cache,
token,
(uint32_t)s->checkpoint.len,
e->directional_steering_dirs,
e->directional_steering_attn_scale,
e->directional_steering_ffn_scale,
steering_attn,
steering_ffn,
&s->cpu_scratch);
token_vec_push(&s->checkpoint, token);
s->checkpoint_valid = true;
Expand Down Expand Up @@ -17636,6 +17708,10 @@ int ds4_session_eval(ds4_session *s, int token, char *err, size_t errlen) {
return ds4_session_eval_internal(s, token, true, err, errlen);
}

int ds4_session_eval_no_mtp(ds4_session *s, int token, char *err, size_t errlen) {
return ds4_session_eval_internal(s, token, false, err, errlen);
}

/* Speculative decode state machine:
* 1. commit the normal target token and use its logits to validate draft[0];
* 2. let MTP recursively draft a tiny suffix from its own raw-cache frontier;
Expand Down
3 changes: 3 additions & 0 deletions ds4.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ int ds4_token_eos(ds4_engine *e);
int ds4_session_create(ds4_session **out, ds4_engine *e, int ctx_size);
void ds4_session_free(ds4_session *s);
void ds4_session_set_progress(ds4_session *s, ds4_session_progress_fn fn, void *ud);
void ds4_session_set_directional_steering(ds4_session *s, float attn, float ffn);
void ds4_session_use_engine_directional_steering(ds4_session *s);

typedef enum {
DS4_SESSION_REWRITE_ERROR = -1,
Expand All @@ -169,6 +171,7 @@ int ds4_session_sample(ds4_session *s, float temperature, int top_k, float top_p
int ds4_session_top_logprobs(ds4_session *s, ds4_token_score *out, int k);
int ds4_session_token_logprob(ds4_session *s, int token, ds4_token_score *out);
int ds4_session_eval(ds4_session *s, int token, char *err, size_t errlen);
int ds4_session_eval_no_mtp(ds4_session *s, int token, char *err, size_t errlen);
int ds4_session_eval_speculative_argmax(ds4_session *s, int first_token,
int max_tokens, int eos_token,
int *accepted, int accepted_cap,
Expand Down
Loading