diff --git a/README.md b/README.md index ea02bf39..1e09fb79 100644 --- a/README.md +++ b/README.md @@ -719,6 +719,13 @@ and so forth, much faster than fine-tuning. This is also useful for cybersecurity researchers who want to reduce a model's willingness to provide dual-use or offensive security guidance. +For `ds4-server`, directional steering defaults to the tool-safe +`final-answer` policy: prompt prefill, thinking tokens, and DSML tool-call +syntax stay unsteered, while final visible answer prose uses the configured +direction. Use `--dir-steering-policy decoding` to leave only prefill +unsteered, `always` for the original always-on behavior, or `off` to disable +server-side steering. + ## Test Vectors `tests/test-vectors` contains short and long-context continuation vectors diff --git a/dir-steering/README.md b/dir-steering/README.md index e1fdbfe5..3a671297 100644 --- a/dir-steering/README.md +++ b/dir-steering/README.md @@ -17,12 +17,26 @@ With no steering file or zero scales, ds4 follows the normal inference path. --dir-steering-file FILE load a 43 x 4096 f32 direction file --dir-steering-ffn F apply steering after FFN outputs; default is 1 when a file is provided --dir-steering-attn F apply steering after attention outputs; default is 0 +--dir-steering-policy MODE server-only policy: final-answer, decoding, always, or off; default is final-answer ``` The FFN output is usually the best first target because it is late enough in each layer to represent behavior, style, and topic signals. Attention steering is available for experiments, but it can be more fragile. +For tool-using agents, `ds4-server` defaults to `--dir-steering-policy +final-answer`. This keeps prompt prefill, thinking tokens, and DSML tool-call +tokens unsteered. Steering is re-enabled only after generation has clearly +entered final natural-language answer text. This avoids letting a +behavior/style vector perturb tool-call grammar while still allowing the final +prose to use the configured direction. + +`--dir-steering-policy decoding` is a middle ground for experiments that should +leave prompt/prefill activations untouched but steer every generated token, +including thinking and tool-call syntax. `always` restores the original +always-on behavior, and `off` disables directional steering at the server policy +layer. + ## Verbosity Example The bundled example builds a style direction from 100 paired prompts. Each pair diff --git a/ds4.c b/ds4.c index 8825c257..c745990b 100644 --- a/ds4.c +++ b/ds4.c @@ -15588,6 +15588,9 @@ struct ds4_session { int ctx_size; bool checkpoint_valid; bool mtp_draft_valid; + bool directional_steering_override; + float directional_steering_attn_scale; + float directional_steering_ffn_scale; }; /* ========================================================================= @@ -15788,6 +15791,69 @@ static bool ds4_session_is_cpu(const ds4_session *s) { return s && s->engine && s->engine->backend == DS4_BACKEND_CPU; } +static void ds4_session_directional_steering_scales(const ds4_session *s, + float *attn, + float *ffn) { + float a = 0.0f; + float f = 0.0f; + if (s && s->engine) { + if (s->directional_steering_override) { + a = s->directional_steering_attn_scale; + f = s->directional_steering_ffn_scale; + } else { + a = s->engine->directional_steering_attn_scale; + f = s->engine->directional_steering_ffn_scale; + } + } + if (attn) *attn = a; + if (ffn) *ffn = f; +} + +static void ds4_session_apply_directional_steering_to_backend(ds4_session *s) { + if (!s) return; +#ifndef DS4_NO_GPU + if (!ds4_session_is_cpu(s)) { + float attn = 0.0f; + float ffn = 0.0f; + ds4_session_directional_steering_scales(s, &attn, &ffn); + s->graph.directional_steering_attn_scale = attn; + s->graph.directional_steering_ffn_scale = ffn; + } +#else + (void)s; +#endif +} + +static void ds4_session_set_directional_steering_state(ds4_session *s, + bool override, + float attn, + float ffn) { + if (!s) return; + float old_attn = 0.0f; + float old_ffn = 0.0f; + ds4_session_directional_steering_scales(s, &old_attn, &old_ffn); + + s->directional_steering_override = override; + s->directional_steering_attn_scale = attn; + s->directional_steering_ffn_scale = ffn; + + float new_attn = 0.0f; + float new_ffn = 0.0f; + ds4_session_directional_steering_scales(s, &new_attn, &new_ffn); + if (old_attn != new_attn || old_ffn != new_ffn) { + s->mtp_draft_valid = false; + } + ds4_session_apply_directional_steering_to_backend(s); +} + +void ds4_session_set_directional_steering(ds4_session *s, float attn, float ffn) { + ds4_session_set_directional_steering_state(s, true, attn, ffn); +} + +void ds4_session_use_engine_directional_steering(ds4_session *s) { + ds4_session_set_directional_steering_state(s, false, 0.0f, 0.0f); +} + static uint32_t session_cpu_raw_live_rows(const ds4_session *s) { if (!s || !s->checkpoint_valid) return 0; uint32_t rows = ds4_default_raw_cap((uint32_t)s->ctx_size); @@ -17276,6 +17342,9 @@ int ds4_session_sync(ds4_session *s, const ds4_tokens *prompt, char *err, size_t } if (ds4_session_is_cpu(s)) { ds4_engine *e = s->engine; + float steering_attn = 0.0f; + float steering_ffn = 0.0f; + ds4_session_directional_steering_scales(s, &steering_attn, &steering_ffn); if (s->checkpoint_valid && prompt->len >= s->checkpoint.len && ds4_tokens_starts_with(prompt, &s->checkpoint)) @@ -17289,8 +17358,8 @@ int ds4_session_sync(ds4_session *s, const ds4_tokens *prompt, char *err, size_t prompt->v[i], (uint32_t)s->checkpoint.len, e->directional_steering_dirs, - e->directional_steering_attn_scale, - e->directional_steering_ffn_scale, + steering_attn, + steering_ffn, &s->cpu_scratch); token_vec_push(&s->checkpoint, prompt->v[i]); if (s->progress) s->progress(s->progress_ud, "prefill_chunk", i + 1, prompt->len); @@ -17306,8 +17375,8 @@ int ds4_session_sync(ds4_session *s, const ds4_tokens *prompt, char *err, size_t &s->cpu_cache, prompt, e->directional_steering_dirs, - e->directional_steering_attn_scale, - e->directional_steering_ffn_scale); + steering_attn, + steering_ffn); ds4_tokens_copy(&s->checkpoint, prompt); s->checkpoint_valid = true; s->mtp_draft_valid = false; @@ -17560,6 +17629,9 @@ static int ds4_session_eval_internal(ds4_session *s, int token, bool probe_mtp, if (!s) return 1; if (ds4_session_is_cpu(s)) { ds4_engine *e = s->engine; + float steering_attn = 0.0f; + float steering_ffn = 0.0f; + ds4_session_directional_steering_scales(s, &steering_attn, &steering_ffn); forward_token_raw_swa_cpu_decode_scratch(s->logits, &e->model, &e->weights, @@ -17567,8 +17639,8 @@ static int ds4_session_eval_internal(ds4_session *s, int token, bool probe_mtp, token, (uint32_t)s->checkpoint.len, e->directional_steering_dirs, - e->directional_steering_attn_scale, - e->directional_steering_ffn_scale, + steering_attn, + steering_ffn, &s->cpu_scratch); token_vec_push(&s->checkpoint, token); s->checkpoint_valid = true; @@ -17636,6 +17708,10 @@ int ds4_session_eval(ds4_session *s, int token, char *err, size_t errlen) { return ds4_session_eval_internal(s, token, true, err, errlen); } +int ds4_session_eval_no_mtp(ds4_session *s, int token, char *err, size_t errlen) { + return ds4_session_eval_internal(s, token, false, err, errlen); +} + /* Speculative decode state machine: * 1. commit the normal target token and use its logits to validate draft[0]; * 2. let MTP recursively draft a tiny suffix from its own raw-cache frontier; diff --git a/ds4.h b/ds4.h index 950d8dca..bf40ec4c 100644 --- a/ds4.h +++ b/ds4.h @@ -145,6 +145,8 @@ int ds4_token_eos(ds4_engine *e); int ds4_session_create(ds4_session **out, ds4_engine *e, int ctx_size); void ds4_session_free(ds4_session *s); void ds4_session_set_progress(ds4_session *s, ds4_session_progress_fn fn, void *ud); +void ds4_session_set_directional_steering(ds4_session *s, float attn, float ffn); +void ds4_session_use_engine_directional_steering(ds4_session *s); typedef enum { DS4_SESSION_REWRITE_ERROR = -1, @@ -169,6 +171,7 @@ int ds4_session_sample(ds4_session *s, float temperature, int top_k, float top_p int ds4_session_top_logprobs(ds4_session *s, ds4_token_score *out, int k); int ds4_session_token_logprob(ds4_session *s, int token, ds4_token_score *out); int ds4_session_eval(ds4_session *s, int token, char *err, size_t errlen); +int ds4_session_eval_no_mtp(ds4_session *s, int token, char *err, size_t errlen); int ds4_session_eval_speculative_argmax(ds4_session *s, int first_token, int max_tokens, int eos_token, int *accepted, int accepted_cap, diff --git a/ds4_server.c b/ds4_server.c index 9ebaa8b7..ef99bd42 100644 --- a/ds4_server.c +++ b/ds4_server.c @@ -478,6 +478,13 @@ typedef enum { API_RESPONSES, } api_style; +typedef enum { + DS4_STEERING_POLICY_ALWAYS, + DS4_STEERING_POLICY_DECODING, + DS4_STEERING_POLICY_FINAL_ANSWER, + DS4_STEERING_POLICY_OFF, +} directional_steering_policy; + static void random_tool_id(char *dst, size_t dstlen, api_style api) { static uint64_t fallback_ctr; unsigned char bytes[16]; @@ -5035,6 +5042,19 @@ static size_t dsml_max_tool_start_len(void) { return max; } +static bool dsml_text_ends_with_partial_tool_start(const char *raw, size_t raw_len) { + if (!raw || raw_len == 0) return false; + for (size_t i = 0; i < sizeof(dsml_syntaxes) / sizeof(dsml_syntaxes[0]); i++) { + const char *lit = dsml_syntaxes[i].tool_calls_start; + const size_t lit_len = strlen(lit); + const size_t max = raw_len < lit_len ? raw_len : lit_len - 1; + for (size_t n = 2; n <= max; n++) { + if (!memcmp(raw + raw_len - n, lit, n)) return true; + } + } + return false; +} + static bool dsml_find_tool_start(const char *raw, size_t raw_len, size_t *pos_out, const dsml_syntax **syn_out) { @@ -7469,6 +7489,7 @@ static void id_list_push_unique(stop_list *ids, const char *id); struct server { ds4_engine *engine; ds4_session *session; + directional_steering_policy steering_policy; int default_tokens; kv_disk_cache kv; tool_memory tool_mem; @@ -9778,6 +9799,90 @@ static thinking_state thinking_state_from_prompt(const request *r) { return st; } +static const char *directional_steering_policy_name(directional_steering_policy policy) { + switch (policy) { + case DS4_STEERING_POLICY_ALWAYS: return "always"; + case DS4_STEERING_POLICY_DECODING: return "decoding"; + case DS4_STEERING_POLICY_FINAL_ANSWER: return "final-answer"; + case DS4_STEERING_POLICY_OFF: return "off"; + } + return "unknown"; +} + +static bool request_has_tool_result_context(const request *r) { + return r && r->prompt_text && strstr(r->prompt_text, "") != NULL; +} + +static bool directional_steering_final_answer_context(const request *r, + bool responses_live_continuation, + bool anthropic_live_continuation) { + if (!r) return false; + if (r->kind != REQ_CHAT) return true; + if (!r->has_tools) return true; + return responses_live_continuation || + anthropic_live_continuation || + request_has_tool_result_context(r); +} + +static bool text_has_nonspace(const char *p, size_t len) { + if (!p) return false; + for (size_t i = 0; i < len; i++) { + if (!isspace((unsigned char)p[i])) return true; + } + return false; +} + +static bool directional_steering_should_apply( + directional_steering_policy policy, + bool final_answer_context, + bool saw_final_answer_text, + bool thinking_before, + bool thinking_after, + dsml_decode_state dsml_before, + dsml_decode_state dsml_after, + bool partial_tool_start, + const char *piece, + size_t piece_len, + bool *starts_final_answer_out) { + if (starts_final_answer_out) *starts_final_answer_out = false; + if (policy == DS4_STEERING_POLICY_ALWAYS) return true; + if (policy == DS4_STEERING_POLICY_DECODING) return true; + if (policy == DS4_STEERING_POLICY_OFF) return false; + + if (!final_answer_context) return false; + if (thinking_before || thinking_after) return false; + if (dsml_decode_state_is_tool(dsml_before) || + dsml_decode_state_is_tool(dsml_after) || + partial_tool_start) + { + return false; + } + + const bool starts = text_has_nonspace(piece, piece_len); + if (starts_final_answer_out) *starts_final_answer_out = starts; + return saw_final_answer_text || starts; +} + +static void server_apply_directional_steering(server *s, bool enable) { + if (!s || !s->session) return; + if (enable) { + ds4_session_use_engine_directional_steering(s->session); + } else { + ds4_session_set_directional_steering(s->session, 0.0f, 0.0f); + } +} + +static void server_apply_prefill_directional_steering(server *s) { + server_apply_directional_steering( + s, s && s->steering_policy == DS4_STEERING_POLICY_ALWAYS); +} + +static void server_apply_decode_directional_steering(server *s) { + server_apply_directional_steering( + s, s && (s->steering_policy == DS4_STEERING_POLICY_ALWAYS || + s->steering_policy == DS4_STEERING_POLICY_DECODING)); +} + static bool should_remember_thinking_checkpoint(const request *r, const thinking_state *thinking, const char *finish) { @@ -10314,6 +10419,7 @@ static void generate_job(server *s, job *j) { req_flags[0] ? " " : "", req_flags); ds4_session_set_progress(s->session, server_progress_cb, &progress); + server_apply_prefill_directional_steering(s); int cold_store_len = 0; if (cached == 0 && @@ -10448,6 +10554,14 @@ static void generate_job(server *s, job *j) { thinking_state thinking = thinking_state_from_prompt(&j->req); dsml_decode_tracker dsml_tracker; dsml_decode_tracker_init(&dsml_tracker); + const bool dynamic_steering = + s->steering_policy == DS4_STEERING_POLICY_FINAL_ANSWER; + const bool final_answer_context = + directional_steering_final_answer_context(&j->req, + responses_live_continuation, + anthropic_live_continuation); + bool saw_final_answer_text = false; + server_apply_decode_directional_steering(s); while (!g_stop_requested && completion < max_tokens && ds4_session_pos(s->session) < ds4_session_ctx(s->session)) { @@ -10478,9 +10592,11 @@ static void generate_job(server *s, job *j) { int toks[17]; int ntok = 0; + bool toks_evaluated = false; if (temperature <= 0.0f && ds4_engine_mtp_draft_tokens(s->engine) > 1 && - getenv("DS4_MTP_SPEC_DISABLE") == NULL) + getenv("DS4_MTP_SPEC_DISABLE") == NULL && + !dynamic_steering) { ntok = ds4_session_eval_speculative_argmax(s->session, token, @@ -10494,11 +10610,8 @@ static void generate_job(server *s, job *j) { finish = "error"; break; } + toks_evaluated = true; } else { - if (ds4_session_eval(s->session, token, err, sizeof(err)) != 0) { - finish = "error"; - break; - } toks[0] = token; ntok = 1; } @@ -10514,12 +10627,65 @@ static void generate_job(server *s, job *j) { size_t piece_len = 0; char *piece = ds4_token_text(s->engine, token, &piece_len); + thinking_state next_thinking = thinking; + dsml_decode_tracker next_dsml_tracker = dsml_tracker; + dsml_decode_state next_dsml_state = dsml_state; + bool starts_final_answer = false; + + if (!toks_evaluated) { + if (dynamic_steering) { + const bool thinking_before = thinking.inside; + thinking_state_feed(&next_thinking, piece, piece_len); + bool partial_tool_start = false; + if (j->req.kind == REQ_CHAT && j->req.has_tools) { + const size_t old_len = text.len; + buf_append(&text, piece, piece_len); + dsml_decode_tracker_update(&next_dsml_tracker, + text.ptr, text.len); + next_dsml_state = next_dsml_tracker.decode; + partial_tool_start = + dsml_text_ends_with_partial_tool_start(text.ptr, + text.len); + text.len = old_len; + if (text.ptr) text.ptr[text.len] = '\0'; + } + const bool steer_token = directional_steering_should_apply( + s->steering_policy, + final_answer_context, + saw_final_answer_text, + thinking_before, + next_thinking.inside, + dsml_state, + next_dsml_state, + partial_tool_start, + piece, + piece_len, + &starts_final_answer); + server_apply_directional_steering(s, steer_token); + } + int eval_rc = dynamic_steering ? + ds4_session_eval_no_mtp(s->session, token, err, sizeof(err)) : + ds4_session_eval(s->session, token, err, sizeof(err)); + if (eval_rc != 0) { + finish = "error"; + free(piece); + stop_decode = true; + break; + } + } completion++; trace_piece(s, trace_id, piece, piece_len); buf_append(&text, piece, piece_len); - thinking_state_feed(&thinking, piece, piece_len); - if (j->req.kind == REQ_CHAT && j->req.has_tools) { + if (dynamic_steering) { + thinking = next_thinking; + dsml_tracker = next_dsml_tracker; + if (starts_final_answer) saw_final_answer_text = true; + } else { + thinking_state_feed(&thinking, piece, piece_len); + } + if (!dynamic_steering && + j->req.kind == REQ_CHAT && j->req.has_tools) { dsml_decode_tracker_update(&dsml_tracker, text.ptr, text.len); } @@ -11243,6 +11409,7 @@ typedef struct { const char *kv_disk_dir; uint64_t kv_disk_space_mb; kv_cache_options kv_cache; + directional_steering_policy steering_policy; bool kv_cache_reject_different_quant; bool disable_exact_dsml_tool_replay; int tool_memory_max_ids; @@ -11345,6 +11512,8 @@ static void usage(FILE *fp) { " Apply steering after FFN outputs: y -= F*v*dot(v,y). Default with file: 1\n" " --dir-steering-attn F\n" " Apply steering after attention outputs. Default: 0\n" + " --dir-steering-policy MODE\n" + " Server steering policy: final-answer, decoding, always, or off. Default: final-answer\n" " --warm-weights\n" " Touch mapped tensor pages before serving. Slower startup, fewer first-use stalls.\n" " --metal | --cuda | --cpu | --backend NAME\n" @@ -11425,6 +11594,28 @@ static ds4_backend default_server_backend(void) { #endif } +static directional_steering_policy parse_directional_steering_policy_arg( + const char *s, + const char *arg) { + if (!strcmp(s, "always")) return DS4_STEERING_POLICY_ALWAYS; + if (!strcmp(s, "decoding") || !strcmp(s, "decode")) { + return DS4_STEERING_POLICY_DECODING; + } + if (!strcmp(s, "final-answer") || + !strcmp(s, "final") || + !strcmp(s, "tool-safe")) + { + return DS4_STEERING_POLICY_FINAL_ANSWER; + } + if (!strcmp(s, "off") || !strcmp(s, "none")) { + return DS4_STEERING_POLICY_OFF; + } + server_log(DS4_LOG_DEFAULT, "ds4-server: invalid %s value: %s", arg, s); + server_log(DS4_LOG_DEFAULT, + "ds4-server: valid directional steering policies are: final-answer, decoding, always, off"); + exit(2); +} + static server_config parse_options(int argc, char **argv) { server_config c = { .engine = { @@ -11437,6 +11628,7 @@ static server_config parse_options(int argc, char **argv) { .port = 8000, .ctx_size = 32768, .default_tokens = 393216, + .steering_policy = DS4_STEERING_POLICY_FINAL_ANSWER, .tool_memory_max_ids = DS4_TOOL_MEMORY_DEFAULT_MAX_IDS, }; c.kv_cache = kv_cache_default_options(); @@ -11497,6 +11689,9 @@ static server_config parse_options(int argc, char **argv) { } else if (!strcmp(arg, "--dir-steering-attn")) { c.engine.directional_steering_attn = parse_float_arg(need_arg(&i, argc, argv, arg), arg, -100.0f, 100.0f); directional_steering_scale_set = true; + } else if (!strcmp(arg, "--dir-steering-policy")) { + c.steering_policy = + parse_directional_steering_policy_arg(need_arg(&i, argc, argv, arg), arg); } else if (!strcmp(arg, "--warm-weights")) { c.engine.warm_weights = true; } else if (!strcmp(arg, "--metal")) { @@ -11555,6 +11750,7 @@ int main(int argc, char **argv) { memset(&s, 0, sizeof(s)); s.engine = engine; s.session = session; + s.steering_policy = cfg.steering_policy; s.default_tokens = cfg.default_tokens; s.disable_exact_dsml_tool_replay = cfg.disable_exact_dsml_tool_replay; s.tool_mem.max_entries = cfg.tool_memory_max_ids; @@ -11566,6 +11762,11 @@ int main(int argc, char **argv) { server_log(DS4_LOG_DEFAULT, "ds4-server: exact DSML tool replay disabled; tool history uses canonical JSON rendering"); } + if (s.steering_policy != DS4_STEERING_POLICY_ALWAYS) { + server_log(DS4_LOG_DEFAULT, + "ds4-server: directional steering policy=%s", + directional_steering_policy_name(s.steering_policy)); + } pthread_mutex_init(&s.mu, NULL); pthread_cond_init(&s.cv, NULL); pthread_cond_init(&s.clients_cv, NULL); @@ -13594,6 +13795,165 @@ static void test_dsml_decode_state_separates_structure_and_payload(void) { TEST_ASSERT(tracker.decode == DSML_DECODE_OUTSIDE); } +static void test_directional_steering_final_answer_policy_is_tool_safe(void) { + char *argv0[] = {"ds4-server"}; + server_config cfg = parse_options(1, argv0); + TEST_ASSERT(cfg.steering_policy == DS4_STEERING_POLICY_FINAL_ANSWER); + TEST_ASSERT(parse_directional_steering_policy_arg("decoding", "--dir-steering-policy") == + DS4_STEERING_POLICY_DECODING); + TEST_ASSERT(parse_directional_steering_policy_arg("decode", "--dir-steering-policy") == + DS4_STEERING_POLICY_DECODING); + TEST_ASSERT(!strcmp(directional_steering_policy_name(DS4_STEERING_POLICY_DECODING), + "decoding")); + + bool starts = true; + TEST_ASSERT(directional_steering_should_apply( + DS4_STEERING_POLICY_ALWAYS, + false, + false, + true, + false, + DSML_DECODE_STRUCTURAL, + DSML_DECODE_OUTSIDE, + true, + "", + 0, + &starts)); + TEST_ASSERT(starts == false); + + TEST_ASSERT(directional_steering_should_apply( + DS4_STEERING_POLICY_DECODING, + false, + false, + true, + true, + DSML_DECODE_STRUCTURAL, + DSML_DECODE_STRING_BODY, + true, + DS4_TOOL_CALLS_START, + strlen(DS4_TOOL_CALLS_START), + NULL)); + + starts = true; + TEST_ASSERT(!directional_steering_should_apply( + DS4_STEERING_POLICY_OFF, + true, + true, + false, + false, + DSML_DECODE_OUTSIDE, + DSML_DECODE_OUTSIDE, + false, + "answer", + strlen("answer"), + &starts)); + TEST_ASSERT(starts == false); + + starts = true; + TEST_ASSERT(!directional_steering_should_apply( + DS4_STEERING_POLICY_FINAL_ANSWER, + false, + false, + false, + false, + DSML_DECODE_OUTSIDE, + DSML_DECODE_OUTSIDE, + false, + "answer", + strlen("answer"), + &starts)); + TEST_ASSERT(starts == false); + + TEST_ASSERT(!directional_steering_should_apply( + DS4_STEERING_POLICY_FINAL_ANSWER, + true, + false, + true, + false, + DSML_DECODE_OUTSIDE, + DSML_DECODE_OUTSIDE, + false, + "", + strlen(""), + NULL)); + + TEST_ASSERT(!directional_steering_should_apply( + DS4_STEERING_POLICY_FINAL_ANSWER, + true, + false, + false, + false, + DSML_DECODE_STRUCTURAL, + DSML_DECODE_STRUCTURAL, + false, + DS4_TOOL_CALLS_START, + strlen(DS4_TOOL_CALLS_START), + NULL)); + + TEST_ASSERT(!directional_steering_should_apply( + DS4_STEERING_POLICY_FINAL_ANSWER, + true, + false, + false, + false, + DSML_DECODE_OUTSIDE, + DSML_DECODE_OUTSIDE, + true, + DS4_TOOL_CALLS_START, + strlen(DS4_TOOL_CALLS_START) - 2, + NULL)); + + starts = false; + TEST_ASSERT(directional_steering_should_apply( + DS4_STEERING_POLICY_FINAL_ANSWER, + true, + false, + false, + false, + DSML_DECODE_OUTSIDE, + DSML_DECODE_OUTSIDE, + false, + "answer", + strlen("answer"), + &starts)); + TEST_ASSERT(starts == true); + + starts = true; + TEST_ASSERT(directional_steering_should_apply( + DS4_STEERING_POLICY_FINAL_ANSWER, + true, + true, + false, + false, + DSML_DECODE_OUTSIDE, + DSML_DECODE_OUTSIDE, + false, + " ", + 1, + &starts)); + TEST_ASSERT(starts == false); + + request r = { + .kind = REQ_CHAT, + .has_tools = true, + .prompt_text = "user asks before any tool result", + }; + TEST_ASSERT(!directional_steering_final_answer_context(&r, false, false)); + TEST_ASSERT(directional_steering_final_answer_context(&r, true, false)); + r.prompt_text = "ok"; + TEST_ASSERT(directional_steering_final_answer_context(&r, false, false)); + r.has_tools = false; + r.prompt_text = NULL; + TEST_ASSERT(directional_steering_final_answer_context(&r, false, false)); + + request c = {.kind = REQ_COMPLETION}; + TEST_ASSERT(directional_steering_final_answer_context(&c, false, false)); + TEST_ASSERT(dsml_text_ends_with_partial_tool_start( + DS4_TOOL_CALLS_START, + strlen(DS4_TOOL_CALLS_START) - 2)); + TEST_ASSERT(!dsml_text_ends_with_partial_tool_start("plain", strlen("plain"))); +} + static void test_tool_memory_max_ids_prunes_oldest(void) { const char *a_dsml = "\n\n<|DSML|tool_calls>\n<|DSML|invoke name=\"bash\">\n<|DSML|parameter name=\"command\" string=\"true\">a\n\n"; const char *b_dsml = "\n\n<|DSML|tool_calls>\n<|DSML|invoke name=\"bash\">\n<|DSML|parameter name=\"command\" string=\"true\">b\n\n"; @@ -14507,6 +14867,7 @@ static void ds4_server_unit_tests_run(void) { test_responses_visible_suffix_matches_client_replay(); test_exact_dsml_tool_replay_can_be_disabled(); test_dsml_decode_state_separates_structure_and_payload(); + test_directional_steering_final_answer_policy_is_tool_safe(); test_tool_memory_max_ids_prunes_oldest(); test_kv_tool_map_filters_by_dsml_text(); test_kv_tool_map_restores_before_prompt_render();