From 0de84e335650ccded4faa792896d3b5c78bf8680 Mon Sep 17 00:00:00 2001 From: Dev-X25874 <283057883+Dev-X25874@users.noreply.github.com> Date: Thu, 14 May 2026 18:22:29 +0530 Subject: [PATCH] stat_tracking: fix AttributeError on second update() call caused by in-place np.stack overwrite --- flow_grpo/flow_grpo/stat_tracking.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flow_grpo/flow_grpo/stat_tracking.py b/flow_grpo/flow_grpo/stat_tracking.py index 0166c10..1cb0349 100644 --- a/flow_grpo/flow_grpo/stat_tracking.py +++ b/flow_grpo/flow_grpo/stat_tracking.py @@ -20,13 +20,13 @@ def update(self, prompts, rewards, type='grpo'): self.stats[prompt].extend(prompt_rewards) self.history_prompts.add(hash(prompt)) # Add hash of prompt to history_prompts for prompt in unique: - self.stats[prompt] = np.stack(self.stats[prompt]) + prompt_history = np.array(self.stats[prompt]) prompt_rewards = rewards[prompts == prompt] # Fix: Recalculate prompt_rewards for each prompt - mean = np.mean(self.stats[prompt], axis=0, keepdims=True) + mean = np.mean(prompt_history, axis=0, keepdims=True) if self.global_std: std = np.std(rewards, axis=0, keepdims=True) + 1e-4 # Use global std of all rewards else: - std = np.std(self.stats[prompt], axis=0, keepdims=True) + 1e-4 + std = np.std(prompt_history, axis=0, keepdims=True) + 1e-4 if type=='grpo': advantages[prompts == prompt] = (prompt_rewards - mean) / std elif type=='rwr': @@ -75,4 +75,4 @@ def main(): print("Stats after clear:", tracker.stats) if __name__ == "__main__": - main() \ No newline at end of file + main()