Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 47 additions & 13 deletions acp/internal/controller/task/task_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package task
import (
"bytes"
"context"
"crypto/rand"
"encoding/json"
"errors"
"fmt"
"io"
"math/big"
"net/http"
"time"

Expand Down Expand Up @@ -787,7 +789,6 @@ func (r *TaskReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.
return ctrl.Result{}, nil
}

// SetupWithManager sets up the controller with the Manager.
// notifyResponseURLAsync sends the final task result to the response URL asynchronously
func (r *TaskReconciler) notifyResponseURLAsync(task *acp.Task, result string) {
go func() {
Expand All @@ -797,7 +798,7 @@ func (r *TaskReconciler) notifyResponseURLAsync(task *acp.Task, result string) {
logger := log.FromContext(ctx)
taskCopy := task.DeepCopy()

err := r.sendFinalResultToResponseURL(ctx, task.Spec.ResponseURL, result)
err := r.sendFinalResultToResponseURL(ctx, task, result)
if err != nil {
logger.Error(err, "Failed to send final result to responseURL",
"responseURL", task.Spec.ResponseURL,
Expand All @@ -815,10 +816,43 @@ func (r *TaskReconciler) notifyResponseURLAsync(task *acp.Task, result string) {
}()
}

// assertAvailablePRNG ensures that a cryptographically secure PRNG is available
func assertAvailablePRNG() {
buf := make([]byte, 1)
_, err := io.ReadFull(rand.Reader, buf)
if err != nil {
panic(fmt.Sprintf("crypto/rand is unavailable: Read() failed with %#v", err))
}
}

// init ensures that a cryptographically secure PRNG is available when the package is loaded
func init() {
assertAvailablePRNG()
}

// generateRandomString returns a securely generated random string
func generateRandomString(n int) (string, error) {
const letters = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz-"
ret := make([]byte, n)
for i := 0; i < n; i++ {
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
if err != nil {
return "", err
}
ret[i] = letters[num.Int64()]
}
return string(ret), nil
}

// createHumanContactRequest builds the request payload for sending to a response URL
func createHumanContactRequest(result string) ([]byte, error) {
runID := uuid.New().String()
callID := uuid.New().String()
func createHumanContactRequest(agentName string, result string) ([]byte, error) {
// Use agent name as runId
runID := agentName
// Generate a secure random string for callId
callID, err := generateRandomString(7)
if err != nil {
return nil, fmt.Errorf("failed to generate secure random string: %w", err)
}
spec := humanlayerapi.NewHumanContactSpecInput(result)
input := humanlayerapi.NewHumanContactInput(runID, callID, *spec)
return json.Marshal(input)
Expand All @@ -831,12 +865,12 @@ func isRetryableStatusCode(statusCode int) bool {

// sendFinalResultToResponseURL sends the final task result to the specified URL
// It includes retry logic for transient errors and better error categorization
func (r *TaskReconciler) sendFinalResultToResponseURL(ctx context.Context, responseURL string, result string) error {
func (r *TaskReconciler) sendFinalResultToResponseURL(ctx context.Context, task *acp.Task, result string) error {
logger := log.FromContext(ctx)
logger.Info("Sending final result to responseURL", "responseURL", responseURL)
logger.Info("Sending final result to responseURL", "responseURL", task.Spec.ResponseURL)

// Create the request body
jsonData, err := createHumanContactRequest(result)
jsonData, err := createHumanContactRequest(task.Spec.AgentRef.Name, result)
if err != nil {
return fmt.Errorf("failed to marshal request body: %w", err)
}
Expand All @@ -846,9 +880,9 @@ func (r *TaskReconciler) sendFinalResultToResponseURL(ctx context.Context, respo
initialDelay := 1 * time.Second

// Retry the operation with exponential backoff
return retryWithBackoff(ctx, maxRetries, initialDelay, responseURL, func() (bool, error) {
return retryWithBackoff(ctx, maxRetries, initialDelay, task.Spec.ResponseURL, func() (bool, error) {
// Create the HTTP request
req, err := http.NewRequestWithContext(ctx, "POST", responseURL, bytes.NewBuffer(jsonData))
req, err := http.NewRequestWithContext(ctx, "POST", task.Spec.ResponseURL, bytes.NewBuffer(jsonData))
if err != nil {
return false, fmt.Errorf("failed to create HTTP request: %w", err) // Non-retryable
}
Expand Down Expand Up @@ -893,15 +927,15 @@ func (r *TaskReconciler) sendFinalResultToResponseURL(ctx context.Context, respo
// Success case
logger.Info("Successfully sent final result to responseURL",
"statusCode", resp.StatusCode,
"responseURL", responseURL)
"responseURL", task.Spec.ResponseURL)
return false, nil
})
}

// retryWithBackoff executes an operation with exponential backoff
func retryWithBackoff(ctx context.Context, maxRetries int, initialDelay time.Duration,
responseURL string, operation func() (bool, error)) error {

responseURL string, operation func() (bool, error),
) error {
logger := log.FromContext(ctx)
var lastErr error
delay := initialDelay
Expand Down
40 changes: 39 additions & 1 deletion acp/internal/server/server.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package server

import (
"bytes"
"context"
"encoding/json"
"net/http"
"strings"
"time"

"github.com/gin-gonic/gin"
Expand All @@ -22,6 +25,8 @@ type CreateTaskRequest struct {
AgentName string `json:"agentName"` // Required
UserMessage string `json:"userMessage,omitempty"` // Optional if contextWindow is provided
ContextWindow []acp.Message `json:"contextWindow,omitempty"` // Optional if userMessage is provided
ResponseURL string `json:"responseURL,omitempty"` // Optional, URL for receiving task results
ResponseUrl string `json:"responseUrl,omitempty"` // Alternative casing for responseURL (deprecated)
}

// APIServer represents the REST API server
Expand Down Expand Up @@ -156,12 +161,36 @@ func (s *APIServer) createTask(c *gin.Context) {
ctx := c.Request.Context()
logger := log.FromContext(ctx)

// First, read the raw data and store it for validation
var rawData []byte
if data, err := c.GetRawData(); err == nil {
rawData = data
} else {
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to read request body: " + err.Error()})
return
}

// First parse to basic binding
var req CreateTaskRequest
if err := c.ShouldBindJSON(&req); err != nil {
if err := json.Unmarshal(rawData, &req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body: " + err.Error()})
return
}

// Then check for unknown fields with a more strict decoder
decoder := json.NewDecoder(bytes.NewReader(rawData))
decoder.DisallowUnknownFields()
if err := decoder.Decode(&req); err != nil {
// Check if it's an unknown field error
if strings.Contains(err.Error(), "unknown field") {
c.JSON(http.StatusBadRequest, gin.H{"error": "Unknown field in request: " + err.Error()})
return
}
// For other JSON errors
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid JSON format: " + err.Error()})
return
}

if req.AgentName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "agentName is required"})
return
Expand All @@ -177,6 +206,14 @@ func (s *APIServer) createTask(c *gin.Context) {
namespace = "default"
}

// Handle both responseURL and responseUrl fields (with responseURL taking precedence)
responseURL := req.ResponseURL
if responseURL == "" && req.ResponseUrl != "" {
responseURL = req.ResponseUrl
logger.Info("Using deprecated 'responseUrl' field, please use 'responseURL' instead",
"responseUrl", req.ResponseUrl)
}

// Check if agent exists
var agent acp.Agent
err := s.client.Get(ctx, client.ObjectKey{Namespace: namespace, Name: req.AgentName}, &agent)
Expand Down Expand Up @@ -205,6 +242,7 @@ func (s *APIServer) createTask(c *gin.Context) {
AgentRef: acp.LocalObjectReference{Name: req.AgentName},
UserMessage: req.UserMessage,
ContextWindow: req.ContextWindow,
ResponseURL: responseURL,
},
}

Expand Down
Loading