diff --git a/.local.env b/.local.env index a1973fd1..d82cabb4 100644 --- a/.local.env +++ b/.local.env @@ -1,8 +1,8 @@ -SENTRIUS_VERSION=1.1.193 -SENTRIUS_SSH_VERSION=1.1.35 -SENTRIUS_KEYCLOAK_VERSION=1.1.47 -SENTRIUS_AGENT_VERSION=1.1.34 -SENTRIUS_AI_AGENT_VERSION=1.1.64 -LLMPROXY_VERSION=1.0.46 -LAUNCHER_VERSION=1.0.51 -AGENTPROXY_VERSION=1.0.66 \ No newline at end of file +SENTRIUS_VERSION=1.1.261 +SENTRIUS_SSH_VERSION=1.1.40 +SENTRIUS_KEYCLOAK_VERSION=1.1.52 +SENTRIUS_AGENT_VERSION=1.1.39 +SENTRIUS_AI_AGENT_VERSION=1.1.148 +LLMPROXY_VERSION=1.0.53 +LAUNCHER_VERSION=1.0.73 +AGENTPROXY_VERSION=1.0.74 \ No newline at end of file diff --git a/.local.env.bak b/.local.env.bak index a1973fd1..d82cabb4 100644 --- a/.local.env.bak +++ b/.local.env.bak @@ -1,8 +1,8 @@ -SENTRIUS_VERSION=1.1.193 -SENTRIUS_SSH_VERSION=1.1.35 -SENTRIUS_KEYCLOAK_VERSION=1.1.47 -SENTRIUS_AGENT_VERSION=1.1.34 -SENTRIUS_AI_AGENT_VERSION=1.1.64 -LLMPROXY_VERSION=1.0.46 -LAUNCHER_VERSION=1.0.51 -AGENTPROXY_VERSION=1.0.66 \ No newline at end of file +SENTRIUS_VERSION=1.1.261 +SENTRIUS_SSH_VERSION=1.1.40 +SENTRIUS_KEYCLOAK_VERSION=1.1.52 +SENTRIUS_AGENT_VERSION=1.1.39 +SENTRIUS_AI_AGENT_VERSION=1.1.148 +LLMPROXY_VERSION=1.0.53 +LAUNCHER_VERSION=1.0.73 +AGENTPROXY_VERSION=1.0.74 \ No newline at end of file diff --git a/agent-launcher/src/main/java/io/sentrius/agent/launcher/api/AgentLauncherController.java b/agent-launcher/src/main/java/io/sentrius/agent/launcher/api/AgentLauncherController.java index a8665b42..8d6af2bc 100644 --- a/agent-launcher/src/main/java/io/sentrius/agent/launcher/api/AgentLauncherController.java +++ b/agent-launcher/src/main/java/io/sentrius/agent/launcher/api/AgentLauncherController.java @@ -12,7 +12,6 @@ import lombok.extern.slf4j.Slf4j; import org.apache.http.HttpStatus; import org.springframework.http.ResponseEntity; -import org.springframework.web.bind.annotation.DeleteMapping; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; @@ -50,8 +49,7 @@ public ResponseEntity createPod( return ResponseEntity.status(HttpStatus.SC_UNAUTHORIZED).body("Invalid Keycloak token"); } - var clientId = agent.getAgentName(); - podLauncherService.launchAgentPod(clientId, agent.getAgentCallbackUrl()); + podLauncherService.launchAgentPod(agent); return ResponseEntity.ok(Map.of("status", "success")); } @@ -62,9 +60,18 @@ public ResponseEntity deleteAgent(@RequestParam(name="agentId") String a podLauncherService.deleteAgentById(agentId); return ResponseEntity.ok("Shutdown triggered"); } catch (Exception e) { - e.printStackTrace(); return ResponseEntity.status(500).body("Shutdown failed: " + e.getMessage()); } } + @GetMapping("/status") + public ResponseEntity getAgentStatus(@RequestParam(name="agentId") String agentId) { + try { + return ResponseEntity.ok(podLauncherService.statusById(agentId) ); + } catch (Exception e) { + log.error("Status failed", e); + return ResponseEntity.status(500).body("Status retrieval failed"); + } + } + } diff --git a/agent-launcher/src/main/java/io/sentrius/agent/launcher/service/PodLauncherService.java b/agent-launcher/src/main/java/io/sentrius/agent/launcher/service/PodLauncherService.java index 84b23475..da128093 100644 --- a/agent-launcher/src/main/java/io/sentrius/agent/launcher/service/PodLauncherService.java +++ b/agent-launcher/src/main/java/io/sentrius/agent/launcher/service/PodLauncherService.java @@ -3,14 +3,17 @@ import io.kubernetes.client.custom.IntOrString; import io.kubernetes.client.custom.Quantity; import io.kubernetes.client.openapi.ApiClient; +import io.kubernetes.client.openapi.ApiException; import io.kubernetes.client.openapi.apis.CoreV1Api; import io.kubernetes.client.openapi.models.*; import io.kubernetes.client.util.Config; +import io.sentrius.sso.core.dto.AgentRegistrationDTO; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.regex.Matcher; @@ -100,7 +103,7 @@ public void deleteAgentById(String agentId) throws Exception { } catch (Exception ex) { log.warn("Service not found or already deleted: {}", ex.getMessage()); } - }else { + } else { log.info("Not Deleting pod: {}", podName); } @@ -111,12 +114,52 @@ public void deleteAgentById(String agentId) throws Exception { } + } - } + public String statusById(String agentId) throws Exception { + // Delete all pods with this agentId label + var pods = coreV1Api.listNamespacedPod( + agentNamespace + ).execute().getItems(); + + for (V1Pod pod : pods) { + + var labels = pod.getMetadata().getLabels(); + var podName = pod.getMetadata().getName(); + + Matcher matcher = pattern.matcher(agentId); + + if (matcher.matches() && labels != null && labels.containsKey("agentId")) { + String name = matcher.group(1); + + var value = labels.get("agentId"); + if (value.equals(name)) { + // get pod status + // + V1PodStatus status = pod.getStatus(); + if (status == null) { + log.warn("Pod {} has no status information", podName); + return "Unknown"; + } + return status.getPhase(); // e.g., "Running", "Pending", "Failed", "Succeeded" + + } + + + } + + + } + return "NotFound"; + } + - public V1Pod launchAgentPod(String agentId, String callbackUrl) throws Exception { + + + + public V1Pod launchAgentPod(AgentRegistrationDTO agent) throws Exception { var myAgentRegistry = ""; if (agentRegistry != null ) { if ("local".equalsIgnoreCase(agentRegistry)) { @@ -125,9 +168,37 @@ public V1Pod launchAgentPod(String agentId, String callbackUrl) throws Exception myAgentRegistry += "/"; } } + String agentId = agent.getAgentName().toLowerCase(); + String callbackUrl = agent.getAgentCallbackUrl(); + String agentType = agent.getAgentType(); var constructedCallbackUrl = buildAgentCallbackUrl(agentId); + + List argList = new ArrayList<>(); + argList.add("--spring.config.location=file:/config/agent.properties"); + argList.add("--agent.namePrefix=" + agentId); + argList.add("--agent.listen.websocket=true"); + argList.add("--agent.callback.url=" + constructedCallbackUrl); + if (agent.getAgentContextId() != null && !agent.getAgentContextId().isEmpty()) { + argList.add("--agent.ai.context.db.id=" + agent.getAgentContextId()); + }else { + String agentFile= "chat-helper.yaml"; + switch(agentType){ + case "chat": + agentFile = "chat-helper.yaml"; + break; + case "atpl-helper": + agentFile = "chat-atpl-helper.yaml"; + break; + case "default": + default: + agentFile = "chat-helper.yaml"; + } + argList.add("--agent.ai.config=/config/" + agentFile); + } + + String image = String.format("%ssentrius-launchable-agent:%s", myAgentRegistry, agentVersion); log.info("Launching agent pod with ID: {}, Image: {}, Callback URL: {}", agentId, image, callbackUrl); @@ -141,10 +212,7 @@ public V1Pod launchAgentPod(String agentId, String callbackUrl) throws Exception .image(image) .imagePullPolicy("IfNotPresent") - .args(List.of("--spring.config.location=file:/config/agent.properties", - "--agent.namePrefix=" + agentId, "--agent.ai.config=/config/chat-helper.yaml", "--agent.listen.websocket=true", - "--agent.callback.url=" + constructedCallbackUrl - )) + .args(argList) .resources(new V1ResourceRequirements() .limits(Map.of( "cpu", Quantity.fromString("1000m"), @@ -169,24 +237,34 @@ public V1Pod launchAgentPod(String agentId, String callbackUrl) throws Exception var createdPod = coreV1Api.createNamespacedPod(agentNamespace, pod).execute(); - // Create corresponding service for WebSocket routing - V1Service service = new V1Service() - .metadata(new V1ObjectMeta() - .name("sentrius-agent-" + agentId) - .labels(Map.of("agentId", agentId))) - .spec(new V1ServiceSpec() - .selector(Map.of("agentId", agentId)) - .ports(List.of(new V1ServicePort() - .protocol("TCP") - .port(8090) - .targetPort(new IntOrString(8090)) - )) - .type("ClusterIP") - ); - - log.info("Created service pod: {} and service {}", createdPod, service); - coreV1Api.createNamespacedService(agentNamespace, service).execute(); - + try { + // Create corresponding service for WebSocket routing + V1Service service = new V1Service() + .metadata(new V1ObjectMeta() + .name("sentrius-agent-" + agentId) + .labels(Map.of("agentId", agentId))) + .spec(new V1ServiceSpec() + .selector(Map.of("agentId", agentId)) + .ports(List.of(new V1ServicePort() + .protocol("TCP") + .port(8090) + .targetPort(new IntOrString(8090)) + )) + .type("ClusterIP") + ); + + log.info("Created service pod: {} and service {}", createdPod, service); + coreV1Api.createNamespacedService(agentNamespace, service).execute(); + + }catch(ApiException e){ + if (e.getCode() == 409){ + log.info("Service for agent {} already exists, skipping creation", agentId); + } + else{ + throw e; + } + } return createdPod; } + } diff --git a/agent-proxy/src/main/java/io/sentrius/sso/config/AgentWebSocketProxyHandler.java b/agent-proxy/src/main/java/io/sentrius/sso/config/AgentWebSocketProxyHandler.java index ecbcfa91..109128a2 100644 --- a/agent-proxy/src/main/java/io/sentrius/sso/config/AgentWebSocketProxyHandler.java +++ b/agent-proxy/src/main/java/io/sentrius/sso/config/AgentWebSocketProxyHandler.java @@ -65,7 +65,7 @@ public Mono handle(WebSocketSession clientSession) { log.info("Handling WebSocket connection for host: {}, sessionId: {}, chatGroupId: {}, ztat: {}", agentHost, sessionId, chatGroupId, ztat); - URI agentUri = agentLocator.resolveWebSocketUri(agentHost, sessionId, chatGroupId, ztat); + URI agentUri = agentLocator.resolveWebSocketUri(agentHost.toLowerCase(), sessionId, chatGroupId, ztat); log.info("Resolved agent URI: {}", agentUri); @@ -123,22 +123,38 @@ public Mono handle(WebSocketSession clientSession) { }) .as(clientSession::send) .doOnSuccess(aVoid -> log.info("agent -> client completed gracefully")) // Corrected for Mono - .doOnError(e -> log.error("Error in agent -> client stream", e)) + .doOnError(e -> { + log.error("Error in agent -> client stream", e); + sessionManager.unregister(agentSession.getId()); + }) .onErrorResume(e -> { + sessionManager.unregister(agentSession.getId()); log.error("Agent to client stream error, closing agent session.", e); return agentSession.close().then(Mono.empty()); }) .doFinally(sig -> log.info("Agent to client stream finalized: {}", sig)); return Mono.when(clientToAgent, agentToClient) - .doOnTerminate(() -> log.info("WebSocket proxy connection terminated (client and agent streams completed/cancelled)")) - .doOnError(e -> log.error("Overall proxy connection failed", e)) + .doOnTerminate(() -> { + log.info("WebSocket proxy connection terminated (client and agent " + + "streams completed/cancelled)"); + sessionManager.unregister(agentSession.getId()); + + }) + .doOnError(e -> { + log.error("Overall proxy connection failed", e); + sessionManager.unregister(agentSession.getId()); + + }) .doFinally(sig -> { sessionManager.unregister(finalSessionId); log.info("WebSocket proxy stream closed completely: {}. Final session ID: {}", sig, finalSessionId); }); } - ).doOnError(e -> log.error("Failed to establish proxy connection", e)); + ).doOnError(e -> { + log.error("Failed to establish proxy connection", e); + sessionManager.unregister(finalSessionId); + }); } catch (Exception ex) { diff --git a/agent-proxy/src/main/java/io/sentrius/sso/config/SecurityConfig.java b/agent-proxy/src/main/java/io/sentrius/sso/config/SecurityConfig.java index 4595af51..17cf55df 100644 --- a/agent-proxy/src/main/java/io/sentrius/sso/config/SecurityConfig.java +++ b/agent-proxy/src/main/java/io/sentrius/sso/config/SecurityConfig.java @@ -55,10 +55,6 @@ private ReactiveJwtAuthenticationConverter grantedAuthoritiesExtractor() { converter.setJwtGrantedAuthoritiesConverter(jwt -> { Collection authorities = new JwtGrantedAuthoritiesConverter().convert(jwt); - log.info("JWT Claims: {}", jwt.getClaims()); - - String username = jwt.getClaimAsString("preferred_username"); - String email = jwt.getClaimAsString("email"); return Flux.fromIterable(authorities); }); diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/AgentVerb.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/AgentVerb.java index 523997ba..aeef6f6a 100644 --- a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/AgentVerb.java +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/AgentVerb.java @@ -24,4 +24,6 @@ public class AgentVerb { @Builder.Default Class outputInterpreter = DefaultInterpreter.class; Class inputInterpreter = DefaultInterpreter.class; + + private String exampleJson = ""; } \ No newline at end of file diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/ChatAgent.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/ChatAgent.java index d0580c93..6a2e87be 100644 --- a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/ChatAgent.java +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/ChatAgent.java @@ -135,6 +135,14 @@ public void onApplicationEvent(final ApplicationReadyEvent event) { } } + try { + verbRegistry.scanEndpoints(agentExecution); + } catch (ZtatException e) { + throw new RuntimeException(e); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + while(running) { log.info("Agent Registered..."); diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/PromptBuilder.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/PromptBuilder.java index 924cb95e..12c7a62a 100644 --- a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/PromptBuilder.java +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/PromptBuilder.java @@ -3,6 +3,8 @@ import java.lang.reflect.Method; import java.util.Arrays; import java.util.stream.Collectors; +import io.sentrius.agent.analysis.agents.verbs.ExampleFactory; +import io.sentrius.sso.core.utils.JsonUtil; /** * The `PromptBuilder` class is responsible for constructing a prompt string @@ -25,12 +27,17 @@ public PromptBuilder(VerbRegistry verbRegistry, AgentConfig agentConfig) { this.agentConfig = agentConfig; } + public String buildPrompt(){ + return buildPrompt(true); + } + /** * Builds a prompt string that includes roles, context, instructions, and available verbs. * * @return A formatted prompt string. */ - public String buildPrompt() { + public String buildPrompt(boolean applyInstructions) + { StringBuilder prompt = new StringBuilder(); // Append roles to the prompt @@ -39,36 +46,67 @@ public String buildPrompt() { // Append context to the prompt prompt.append("Context: ").append(agentConfig.getContext()).append("\n\n"); - // Append instructions for using the JSON format - prompt.append("Instructions: ").append("Respond using this JSON format. Only use verbs provided in " + - "Available Verbs. Formulate a complete plan with all possible steps.:\n" + - "\n" + - "{\n" + - " \"plan\": [\n" + - " {\n" + - " \"verb\": \"list_open_terminals\",\n" + - " \"params\": {}\n" + - " },\n" + - " {\n" + - " \"verb\": \"send_terminal_command\",\n" + - " \"params\": {}\n" + - " }\n" + - " ]\n" + - "}\n" ); - - // Append the list of available verbs - prompt.append("Available Verbs:\n"); - - // Iterate through the verbs in the registry and append their details - verbRegistry.getVerbs().forEach((name, verb) -> { - prompt.append("- ").append(name); - prompt.append(" (").append(buildMethodSignature(verb.getMethod())).append(") - "); - prompt.append(verb.getDescription()).append("\n"); - }); + if (applyInstructions) { + // Append instructions for using the JSON format + prompt.append("Instructions: ").append("Respond using this JSON format. Only use verbs provided in " + + "Available Verbs. Formulate a complete plan with all possible steps.:\n" + + "\n" + + "{\n" + + " \"plan\": [\n" + + " {\n" + + " \"verb\": \"list_open_terminals\",\n" + + " \"params\": {}\n" + + " },\n" + + " {\n" + + " \"verb\": \"send_terminal_command\",\n" + + " \"params\": {}\n" + + " }\n" + + " ]\n" + + "}\n"); + } + // Append the list of available verbs + prompt.append("Verb operations:\n"); + + // Iterate through the verbs in the registry and append their details + verbRegistry.getVerbs().forEach((name, verb) -> { + prompt.append("- ").append(name); + prompt.append(" (").append(buildMethodSignature(verb.getMethod())).append(") - "); + prompt.append(verb.getDescription()).append("\n"); + // Optionally generate example params based on arg1 class + Class[] paramTypes = verb.getMethod().getParameterTypes(); + + if (paramTypes.length > 1 && !paramTypes[1].equals(Void.class)) { + var paramName = verb.getMethod().getParameters()[1].getName(); + Object example = ExampleFactory.createExample(paramName, paramTypes[1]); // create a stub from + // your DTO + try { + if (verb.getExampleJson() != null && !verb.getExampleJson().isEmpty()) { + prompt.append(" Example arg1: ").append(verb.getExampleJson()).append("\n"); + } else if (example != null) { + // Serialize the example object to JSON + String exampleJson = JsonUtil.MAPPER.writeValueAsString(example); + prompt.append(" Example arg1: ").append(exampleJson).append("\n"); + } + + } catch (Exception e) { + prompt.append(" Example params: [unavailable due to serialization error]\n"); + } + } else { + prompt.append(" Example params: {}\n"); + } + }); return prompt.toString(); } + public static String indent(String input, int spaces) { + String indent = " ".repeat(spaces); + return Arrays.stream(input.split("\n")) + .map(line -> indent + line) + .collect(Collectors.joining("\n")); + } + + /** * Builds a method signature string for a given method. * @@ -77,6 +115,8 @@ public String buildPrompt() { */ private String buildMethodSignature(Method method) { return Arrays.stream(method.getParameters()) + .filter( p -> !p.getType().getSimpleName().equalsIgnoreCase("TokenDTO") && + !p.getType().getSimpleName().equalsIgnoreCase("AgentExecution")) .map(p -> p.getName() + ": " + p.getType().getSimpleName()) .collect(Collectors.joining(", ")); } diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/RegisteredAgent.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/RegisteredAgent.java index f7ae840f..6c2c9fb7 100644 --- a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/RegisteredAgent.java +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/RegisteredAgent.java @@ -2,7 +2,6 @@ import java.util.HashMap; import java.util.Map; -import java.util.UUID; import java.util.concurrent.TimeUnit; import com.fasterxml.jackson.databind.node.ArrayNode; import io.sentrius.agent.analysis.agents.verbs.AgentVerbs; diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/VerbRegistry.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/VerbRegistry.java index 4c6496c7..c9b0e963 100644 --- a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/VerbRegistry.java +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/VerbRegistry.java @@ -1,5 +1,6 @@ package io.sentrius.agent.analysis.agents.agents; +import com.fasterxml.jackson.core.JsonProcessingException; import io.github.classgraph.ClassGraph; import io.github.classgraph.ScanResult; import io.sentrius.sso.core.dto.capabilities.EndpointDescriptor; @@ -10,6 +11,7 @@ import io.sentrius.sso.core.model.verbs.OutputInterpreterIfc; import io.sentrius.sso.core.model.verbs.Verb; import io.sentrius.sso.core.model.verbs.VerbResponse; +import io.sentrius.sso.core.services.agents.AgentClientService; import io.sentrius.sso.core.services.agents.ZeroTrustClientService; import io.sentrius.sso.core.services.capabilities.EndpointScanningService; import lombok.RequiredArgsConstructor; @@ -20,6 +22,7 @@ import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -34,6 +37,8 @@ public class VerbRegistry { private final ApplicationContext applicationContext; private final ZeroTrustClientService zeroTrustClientService; + + private final AgentClientService agentClientService; private final EndpointScanningService endpointScanningService; @@ -42,6 +47,26 @@ public class VerbRegistry { private final AgentEndpointDiscoveryService agentEndpointDiscoveryService; + private List endpoints = new ArrayList<>(); + + public void scanEndpoints(AgentExecution execution) throws ZtatException, JsonProcessingException { + synchronized (this) { + var endpoints = agentClientService.getAvailableEndpoints(execution); + log.info("Scanning endpoints for verbs..."); + var verbs = agentClientService.getAvailableVerbs(execution); + + endpoints.forEach(x -> { + log.info("Discovered endpoint: {}", x); + }); + + this.endpoints.addAll(endpoints); + + verbs.forEach(x -> { + log.info("Discovered verb: {}", x); + }); + } + } + public void scanClasspath() { // Scan the classpath for classes with the @Verb annotation synchronized (this) { @@ -68,6 +93,7 @@ public void scanClasspath() { .name(name) .description(annotation.description()) .method(method) + .exampleJson(annotation.exampleJson()) .requiresTokenManagement(annotation.requiresTokenManagement()) .returnType(annotation.returnType()) .outputInterpreter(annotation.outputInterpreter()) @@ -83,6 +109,7 @@ public void scanClasspath() { } }); } + } } @@ -173,6 +200,10 @@ public VerbResponse execute(AgentExecution agentExecution, VerbResponse priorRes } } + public List getEndpoints() { + return endpoints; + } + public Map getVerbs() { return new HashMap<>(verbs); } @@ -202,4 +233,5 @@ public List getAiCallableVerbDescriptors() { }) .collect(Collectors.toList()); } + } diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/interpreters/AgentContextInterpreter.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/interpreters/AgentContextInterpreter.java new file mode 100644 index 00000000..1743671d --- /dev/null +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/interpreters/AgentContextInterpreter.java @@ -0,0 +1,22 @@ +package io.sentrius.agent.analysis.agents.interpreters; + +import java.util.Map; +import io.sentrius.sso.core.dto.agents.AgentContextDTO; +import io.sentrius.sso.core.model.verbs.InputInterpreterIfc; +import io.sentrius.sso.core.utils.JsonUtil; + + +public class AgentContextInterpreter implements InputInterpreterIfc { + @Override + public AgentContextDTO interpret(Map input) throws Exception { + Object agentContextObj = input.get("agentContext"); + Object arg1Obj = input.get("arg1"); + + if (agentContextObj != null) { + return JsonUtil.MAPPER.convertValue(agentContextObj, AgentContextDTO.class); + } else if (arg1Obj != null) { + return JsonUtil.MAPPER.convertValue(arg1Obj, AgentContextDTO.class); + } + return null; + } +} diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/interpreters/ObjectNodeInterpreter.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/interpreters/ObjectNodeInterpreter.java new file mode 100644 index 00000000..d4ea3a81 --- /dev/null +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/interpreters/ObjectNodeInterpreter.java @@ -0,0 +1,44 @@ +package io.sentrius.agent.analysis.agents.interpreters; + +import java.util.HashMap; +import java.util.Map; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import io.sentrius.sso.core.model.verbs.InputInterpreterIfc; +import io.sentrius.sso.core.model.verbs.OutputInterpreterIfc; +import io.sentrius.sso.core.model.verbs.VerbResponse; +import io.sentrius.sso.core.utils.JsonUtil; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class ObjectNodeInterpreter implements InputInterpreterIfc, OutputInterpreterIfc { + @Override + public ObjectNode interpret(Map input) throws Exception { + Object raw = input.containsKey("arg1") ? input.get("arg1") : input; + + JsonNode node = JsonUtil.MAPPER.valueToTree(raw); + + if (node == null || node.isNull()) { + throw new IllegalArgumentException("Input is null or could not be converted to ObjectNode"); + } + + if (!node.isObject()) { + throw new IllegalArgumentException("Expected ObjectNode, got: " + node.getNodeType()); + } + + return (ObjectNode) node; + + } + + @Override + public Map interpret(VerbResponse input) throws Exception { + if (input.getResponse() instanceof ObjectNode) { + Map responseMap = new HashMap<>(); + ((ObjectNode)input.getResponse()).fieldNames().forEachRemaining( x ->{ + responseMap.put(x, ((ObjectNode)input.getResponse()).get(x).asText() ); + }); + return responseMap; + } + return Map.of(); + } +} diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/interpreters/StringInterpreter.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/interpreters/StringInterpreter.java new file mode 100644 index 00000000..996a3bc9 --- /dev/null +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/interpreters/StringInterpreter.java @@ -0,0 +1,22 @@ +package io.sentrius.agent.analysis.agents.interpreters; + +import java.util.Map; +import io.sentrius.sso.core.model.verbs.InputInterpreterIfc; +import io.sentrius.sso.core.trust.ATPLPolicy; +import io.sentrius.sso.core.utils.JsonUtil; + +public class StringInterpreter implements InputInterpreterIfc { + + @Override + public ATPLPolicy interpret(Map input) throws Exception { + Object policyObj = input.get("policy"); + Object arg1Obj = input.get("arg1"); + + if (policyObj != null) { + return JsonUtil.MAPPER.convertValue(policyObj, ATPLPolicy.class); + } else if (arg1Obj != null) { + return JsonUtil.MAPPER.convertValue(arg1Obj, ATPLPolicy.class); + } + return null; + } +} diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/interpreters/StringToAtplInterpreter.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/interpreters/StringToAtplInterpreter.java new file mode 100644 index 00000000..9852f5f4 --- /dev/null +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/interpreters/StringToAtplInterpreter.java @@ -0,0 +1,23 @@ +package io.sentrius.agent.analysis.agents.interpreters; + +import java.util.Map; +import com.fasterxml.jackson.core.JsonParser; +import io.sentrius.agent.analysis.model.TerminalResponse; +import io.sentrius.sso.core.model.verbs.InputInterpreterIfc; +import io.sentrius.sso.core.trust.ATPLPolicy; +import io.sentrius.sso.core.utils.JsonUtil; + +public class StringToAtplInterpreter implements InputInterpreterIfc { + @Override + public ATPLPolicy interpret(Map input) throws Exception { + Object policyObj = input.get("policy"); + Object arg1Obj = input.get("arg1"); + + if (policyObj != null) { + return JsonUtil.MAPPER.convertValue(policyObj, ATPLPolicy.class); + } else if (arg1Obj != null) { + return JsonUtil.MAPPER.convertValue(arg1Obj, ATPLPolicy.class); + } + return null; + } +} diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/AgentVerbs.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/AgentVerbs.java index d8224979..6d3e2752 100644 --- a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/AgentVerbs.java +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/AgentVerbs.java @@ -10,12 +10,14 @@ import java.util.UUID; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import com.fasterxml.jackson.core.JsonParseException; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; import com.google.common.collect.Maps; import com.google.common.collect.Sets; @@ -24,15 +26,17 @@ import io.sentrius.agent.analysis.agents.agents.VerbRegistry; import io.sentrius.agent.analysis.agents.interpreters.AsessmentListInterpreter; import io.sentrius.agent.analysis.agents.interpreters.ObjectListInterpreter; +import io.sentrius.agent.analysis.agents.interpreters.ObjectNodeInterpreter; import io.sentrius.agent.analysis.agents.interpreters.ZtatOutputInterpreter; import io.sentrius.agent.analysis.model.AssessedTerminal; import io.sentrius.agent.analysis.model.Assessment; -import io.sentrius.agent.analysis.model.TerminalResponse; -import io.sentrius.agent.analysis.model.WebSocky; import io.sentrius.agent.analysis.model.ZtatAsessment; import io.sentrius.agent.analysis.model.ZtatResponse; import io.sentrius.sso.core.dto.AgentCommunicationDTO; +import io.sentrius.sso.core.dto.AgentRegistrationDTO; import io.sentrius.sso.core.dto.ZtatDTO; +import io.sentrius.sso.core.dto.agents.AgentContextDTO; +import io.sentrius.sso.core.dto.agents.AgentContextRequestDTO; import io.sentrius.sso.core.dto.ztat.AgentExecution; import io.sentrius.sso.core.dto.ztat.AtatRequest; import io.sentrius.sso.core.dto.ztat.ZtatRequestDTO; @@ -45,7 +49,6 @@ import io.sentrius.sso.genai.Message; import io.sentrius.sso.genai.Response; import io.sentrius.sso.genai.model.LLMRequest; -import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; @@ -57,16 +60,14 @@ */ @Service @Slf4j -public class AgentVerbs { +public class AgentVerbs extends VerbBase { final ZeroTrustClientService zeroTrustClientService; final LLMService llmService; final VerbRegistry verbRegistry; - final AgentClientService agentClientService; - @Value("${agent.ai.config}") - private String agentConfigFile; + final ObjectMapper mapper = new ObjectMapper(new YAMLFactory()); // Jackson ObjectMapper for YAML parsing @@ -78,13 +79,15 @@ public class AgentVerbs { * @param verbRegistry The registry containing available verbs and their metadata. * @throws JsonProcessingException If there is an error processing JSON during initialization. */ - public AgentVerbs(ZeroTrustClientService zeroTrustClientService, LLMService llmService, VerbRegistry verbRegistry, + public AgentVerbs( @Value("${agent.ai.config}") String agentConfigFile, + @Value("${agent.ai.context.db.id:none}") String agentDatabaseContext, + ZeroTrustClientService zeroTrustClientService, LLMService llmService, VerbRegistry verbRegistry, AgentClientService agentService ) throws JsonProcessingException { + super(agentConfigFile, agentDatabaseContext, agentService); this.zeroTrustClientService = zeroTrustClientService; this.llmService = llmService; this.verbRegistry = verbRegistry; - this.agentClientService = agentService; log.info("Loading agent config from {}", agentConfigFile); } @@ -101,11 +104,8 @@ public AgentVerbs(ZeroTrustClientService zeroTrustClientService, LLMService llmS isAiCallable = false, requiresTokenManagement = true) public ArrayNode promptAgent(AgentExecution execution, Map args) throws ZtatException, IOException { - InputStream is = getClass().getClassLoader().getResourceAsStream(agentConfigFile); - if (is == null) { - throw new RuntimeException(agentConfigFile + " not found on classpath"); - } - AgentConfig config = new ObjectMapper(new YAMLFactory()).readValue(is, AgentConfig.class); + + AgentConfig config = getAgentConfig(execution); log.info("Agent config loaded: {}", config); PromptBuilder promptBuilder = new PromptBuilder(verbRegistry, config); @@ -114,7 +114,7 @@ public ArrayNode promptAgent(AgentExecution execution, Map args) messages.add(Message.builder().role("system").content(prompt).build()); - LLMRequest chatRequest = LLMRequest.builder().model("gpt-4o").messages(messages).build(); + LLMRequest chatRequest = LLMRequest.builder().model("gpt-4o-mini").messages(messages).build(); var resp = llmService.askQuestion(execution, chatRequest); execution.addMessages( messages ); Response response = JsonUtil.MAPPER.readValue(resp, Response.class); @@ -221,7 +221,7 @@ public String justifyAgent(AgentExecution execution, ZtatRequestDTO ztatRequest, messages.add(Message.builder().role("system").content("please respond in the following json " + "format: " + respondZtat).build()); - LLMRequest chatRequest = LLMRequest.builder().model("gpt-4o").messages(messages).build(); + LLMRequest chatRequest = LLMRequest.builder().model("gpt-4o-mini").messages(messages).build(); execution.addMessages( messages ); var resp = llmService.askQuestion(execution, chatRequest); Response response = JsonUtil.MAPPER.readValue(resp, Response.class); @@ -270,7 +270,8 @@ public String justifyAgent(AgentExecution execution, ZtatRequestDTO ztatRequest, * @throws ZtatException If there is an error during the operation. * @throws IOException If there is an error reading the configuration file. */ - @Verb(name = "assess_data", returnType = ArrayNode.class, description = "Accepts api server data based on the " + + @Verb(name = "assess_api_data", returnType = ArrayNode.class, description = "Accepts api server data based on the" + + " " + "context and seeks" + " to perform the assessment by prompting the LLM. Can be used to assess data or request information from " + "users and/or agents, but not for assessing ztat requests.", @@ -278,25 +279,60 @@ public String justifyAgent(AgentExecution execution, ZtatRequestDTO ztatRequest, inputInterpreter = ObjectListInterpreter.class, requiresTokenManagement = true) public List assessData(AgentExecution execution, List objectList) throws ZtatException, IOException { - InputStream is = getClass().getClassLoader().getResourceAsStream(agentConfigFile); - if (is == null) { - throw new RuntimeException("assessor-config.yaml not found on classpath"); - } - AgentConfig config = new ObjectMapper(new YAMLFactory()).readValue(is, AgentConfig.class); + AgentConfig config = getAgentConfig(execution); log.info("Agent config loaded: {}", config); List responses = new ArrayList<>(); log.info("Object list is {}", objectList); - for (var obj : objectList) { + if (null != objectList) { + for (var obj : objectList) { + List messages = new ArrayList<>(); + var context = config.getContext(); + + var userMessage =Message.builder().role("user").content(obj.toString()).build(); + execution.addMessages(userMessage); + messages.add(userMessage); + + + + LLMRequest chatRequest = LLMRequest.builder().model("gpt-4o-mini").messages(messages).build(); + + var resp = llmService.askQuestion(execution, chatRequest); + Response response = JsonUtil.MAPPER.readValue(resp, Response.class); + log.info("Response is {}", resp); + for (Response.Choice choice : response.getChoices()) { + var content = choice.getMessage().getContent(); + if (content.startsWith("```json")) { + content = content.substring(7, content.length() - 3); + } + + + responses.add(AssessedTerminal.builder().assessment(JsonUtil.MAPPER.readValue( + content, + Assessment.class + )).messages(messages).build()); + log.info("content is {}", content); + } + log.info("Object is {}", obj); + } + }else { List messages = new ArrayList<>(); var context = config.getContext(); - messages.add(Message.builder().role("user").content(obj.toString()).build()); - messages.add(Message.builder().role("system").content(context).build()); - LLMRequest chatRequest = LLMRequest.builder().model("gpt-4o").messages(messages).build(); - execution.addMessages( messages ); + messages.addAll( execution.getMessages()); + + var assistantMessage = + Message.builder().role("assistant").content("Assess the previous data, but respond with the " + + "following format { \"assessment\"{ sessionId, risk, description} } where description is your " + + "assessment, sessionId is a random UUID or a previously found sessionId, and risk is a measure of" + + " low, medium, and high of the previous data" ).build(); + execution.addMessages(assistantMessage); + + + LLMRequest chatRequest = LLMRequest.builder().model("gpt-4o-mini").messages(messages).build(); + execution.addMessages(messages); var resp = llmService.askQuestion(execution, chatRequest); Response response = JsonUtil.MAPPER.readValue(resp, Response.class); log.info("Response is {}", resp); @@ -307,11 +343,12 @@ public List assessData(AgentExecution execution, List objec } - responses.add(AssessedTerminal.builder().assessment(JsonUtil.MAPPER.readValue(content, - Assessment.class)).messages(messages).build()); + responses.add(AssessedTerminal.builder().assessment(JsonUtil.MAPPER.readValue( + content, + Assessment.class + )).messages(messages).build()); log.info("content is {}", content); } - log.info("Object is {}", obj); } return responses; } @@ -379,11 +416,6 @@ public List analyzeAtatRequests(AgentExecution execution, List analyzeAtatRequests(AgentExecution execution, List responses = new ArrayList<>(); log.info("Size of requests {}", requests.size()); @@ -414,7 +446,7 @@ public List analyzeAtatRequests(AgentExecution execution, List analyzeAtatRequests(AgentExecution execution, List analyzeAtatRequests(AgentExecution execution, List messages = new ArrayList<>(); + messages.add( listedEndpoints); + execution.addMessages(Message.builder().role("user").content(queryInput.toString()).build()); + LLMRequest chatRequest = LLMRequest.builder().model("gpt-4o-mini").messages(messages).build(); + var resp = llmService.askQuestion(execution, chatRequest); + + Response response = JsonUtil.MAPPER.readValue(resp, Response.class); + log.info("Response is {}", resp); + for (Response.Choice choice : response.getChoices()) { + var content = choice.getMessage().getContent(); + if (content.startsWith("```json")) { + content = content.substring(7, content.length() - 3); + } else if (content.startsWith("```")) { + content = content.substring(3, content.length() - 3); + } + log.info("content is {}", content); + if (null != content && !content.isEmpty()) { + try { + + ObjectNode newResponse = JsonUtil.MAPPER.createObjectNode(); + JsonNode node = JsonUtil.MAPPER.readTree(content); + + if (node.isArray()) { + ArrayNode arrayNode = (ArrayNode) node; + newResponse.put("endpoints", arrayNode); + return newResponse; + } else { + log.warn("Expected JSON array but got: {}", node.getNodeType()); + } + + return newResponse; + }catch (JsonParseException e) { + log.error("Failed to parse terminal response: {}", e.getMessage()); + throw e; + } + } + } + return contextNode; + } } \ No newline at end of file diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/AtplVerbs.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/AtplVerbs.java new file mode 100644 index 00000000..05b94469 --- /dev/null +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/AtplVerbs.java @@ -0,0 +1,104 @@ +package io.sentrius.agent.analysis.agents.verbs; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Map; +import com.fasterxml.jackson.databind.node.ArrayNode; +import io.sentrius.agent.analysis.agents.interpreters.StringToAtplInterpreter; +import io.sentrius.sso.core.dto.ztat.AtatRequest; +import io.sentrius.sso.core.dto.ztat.TokenDTO; +import io.sentrius.sso.core.exceptions.ZtatException; +import io.sentrius.sso.core.model.verbs.DefaultInterpreter; +import io.sentrius.sso.core.model.verbs.Verb; +import io.sentrius.sso.core.services.agents.AgentClientService; +import io.sentrius.sso.core.services.agents.LLMService; +import io.sentrius.sso.core.services.agents.ZeroTrustClientService; +import io.sentrius.sso.core.trust.ATPLPolicy; +import io.sentrius.sso.core.utils.JsonUtil; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Service; + +/** + * The `TerminalVerbs` class provides methods to interact with terminal-related operations. + * It includes functionality to list open terminals and fetch terminal logs. + */ +@Slf4j +@Service +public class AtplVerbs extends VerbBase { + + final ZeroTrustClientService zeroTrustClientService; + final LLMService llmService; + final AgentVerbs agentVerbs; + + /** + * Constructs a `TerminalVerbs` instance with the required services. + * + * @param zeroTrustClientService The service for interacting with Zero Trust APIs. + * @param llmService The service for interacting with the LLM (Large Language Model). + */ + public AtplVerbs(@Value("${agent.ai.config}") String agentConfigFile, + @Value("${agent.ai.context.db.id:none}") String agentDatabaseContext, + ZeroTrustClientService zeroTrustClientService, LLMService llmService, AgentVerbs agentVerbs, + AgentClientService agentClientService) { + super(agentConfigFile, agentDatabaseContext, agentClientService); + this.zeroTrustClientService = zeroTrustClientService; + this.llmService = llmService; + this.agentVerbs = agentVerbs; + } + + @Verb(name = "qry_policy_id", description = "Queries by policyId.", + inputInterpreter = StringToAtplInterpreter.class, + outputInterpreter = DefaultInterpreter.class, requiresTokenManagement = true) + public ArrayNode queryPolicyById(TokenDTO token, String policyId) throws ZtatException { + try { + + log.info("policy is : {}", policyId); + String response = zeroTrustClientService.callGetOnApi(token, "/api/v1/policies/" + policyId); + if (response == null) { + throw new RuntimeException("Failed to retrieve terminal list"); + } + log.info("Terminal list response: {}", response); + return (ArrayNode) JsonUtil.MAPPER.readTree(response); + } catch (Exception e) { + throw new RuntimeException("Failed to retrieve terminal list", e); + } + } + + @Verb(name = "get_atpl_schema", description = "Gets Schema. No argument required. Returns JSON Schema.", + inputInterpreter = DefaultInterpreter.class, + outputInterpreter = DefaultInterpreter.class, requiresTokenManagement = true) + public String getAtplSchema(TokenDTO token, Map args) throws ZtatException, IOException { + InputStream schema = getClass().getClassLoader().getResourceAsStream("atpl-schema.json"); + if (schema == null) { + throw new RuntimeException("atpl-schema.json not found on classpath"); + + } + return new String(schema.readAllBytes()); + } + + /** + * Retrieves a list of currently open terminals. + * + * @return An `ArrayNode` containing the list of open terminals. + * @throws io.sentrius.sso.core.exceptions.ZtatException If there is an error during the operation. + */ + @Verb(name = "save_policy", description = "Saves an ATPL policy. Accepts ATPL policy in JSON format.", + inputInterpreter = StringToAtplInterpreter.class, + outputInterpreter = DefaultInterpreter.class, requiresTokenManagement = true) + public ArrayNode savePolicy(TokenDTO token, ATPLPolicy policy) throws ZtatException { + try { + + log.info("policy is : {}", policy); + String response = zeroTrustClientService.callPostOnApi("/api/v1/policies", policy); + if (response == null) { + throw new RuntimeException("Failed to retrieve terminal list"); + } + log.info("Terminal list response: {}", response); + return (ArrayNode) JsonUtil.MAPPER.readTree(response); + } catch (Exception e) { + throw new RuntimeException("Failed to retrieve terminal list", e); + } + } + +} \ No newline at end of file diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/ChatVerbs.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/ChatVerbs.java index 369515d4..f0c01d2d 100644 --- a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/ChatVerbs.java +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/ChatVerbs.java @@ -7,9 +7,12 @@ import java.nio.file.Paths; import java.util.ArrayList; import java.util.List; +import java.util.ListIterator; +import com.fasterxml.jackson.core.JsonParseException; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; import io.sentrius.agent.analysis.agents.agents.AgentConfig; import io.sentrius.agent.analysis.agents.agents.AgentVerb; @@ -17,10 +20,13 @@ import io.sentrius.agent.analysis.agents.agents.VerbRegistry; import io.sentrius.agent.analysis.model.TerminalResponse; import io.sentrius.agent.analysis.model.WebSocky; +import io.sentrius.sso.core.dto.agents.AgentContextDTO; +import io.sentrius.sso.core.dto.capabilities.EndpointDescriptor; import io.sentrius.sso.core.dto.ztat.AgentExecution; import io.sentrius.sso.core.exceptions.ZtatException; import io.sentrius.sso.core.model.verbs.Verb; import io.sentrius.sso.core.services.agents.AgentClientService; +import io.sentrius.sso.core.services.agents.AgentExecutionService; import io.sentrius.sso.core.services.agents.LLMService; import io.sentrius.sso.core.services.agents.ZeroTrustClientService; import io.sentrius.sso.core.utils.JsonUtil; @@ -35,17 +41,30 @@ @Service @Slf4j -@RequiredArgsConstructor -public class ChatVerbs { - @Value("${agent.ai.config}") - private String agentConfigFile; +public class ChatVerbs extends VerbBase{ + + private final AgentExecutionService agentExecutionService; final ZeroTrustClientService zeroTrustClientService; final LLMService llmService; final VerbRegistry verbRegistry; final AgentClientService agentClientService; + protected ChatVerbs(@Value("${agent.ai.config}") String agentConfigFile, + @Value("${agent.ai.context.db.id:none}") String agentDatabaseContext, + AgentClientService agentClientService, AgentExecutionService agentExecutionService, + ZeroTrustClientService zeroTrustClientService, LLMService llmService, VerbRegistry verbRegistry, + AgentClientService agentClientService1 + ) { + super(agentConfigFile, agentDatabaseContext, agentClientService); + this.agentExecutionService = agentExecutionService; + this.zeroTrustClientService = zeroTrustClientService; + this.llmService = llmService; + this.verbRegistry = verbRegistry; + this.agentClientService = agentClientService1; + } + /** * Prompts the agent for workload based on the provided arguments. * @@ -69,28 +88,34 @@ public TerminalResponse interpretUserData( throw new RuntimeException("assessor-config.yaml not found on classpath"); } + String terminalResponse = new String(terminalHelperStream.readAllBytes()); - InputStream is = getStream(agentConfigFile); - if (is == null) { - throw new RuntimeException(agentConfigFile + " not found on classpath"); - } - AgentConfig config = new ObjectMapper(new YAMLFactory()).readValue(is, AgentConfig.class); + AgentConfig config = getAgentConfig(execution); log.info("Agent config loaded: {}", config); PromptBuilder promptBuilder = new PromptBuilder(verbRegistry, config); - var prompt = promptBuilder.buildPrompt(); + var prompt = promptBuilder.buildPrompt(false + ); List messages = new ArrayList<>(); + var context = Message.builder().role("system").content(prompt).build(); + messages.add(context); - messages.add(Message.builder().role("system").content(prompt).build()); + messages.add(Message.builder().role("system").content("You have executed verbs for the previous user " + + "messages. Please generate a user response that summarizes the last message.").build()); + int size = getMessageSize(context); + + var history = getContextWindow(execution.getMessages(), 1024*96 - (size)); + messages.addAll(history); messages.add(Message.builder().role("system").content("Please ensure your nextOperation abides by the " + "following json format and leave it empty if user's request doesn't require explicit use of system " + "operations" + ". Please summarize prior terminal " + "sessions, using " + "terminal output if needed " + - "for clarity of the next LLM request and for the user: " + terminalResponse).build()); + "for clarity of the next LLM request and for the user. Ensure your all future responses meets this " + + "json format (TerminalResponse format): " + terminalResponse).build()); messages.add(Message.builder().role("user").content(userMessage.getContent()).build()); - LLMRequest chatRequest = LLMRequest.builder().model("gpt-4o").messages(messages).build(); + LLMRequest chatRequest = LLMRequest.builder().model("gpt-4o-mini").messages(messages).build(); var resp = llmService.askQuestion(execution, chatRequest); execution.addMessages( messages ); Response response = JsonUtil.MAPPER.readValue(resp, Response.class); @@ -104,60 +129,61 @@ public TerminalResponse interpretUserData( } log.info("content is {}", content); if (null != content && !content.isEmpty()) { - var newResponse = JsonUtil.MAPPER.enable(JsonParser.Feature.ALLOW_COMMENTS).readValue(content, - TerminalResponse.class); - return newResponse; + try { + var newResponse = JsonUtil.MAPPER.enable(JsonParser.Feature.ALLOW_COMMENTS).readValue( + content, + TerminalResponse.class + ); + return newResponse; + }catch (JsonParseException e) { + log.error("Failed to parse terminal response: {}", e.getMessage()); + return TerminalResponse.builder().responseForUser(content).build(); + } } } } else { InputStream terminalHelperStream = getClass().getClassLoader().getResourceAsStream("terminal-helper.json"); if (terminalHelperStream == null) { - - - - - - - - - - - - - - - - - - - - - throw new RuntimeException("assessor-config.yaml not found on classpath"); } String terminalResponse = new String(terminalHelperStream.readAllBytes()); - InputStream is = getStream(agentConfigFile); - if (is == null) { - throw new RuntimeException(agentConfigFile + " not found on classpath"); - } - AgentConfig config = new ObjectMapper(new YAMLFactory()).readValue(is, AgentConfig.class); + + AgentConfig config = getAgentConfig(execution); log.info("Agent config loaded: {}", config); PromptBuilder promptBuilder = new PromptBuilder(verbRegistry, config); - var prompt = promptBuilder.buildPrompt(); + var prompt = promptBuilder.buildPrompt(false); List messages = new ArrayList<>(); + //var context = Message.builder().role("system").content(prompt).build(); + //messages.add(context); + + /* + var listedEndpoints = Message.builder().role("system").content("These are a list of available endpoints, " + + "description," + + " their " + + "name:" + endpointArray).build(); + messages.add(listedEndpoints); + messages.add(Message.builder().role("system").content("You have executed verbs for the previous user " + + "messages. Please generate a user response that summarizes the last message.").build()); + int size = getMessageSize(context) + getMessageSize(listedEndpoints); - messages.add(Message.builder().role("system").content(prompt).build()); - messages.add(Message.builder().role("system").content("Please ensure your nextOperation abide by the " + - "following json format. Please summarize prior terminal sessions, using terminal output if needed " + - "for clarity of the next LLM request and for the user: " + terminalResponse).build()); - messages.add(Message.builder().role("assistant").content("prior response: " + lastMessage.getTerminalSummaryForLLM()).build()); + var history = getContextWindow(execution.getMessages(), 1024*96 - (size)); + messages.addAll(history); + + */ + var history = getContextWindow(execution.getMessages(), 1024*96 ); + messages.addAll(history); + +// messages.add(Message.builder().role("assistant").content("prior response: " + lastMessage +// .getTerminalSummaryForLLM()).build()); + + execution.addMessages( userMessage ); messages.add(Message.builder().role("user").content(userMessage.getContent()).build()); - LLMRequest chatRequest = LLMRequest.builder().model("gpt-4o").messages(messages).build(); + LLMRequest chatRequest = LLMRequest.builder().model("gpt-4o-mini").messages(messages).build(); var resp = llmService.askQuestion(execution, chatRequest); - execution.addMessages( messages ); + Response response = JsonUtil.MAPPER.readValue(resp, Response.class); log.info("Response is {}", resp); for (Response.Choice choice : response.getChoices()) { @@ -169,9 +195,19 @@ public TerminalResponse interpretUserData( } log.info("content is {}", content); if (null != content && !content.isEmpty()) { - var newResponse = JsonUtil.MAPPER.enable(JsonParser.Feature.ALLOW_COMMENTS).readValue(content, - TerminalResponse.class); - return newResponse; + + execution.addMessages( choice.getMessage() ); + try { + var newResponse = JsonUtil.MAPPER.enable(JsonParser.Feature.ALLOW_COMMENTS).readValue( + content, + TerminalResponse.class + ); + return newResponse; + }catch (JsonParseException e) { + log.error("Failed to parse terminal response: {}", e.getMessage()); + return TerminalResponse.builder().responseForUser(content).terminalSummaryForLLM(lastMessage.getTerminalSummaryForLLM()).build(); + } + } } } @@ -179,6 +215,56 @@ public TerminalResponse interpretUserData( return null; } + public List getContextWindow(List allMessages, int maxContextSize) { + List systemMessages = new ArrayList<>(); + List window = new ArrayList<>(); + int totalSize = 0; + + // First: collect system messages (or other required ones) + for (Message msg : allMessages) { + if ("system".equals(msg.role)) { + systemMessages.add(msg); + totalSize += getMessageSize(msg); + } + } + + // If system messages already exceed max context, return only those + if (totalSize >= maxContextSize) { + return systemMessages; + } + + int remainingSize = maxContextSize - totalSize; + + // Then: collect non-system messages from the end, up to remainingSize + ListIterator iter = allMessages.listIterator(allMessages.size()); + while (iter.hasPrevious()) { + Message msg = iter.previous(); + + if ("system".equals(msg.role)) continue; // already added + + int messageSize = getMessageSize(msg); + if (messageSize > remainingSize) break; + + window.add(0, msg); // prepend + remainingSize -= messageSize; + } + + // Combine system + selected recent messages + List result = new ArrayList<>(); + result.addAll(systemMessages); + result.addAll(window); + + return result; + } + + + private int getMessageSize(Message msg) { + int size = 0; + if (msg.role != null) size += msg.role.length(); + if (msg.content != null) size += msg.content.length(); + if (msg.refusal != null) size += msg.refusal.length(); + return size; + } public TerminalResponse interpret_plan_response( AgentExecution execution, @NonNull WebSocky socketConnection, @@ -196,30 +282,44 @@ public TerminalResponse interpret_plan_response( } String terminalResponse = new String(terminalHelperStream.readAllBytes()); - InputStream is = getStream(agentConfigFile); - if (is == null) { - throw new RuntimeException(agentConfigFile + " not found on classpath"); - } - AgentConfig config = new ObjectMapper(new YAMLFactory()).readValue(is, AgentConfig.class); + + AgentConfig config = getAgentConfig(execution); log.info("Agent config loaded: {}", config); PromptBuilder promptBuilder = new PromptBuilder(verbRegistry, config); - var prompt = promptBuilder.buildPrompt(); + var prompt = promptBuilder.buildPrompt(false); List messages = new ArrayList<>(); - messages.add(Message.builder().role("system").content(prompt).build()); - messages.add(Message.builder().role("system").content("You have executed verbs for the previous user " + - "messages. Please generate a user response that summarizes the last message.").build()); - if (null != agentVerb ){ - messages.add(Message.builder().role("system").content("You have executed verb: " + agentVerb.getName() + - " with the following description: " + agentVerb.getDescription()).build()); + if (execution.getMessages().isEmpty()) { + log.info("*** Adding Prompt"); + var context = Message.builder().role("system").content(prompt).build(); + messages.add(context); + + + + messages.add(Message.builder().role("system").content("You have executed verbs for the previous user " + + "messages. Please generate a user response that summarizes the last message. Keep all responses in " + + "TerminalResponse format" + + ".").build()); + } + + + if (null != agentVerb) { + messages.add( + Message.builder().role("system").content("You have executed verb: " + agentVerb.getName() + + " with the following description: " + agentVerb.getDescription()).build()); } - messages.add(Message.builder().role("assistant").content("prior response: " + lastMessage.getTerminalSummaryForLLM()).build()); + + + //messages.add(Message.builder().role("assistant").content("prior response: " + lastMessage + // .getTerminalSummaryForLLM()).build()); messages.add(Message.builder().role("assistant").content(planExecutionOutput).build()); - LLMRequest chatRequest = LLMRequest.builder().model("gpt-4o").messages(messages).build(); + LLMRequest chatRequest = LLMRequest.builder().model("gpt-4o-mini").messages(messages).build(); var resp = llmService.askQuestion(execution, chatRequest); execution.addMessages( messages ); + var history = getContextWindow(execution.getMessages(), 1024*96 ); + messages.addAll(history); Response response = JsonUtil.MAPPER.readValue(resp, Response.class); log.info("Response is {}", resp); for (Response.Choice choice : response.getChoices()) { @@ -231,23 +331,21 @@ public TerminalResponse interpret_plan_response( } log.info("content is {}", content); if (null != content && !content.isEmpty()) { - var newResponse = JsonUtil.MAPPER.enable(JsonParser.Feature.ALLOW_COMMENTS).readValue(content, - TerminalResponse.class); - return newResponse; + try { + var newResponse = JsonUtil.MAPPER.enable(JsonParser.Feature.ALLOW_COMMENTS).readValue( + content, + TerminalResponse.class + ); + return newResponse; + } catch (Exception e){ + return TerminalResponse.builder().responseForUser(content).build(); + } + } } return null; } - private InputStream getStream(String requestedPath) throws IOException { - Path path = Paths.get(requestedPath); // 🔁 Replace with your actual path - if (!Files.exists(path)) { - throw new RuntimeException("File not found at path: " + path.toAbsolutePath()); - } - - return Files.newInputStream(path); - - } } diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/ExampleFactory.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/ExampleFactory.java new file mode 100644 index 00000000..418ec15c --- /dev/null +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/ExampleFactory.java @@ -0,0 +1,41 @@ +package io.sentrius.agent.analysis.agents.verbs; + +import java.util.List; +import java.util.Map; +import io.sentrius.sso.core.dto.agents.AgentContextDTO; +import io.sentrius.sso.core.trust.ATPLPolicy; +import io.sentrius.sso.core.trust.Capability; +import io.sentrius.sso.core.trust.CapabilitySet; + +public class ExampleFactory { + public static Object createExample(String paramName, Class type) { + if (type.equals(Map.class)) { + return Map.of("key", "value"); + } + if (type.equals(List.class)) { + return List.of(Map.of("key", "value")); + } + if (type.equals(String.class)) { + return "{ \"" + paramName + "\" : Example String value\" }"; + } + if (type.equals(AgentContextDTO.class)) { + return AgentContextDTO.builder().context("This is the context for the agent").description("Agent " + + "description").build(); + } + if (type.equals(ATPLPolicy.class)) { + return ATPLPolicy.builder() + .policyId("policy-001") + .description("Example policy") + .version("v0") + .capabilities( + CapabilitySet.builder().primitives( + List.of( + Capability.builder().description("description").endpoints(List.of( + "endpoint1", "endpoint2")).build() + )).build()) + .build(); + } + // fallback + return Map.of("field", "value"); + } +} diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/VerbBase.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/VerbBase.java new file mode 100644 index 00000000..77dda611 --- /dev/null +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/VerbBase.java @@ -0,0 +1,67 @@ +package io.sentrius.agent.analysis.agents.verbs; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import io.sentrius.agent.analysis.agents.agents.AgentConfig; +import io.sentrius.sso.core.dto.agents.AgentContextDTO; +import io.sentrius.sso.core.dto.ztat.AgentExecution; +import io.sentrius.sso.core.exceptions.ZtatException; +import io.sentrius.sso.core.services.agents.AgentClientService; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Value; + +@Slf4j +public abstract class VerbBase { + @Value("${agent.ai.config}") + protected String agentConfigFile; + + @Value("${agent.ai.context.db.id:none}") + protected String agentDatabaseContext; + + + protected final AgentClientService agentClientService; + + protected VerbBase(@Value("${agent.ai.config}") String agentConfigFile, + @Value("${agent.ai.context.db.id:none}") String agentDatabaseContext, + AgentClientService agentClientService) { + this.agentClientService = agentClientService; + this.agentConfigFile = agentConfigFile; + this.agentDatabaseContext = agentDatabaseContext; + } + + protected AgentConfig getAgentConfig(AgentExecution execution) throws IOException, ZtatException { + AgentConfig config = null; + if (agentDatabaseContext != null && !agentDatabaseContext.equals("none")) { + AgentContextDTO agentContext = agentClientService.getAgentContext(execution, + agentDatabaseContext); + config = AgentConfig.builder().description(agentContext.getDescription()) + .context(agentContext.getContext()).build(); + log.info("Agent context loaded: {}", agentContext); + }else { + + InputStream is = getStream(agentConfigFile); + if (is == null) { + throw new RuntimeException(agentConfigFile + " not found on classpath"); + } + + config = new ObjectMapper(new YAMLFactory()).readValue(is, AgentConfig.class); + } + return config; + } + + private InputStream getStream(String requestedPath) throws IOException { + Path path = Paths.get(requestedPath); // 🔁 Replace with your actual path + + if (!Files.exists(path)) { + throw new RuntimeException("File not found at path: " + path.toAbsolutePath()); + } + + return Files.newInputStream(path); + + } +} diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/api/websocket/ChatWSHandler.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/api/websocket/ChatWSHandler.java index 6ff71c11..ad28049f 100644 --- a/ai-agent/src/main/java/io/sentrius/agent/analysis/api/websocket/ChatWSHandler.java +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/api/websocket/ChatWSHandler.java @@ -184,8 +184,10 @@ protected void handleTextMessage(WebSocketSession session, TextMessage message) websocketCommunication, userMessage); log.info("Response: {}", response); var newMessage = Session.ChatMessage.newBuilder() - .setMessage(String.format("{\"type\":\"user-message\",\"message\":\"%s\"}", - response.getResponseForUser())) + .setMessage(response.getResponseForUser()/*String.format("{\"type\":\"user-message\"," + + "\"message\":\"%s\"}", + response.getResponseForUser())*/ + ) .setSender("agent") .setChatGroupId("") .setSessionId(Long.parseLong(websocketCommunication.getSessionId())) @@ -203,40 +205,55 @@ protected void handleTextMessage(WebSocketSession session, TextMessage message) verbRegistry.isVerbRegistered(response.getNextOperation())) { try { + + TerminalResponse nextResponse = null; + var lastVerbResponse = - websocketCommunication.getVerbResponses().stream().reduce((prev, next) -> next) + websocketCommunication.getVerbResponses().stream() + .reduce((prev, next) -> next) .orElse(null); - var executionResponse = verbRegistry.execute( - chatAgent.getAgentExecution(), - lastVerbResponse, - response.getNextOperation(), Maps.newHashMap() - ); - - var nextResponse = chatVerbs.interpret_plan_response( - chatAgent.getAgentExecution(), websocketCommunication, - verbRegistry.getVerbs().get(response.getNextOperation()), - executionResponse.getResponse().toString() - ); - - websocky.get().getMessages().add(nextResponse); - - websocketCommunication.getVerbResponses().add(executionResponse); - - var newNextMessage = Session.ChatMessage.newBuilder() - .setMessage(String.format( - "{\"type\":\"user-message\",\"message\":\"%s\"}", - nextResponse.getResponseForUser() - )) - .setSender("agent") - .setChatGroupId("") - .setSessionId(Long.parseLong(websocketCommunication.getSessionId())) - .setTimestamp(System.currentTimeMillis()) - .build(); - messageBytes = newNextMessage.toByteArray(); - base64Message = Base64.getEncoder().encodeToString(messageBytes); - session.sendMessage(new TextMessage( - base64Message - )); + do { + + var arguments = response.getArguments(); + var executionResponse = verbRegistry.execute( + chatAgent.getAgentExecution(), + lastVerbResponse, + response.getNextOperation(), arguments + ); + +// chatAgent.getAgentExecution().addMessages(Message.builder().role("System") +// .content("System executed operation: " + response.getNextOperation()).build()); + + nextResponse = chatVerbs.interpret_plan_response( + chatAgent.getAgentExecution(), websocketCommunication, + verbRegistry.getVerbs().get(response.getNextOperation()), + executionResponse.getResponse().toString() + ); + + websocky.get().getMessages().add(nextResponse); + + websocketCommunication.getVerbResponses().add(executionResponse); + + var newNextMessage = Session.ChatMessage.newBuilder() + .setMessage( + nextResponse.getResponseForUser() + ) + .setSender("agent") + .setChatGroupId("") + .setSessionId(Long.parseLong(websocketCommunication.getSessionId())) + .setTimestamp(System.currentTimeMillis()) + .build(); + messageBytes = newNextMessage.toByteArray(); + base64Message = Base64.getEncoder().encodeToString(messageBytes); + session.sendMessage(new TextMessage( + base64Message + )); + log.info("Next response: {}", nextResponse.getResponseForUser()); + log.info("Next getNextOperation: {}", nextResponse.getNextOperation()); + log.info("Next getArguments: {}", nextResponse.getArguments()); + lastVerbResponse = executionResponse; + response = nextResponse; + }while (nextResponse.getNextOperation() != null && !nextResponse.getNextOperation().isEmpty()); }catch (Exception e){ e.printStackTrace(); log.error("Error executing next operation: {}", e.getMessage()); diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/model/TerminalResponse.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/model/TerminalResponse.java index 02655a82..5bcacf37 100644 --- a/ai-agent/src/main/java/io/sentrius/agent/analysis/model/TerminalResponse.java +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/model/TerminalResponse.java @@ -2,6 +2,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; import io.sentrius.sso.core.dto.HostSystemDTO; import io.sentrius.sso.genai.Message; import lombok.AllArgsConstructor; @@ -23,4 +24,6 @@ public class TerminalResponse { String nextOperation; String terminalSummaryForLLM; String responseForUser; + @Builder.Default + public Map arguments = Map.of(); } diff --git a/ai-agent/src/main/java/io/sentrius/agent/config/SecurityConfig.java b/ai-agent/src/main/java/io/sentrius/agent/config/SecurityConfig.java index 398de64d..ff352101 100644 --- a/ai-agent/src/main/java/io/sentrius/agent/config/SecurityConfig.java +++ b/ai-agent/src/main/java/io/sentrius/agent/config/SecurityConfig.java @@ -66,9 +66,7 @@ public JwtAuthenticationConverter jwtAuthenticationConverterForKeycloak() { log.info("**** Initializing JwtAuthenticationConverter"); converter.setJwtGrantedAuthoritiesConverter(jwt -> { - log.info("**** Jwt Authentication Converter invoked"); Collection authorities = new JwtGrantedAuthoritiesConverter().convert(jwt); - log.info("JWT Claims: {}", jwt.getClaims()); String userId = jwt.getClaimAsString("sub"); String username = jwt.getClaimAsString("preferred_username"); diff --git a/ai-agent/src/main/resources/atpl-schema.json b/ai-agent/src/main/resources/atpl-schema.json new file mode 100644 index 00000000..14839924 --- /dev/null +++ b/ai-agent/src/main/resources/atpl-schema.json @@ -0,0 +1,199 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://sentrius.io/schemas/atpl.schema.json", + "title": "Agent Trust Policy Language (ATPL)", + "type": "object", + "required": ["version", "policy_id", "capabilities"], + "properties": { + "version": { + "type": "string", + "pattern": "^v[0-9]+$" + }, + "policy_id": { + "type": "string" + }, + "description": { + "type": "string" + }, + "match": { + "type": "object", + "properties": { + "agent_tags": { + "type": "array", + "items": { "type": "string" } + } + } + }, + "identity": { + "type": "object", + "properties": { + "issuer": { "type": "string" }, + "subject_prefix": { "type": "string" }, + "mfa_required": { "type": "boolean" }, + "certificate_authority": { "type": "string" } + } + }, + "provenance": { + "type": "object", + "properties": { + "source": { "type": "string" }, + "signature_required": { "type": "boolean" }, + "approved_committers": { + "type": "array", + "items": { "type": "string" } + } + } + }, + "runtime": { + "type": "object", + "properties": { + "allow_drift": { "type": "boolean" }, + "enclave_required": { "type": "boolean" }, + "attestation_type": { "type": "string" }, + "verified_at_boot": { "type": "boolean" } + } + }, + "behavior": { + "type": "object", + "properties": { + "minimum_positive_runs": { "type": "integer" }, + "max_incidents": { "type": "integer" }, + "incident_types": { + "type": "object", + "properties": { + "denylist": { + "type": "array", + "items": { "type": "string" } + } + } + } + } + }, + "trust_score": { + "type": "object", + "properties": { + "minimum": { "type": "integer" }, + "weightings": { + "type": "object", + "properties": { + "identity": { "type": "number" }, + "provenance": { "type": "number" }, + "runtime": { "type": "number" }, + "behavior": { "type": "number" } + } + } + } + }, + "capabilities": { + "type": "object", + "required": ["primitives", "composed"], + "properties": { + "primitives": { + "type": "array", + "items": { + "type": "object", + "required": ["id", "description"], + "properties": { + "id": { "type": "string" }, + "description": { "type": "string" }, + "endpoints": { + "type": "array", + "items": { "type": "string" } + }, + "commands": { + "type": "array", + "items": { "type": "string" } + }, + "subcommands": { + "type": "array", + "items": { "type": "string" } + }, + "activities": { + "type": "array", + "items": { "type": "string" } + }, + "tags": { + "type": "array", + "items": { "type": "string" } + } + } + } + }, + "composed": { + "type": "array", + "items": { + "type": "object", + "required": ["id", "includes"], + "properties": { + "id": { "type": "string" }, + "includes": { + "type": "array", + "items": { "type": "string" } + }, + "tags": { + "type": "array", + "items": { "type": "string" } + }, + "endpoints": { + "type": "array", + "items": { "type": "string" } + }, + "commands": { + "type": "array", + "items": { "type": "string" } + }, + "subcommands": { + "type": "array", + "items": { "type": "string" } + }, + "activities": { + "type": "array", + "items": { "type": "string" } + } + } + } + } + } + } + , + "actions": { + "type": "object", + "properties": { + "on_failure": { + "type": "string", + "enum": ["deny", "log", "alert"] + }, + "on_success": { + "type": "string", + "enum": ["allow", "log"] + }, + "on_marginal": { + "oneOf": [ + {"type": "string", "enum": ["require_ztat", "log"]}, + { + "type": "object", + "properties": { + "action": {"type": "string", "enum": ["require_ztat"]}, + "ztat_provider": {"type": "string"} + }, + "required": ["action"] + } + ] + } + } + }, + "ztat": { + "type": "object", + "properties": { + "provider": {"type": "string"}, + "ttl": {"type": "string"}, + "approved_issuers": { + "type": "array", + "items": {"type": "string"} + }, + "key_binding": {"type": "string"}, + "approval_required": {"type": "boolean"} + } + } + } +} diff --git a/ai-agent/src/main/resources/chat-helper.json b/ai-agent/src/main/resources/chat-helper.json index 7bc969b9..bffb7174 100644 --- a/ai-agent/src/main/resources/chat-helper.json +++ b/ai-agent/src/main/resources/chat-helper.json @@ -1,6 +1,10 @@ { "previousOperation": "", "nextOperation": "", + "arguments": { + "argumentname": "", + "argumentname2": "" + }, "terminalSummaryForLLM": "", "responseForUser": "" } \ No newline at end of file diff --git a/ai-agent/src/main/resources/terminal-helper.json b/ai-agent/src/main/resources/terminal-helper.json index 7bc969b9..02be454b 100644 --- a/ai-agent/src/main/resources/terminal-helper.json +++ b/ai-agent/src/main/resources/terminal-helper.json @@ -1,6 +1,10 @@ { "previousOperation": "", "nextOperation": "", + "arguments": { + "argumentname": "", + "argumentname2": "" + }, "terminalSummaryForLLM": "", - "responseForUser": "" + "responseForUser": "" } \ No newline at end of file diff --git a/ai-agent/src/test/java/io/sentrius/sentrius/analysis/agents/agents/PromptBuilderTest.java b/ai-agent/src/test/java/io/sentrius/sentrius/analysis/agents/agents/PromptBuilderTest.java index 2fc5c663..341a6895 100644 --- a/ai-agent/src/test/java/io/sentrius/sentrius/analysis/agents/agents/PromptBuilderTest.java +++ b/ai-agent/src/test/java/io/sentrius/sentrius/analysis/agents/agents/PromptBuilderTest.java @@ -51,7 +51,7 @@ void buildPromptIncludesAvailableVerbs() { String result = promptBuilder.buildPrompt(); - assertTrue(result.contains("Available Verbs:")); + assertTrue(result.contains("Verb operations:")); assertTrue(result.contains("- verbName (")); assertTrue(result.contains("Description of the verb")); } @@ -72,7 +72,7 @@ void buildPromptHandlesNoAvailableVerbs() { String result = promptBuilder.buildPrompt(); - assertTrue(result.contains("Available Verbs:")); + assertTrue(result.contains("Verb operations:")); assertFalse(result.contains("- ")); } } \ No newline at end of file diff --git a/analytics/src/main/java/io/sentrius/agent/config/SecurityConfig.java b/analytics/src/main/java/io/sentrius/agent/config/SecurityConfig.java index f42d2155..c6ef7f12 100644 --- a/analytics/src/main/java/io/sentrius/agent/config/SecurityConfig.java +++ b/analytics/src/main/java/io/sentrius/agent/config/SecurityConfig.java @@ -73,9 +73,9 @@ public JwtAuthenticationConverter jwtAuthenticationConverterForKeycloak() { log.info("**** Initializing JwtAuthenticationConverter"); converter.setJwtGrantedAuthoritiesConverter(jwt -> { - log.info("**** Jwt Authentication Converter invoked"); + Collection authorities = new JwtGrantedAuthoritiesConverter().convert(jwt); - log.info("JWT Claims: {}", jwt.getClaims()); + String userId = jwt.getClaimAsString("sub"); String username = jwt.getClaimAsString("preferred_username"); diff --git a/api/src/main/java/io/sentrius/sso/config/SecurityConfig.java b/api/src/main/java/io/sentrius/sso/config/SecurityConfig.java index 843482f0..bef57d79 100644 --- a/api/src/main/java/io/sentrius/sso/config/SecurityConfig.java +++ b/api/src/main/java/io/sentrius/sso/config/SecurityConfig.java @@ -78,9 +78,7 @@ public JwtAuthenticationConverter jwtAuthenticationConverterForKeycloak() { log.info("**** Initializing JwtAuthenticationConverter"); converter.setJwtGrantedAuthoritiesConverter(jwt -> { - log.info("**** Jwt Authentication Converter invoked"); Collection authorities = new JwtGrantedAuthoritiesConverter().convert(jwt); - log.info("JWT Claims: {}", jwt.getClaims()); String userId = jwt.getClaimAsString("sub"); String username = jwt.getClaimAsString("preferred_username"); diff --git a/api/src/main/java/io/sentrius/sso/controllers/api/ATPLChatController.java b/api/src/main/java/io/sentrius/sso/controllers/api/ATPLChatController.java new file mode 100644 index 00000000..45524e57 --- /dev/null +++ b/api/src/main/java/io/sentrius/sso/controllers/api/ATPLChatController.java @@ -0,0 +1,206 @@ +package io.sentrius.sso.controllers.api; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import io.sentrius.sso.core.annotations.LimitAccess; +import io.sentrius.sso.core.config.SystemOptions; +import io.sentrius.sso.core.controllers.BaseController; +import io.sentrius.sso.core.model.security.enums.ApplicationAccessEnum; +import io.sentrius.sso.core.services.ATPLPolicyService; +import io.sentrius.sso.core.services.ErrorOutputService; +import io.sentrius.sso.core.services.UserService; +import io.sentrius.sso.core.trust.ATPLPolicy; +import io.sentrius.sso.services.ATPLChatService; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.*; + +@Slf4j +@RestController +@RequestMapping("/api/v1/atpl/chat") +public class ATPLChatController extends BaseController { + + private final ATPLChatService atplChatService; + private final ATPLPolicyService atplPolicyService; + private final ObjectMapper objectMapper; + + // Store conversation sessions (in production, use Redis or database) + private final Map> chatSessions = new HashMap<>(); + + public ATPLChatController( + UserService userService, + SystemOptions systemOptions, + ErrorOutputService errorOutputService, + ATPLChatService atplChatService, + ATPLPolicyService atplPolicyService, + ObjectMapper objectMapper + ) { + super(userService, systemOptions, errorOutputService); + this.atplChatService = atplChatService; + this.atplPolicyService = atplPolicyService; + this.objectMapper = objectMapper; + } + + @PostMapping("/message") + @LimitAccess(applicationAccess = {ApplicationAccessEnum.CAN_MANAGE_APPLICATION}) + public ResponseEntity processMessage( + @RequestBody Map request, + HttpServletRequest httpRequest, + HttpServletResponse httpResponse + ) { + try { + String sessionId = (String) request.get("sessionId"); + String message = (String) request.get("message"); + + if (sessionId == null || message == null) { + return ResponseEntity.badRequest().body(createErrorResponse("Session ID and message are required")); + } + + // Get or create session context + Map sessionContext = chatSessions.computeIfAbsent(sessionId, k -> new HashMap<>()); + + // Process the message with the ATPL chat service + String response = atplChatService.processATPLChatMessage(message, sessionContext); + + // Update session context with conversation history + updateSessionContext(sessionContext, message, response); + + ObjectNode result = objectMapper.createObjectNode(); + result.put("response", response); + result.put("sessionId", sessionId); + + return ResponseEntity.ok(result); + + } catch (Exception e) { + log.error("Error processing ATPL chat message: {}", e.getMessage(), e); + return ResponseEntity.internalServerError().body(createErrorResponse("Failed to process message")); + } + } + + @PostMapping("/generate-policy") + @LimitAccess(applicationAccess = {ApplicationAccessEnum.CAN_MANAGE_APPLICATION}) + public ResponseEntity generatePolicy( + @RequestBody Map request, + HttpServletRequest httpRequest, + HttpServletResponse httpResponse + ) { + try { + String sessionId = (String) request.get("sessionId"); + Map sessionContext = chatSessions.get(sessionId); + + if (sessionContext == null) { + return ResponseEntity.badRequest().body(createErrorResponse("Session not found")); + } + + ObjectNode policyNode = atplChatService.generateATPLPolicy(sessionContext); + + ObjectNode result = objectMapper.createObjectNode(); + result.set("policy", policyNode); + result.put("sessionId", sessionId); + + return ResponseEntity.ok(result); + + } catch (Exception e) { + log.error("Error generating ATPL policy: {}", e.getMessage(), e); + return ResponseEntity.internalServerError().body(createErrorResponse("Failed to generate policy")); + } + } + + @GetMapping("/suggestions") + @LimitAccess(applicationAccess = {ApplicationAccessEnum.CAN_MANAGE_APPLICATION}) + public ResponseEntity getSuggestions( + @RequestParam String description, + HttpServletRequest httpRequest, + HttpServletResponse httpResponse + ) { + try { + List suggestions = atplChatService.suggestCapabilities(description); + + ObjectNode result = objectMapper.createObjectNode(); + ArrayNode suggestionsArray = objectMapper.createArrayNode(); + suggestions.forEach(suggestionsArray::add); + result.set("suggestions", suggestionsArray); + + return ResponseEntity.ok(result); + + } catch (Exception e) { + log.error("Error getting suggestions: {}", e.getMessage(), e); + return ResponseEntity.internalServerError().body(createErrorResponse("Failed to get suggestions")); + } + } + + @GetMapping("/existing-policies") + @LimitAccess(applicationAccess = {ApplicationAccessEnum.CAN_MANAGE_APPLICATION}) + public ResponseEntity getExistingPolicies( + HttpServletRequest httpRequest, + HttpServletResponse httpResponse + ) { + try { + List policies = atplPolicyService.getAllPolicies(); + + ObjectNode result = objectMapper.createObjectNode(); + ArrayNode policiesArray = objectMapper.createArrayNode(); + + for (ATPLPolicy policy : policies) { + ObjectNode policyNode = objectMapper.createObjectNode(); + policyNode.put("id", policy.getPolicyId()); + policyNode.put("description", policy.getDescription()); + policyNode.put("version", policy.getVersion()); + + if (policy.getCapabilities() != null && policy.getCapabilities().getPrimitives() != null) { + ArrayNode capabilitiesArray = objectMapper.createArrayNode(); + policy.getCapabilities().getPrimitives().forEach(cap -> { + ObjectNode capNode = objectMapper.createObjectNode(); + capNode.put("id", cap.getId()); + capNode.put("description", cap.getDescription()); + capabilitiesArray.add(capNode); + }); + policyNode.set("capabilities", capabilitiesArray); + } + + policiesArray.add(policyNode); + } + + result.set("policies", policiesArray); + return ResponseEntity.ok(result); + + } catch (Exception e) { + log.error("Error getting existing policies: {}", e.getMessage(), e); + return ResponseEntity.internalServerError().body(createErrorResponse("Failed to get existing policies")); + } + } + + private void updateSessionContext(Map sessionContext, String userMessage, String agentResponse) { + // Store conversation history + @SuppressWarnings("unchecked") + List> history = (List>) sessionContext.computeIfAbsent("conversation_history", k -> new java.util.ArrayList<>()); + + Map userEntry = new HashMap<>(); + userEntry.put("type", "user"); + userEntry.put("message", userMessage); + history.add(userEntry); + + Map agentEntry = new HashMap<>(); + agentEntry.put("type", "assistant"); + agentEntry.put("message", agentResponse); + history.add(agentEntry); + + // Keep only last 10 exchanges + if (history.size() > 20) { + history.subList(0, history.size() - 20).clear(); + } + } + + private ObjectNode createErrorResponse(String message) { + ObjectNode error = objectMapper.createObjectNode(); + error.put("error", message); + return error; + } +} \ No newline at end of file diff --git a/api/src/main/java/io/sentrius/sso/controllers/api/AgentApiController.java b/api/src/main/java/io/sentrius/sso/controllers/api/AgentApiController.java index 40a303e3..21551a15 100644 --- a/api/src/main/java/io/sentrius/sso/controllers/api/AgentApiController.java +++ b/api/src/main/java/io/sentrius/sso/controllers/api/AgentApiController.java @@ -15,11 +15,16 @@ import java.util.concurrent.ExecutionException; import com.fasterxml.jackson.core.JsonProcessingException; import io.sentrius.sso.config.ApiPaths; +import io.sentrius.sso.config.AppConfig; import io.sentrius.sso.core.annotations.LimitAccess; import io.sentrius.sso.core.config.SystemOptions; import io.sentrius.sso.core.controllers.BaseController; import io.sentrius.sso.core.dto.AgentCommunicationDTO; +import io.sentrius.sso.core.dto.AgentDTO; import io.sentrius.sso.core.dto.AgentHeartbeatDTO; +import io.sentrius.sso.core.dto.agents.AgentContextDTO; +import io.sentrius.sso.core.dto.agents.AgentContextRequestDTO; +import io.sentrius.sso.core.exceptions.ZtatException; import io.sentrius.sso.core.model.chat.AgentCommunication; import io.sentrius.sso.core.model.security.enums.IdentityType; import io.sentrius.sso.core.model.security.UserType; @@ -32,6 +37,8 @@ import io.sentrius.sso.core.services.ATPLPolicyService; import io.sentrius.sso.core.services.ErrorOutputService; import io.sentrius.sso.core.services.UserService; +import io.sentrius.sso.core.services.agents.AgentClientService; +import io.sentrius.sso.core.services.agents.AgentContextService; import io.sentrius.sso.core.services.agents.AgentService; import io.sentrius.sso.core.services.auditing.AuditService; import io.sentrius.sso.core.services.security.CryptoService; @@ -54,6 +61,7 @@ import org.springframework.format.annotation.DateTimeFormat; import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestHeader; @@ -75,7 +83,10 @@ public class AgentApiController extends BaseController { final ZeroTrustRequestService ztrService; final AgentService agentService; final ProvenanceKafkaProducer provenanceKafkaProducer; - private final ZeroTrustRequestService ztatRequestService; + final ZeroTrustRequestService ztatRequestService; + final AgentContextService agentContextService; + final AgentClientService agentClientService; + final AppConfig appConfig; public AgentApiController( UserService userService, @@ -85,7 +96,8 @@ public AgentApiController( CryptoService cryptoService, SessionTrackingService sessionTrackingService, KeycloakService keycloakService, ATPLPolicyService atplPolicyService, ZeroTrustAccessTokenService ztatService, ZeroTrustRequestService ztrService, AgentService agentService, - ProvenanceKafkaProducer provenanceKafkaProducer, ZeroTrustRequestService ztatRequestService + ProvenanceKafkaProducer provenanceKafkaProducer, ZeroTrustRequestService ztatRequestService, + AgentContextService agentContextService, AgentClientService agentClientService, AppConfig appConfig ) { super(userService, systemOptions, errorOutputService); this.auditService = auditService; @@ -98,6 +110,9 @@ public AgentApiController( this.agentService = agentService; this.provenanceKafkaProducer = provenanceKafkaProducer; this.ztatRequestService = ztatRequestService; + this.agentContextService = agentContextService; + this.agentClientService = agentClientService; + this.appConfig = appConfig; } public SessionLog createSession(@RequestParam String username, @RequestParam String ipAddress) { @@ -282,7 +297,7 @@ public ResponseEntity createAgentChatRequest( @GetMapping("/list") @LimitAccess(applicationAccess = {ApplicationAccessEnum.CAN_MANAGE_APPLICATION}) - public ResponseEntity listAgents(HttpServletRequest request, HttpServletResponse response){ + public ResponseEntity listAgents(HttpServletRequest request, HttpServletResponse response) throws ZtatException { var operatingUser = getOperatingUser(request, response ); if (null == operatingUser) { log.warn("No operating user found"); @@ -290,7 +305,28 @@ public ResponseEntity listAgents(HttpServletRequest request, HttpServletRespo } log.info("Received list request from user: {} {}", operatingUser.getUsername(), operatingUser); var agents = agentService.getAllAgents(true); - return ResponseEntity.ok(agents); + + List prunedAgentList = agents.stream().filter(agent -> { + try { + if (null == agent.getAgentName() || agent.getAgentName().isEmpty()) { + log.info("Agent {} has no name, removing from list", agent.getAgentId()); + return false; + } + String podResponse = + agentClientService.getAgentPodStatus(appConfig.getSentriusLauncherService(), agent.getAgentName()); + if (podResponse != null && (podResponse.equalsIgnoreCase("running") || podResponse.equalsIgnoreCase("pending"))){ + return true; + } else { + log.info("Agent {} is not running or pending, removing from list. Status is {}", + agent.getAgentName(), podResponse); + } + } catch (ZtatException ignored) { + + } + return false; + } + ).toList(); + return ResponseEntity.ok(prunedAgentList); } @@ -754,4 +790,38 @@ private boolean validateUser(User requestor, User operatingUser, AgentCommunicat return canSend; } + @GetMapping("/context/{contextId}") + @LimitAccess(applicationAccess = {ApplicationAccessEnum.CAN_MANAGE_APPLICATION}) + public ResponseEntity getContext( + HttpServletRequest request, + HttpServletResponse response, + @PathVariable("contextId") String contextId){ + var databaseContext = agentContextService.getContextOrThrow(UUID.fromString(contextId)); + return ResponseEntity.ok(AgentContextDTO.builder() + .id(databaseContext.getId()) + .name(databaseContext.getName()) + .description(databaseContext.getDescription()) + .context(databaseContext.getContext()) + .createdAt(databaseContext.getCreatedAt()) + .updatedAt(databaseContext.getUpdatedAt()) + .build()); + } + + @PostMapping("/context") + @LimitAccess(applicationAccess = {ApplicationAccessEnum.CAN_MANAGE_APPLICATION}) + public ResponseEntity createContext( + HttpServletRequest request, + HttpServletResponse response, + @RequestBody AgentContextRequestDTO dtoRequest){ + var databaseContext = agentContextService.create(dtoRequest); + return ResponseEntity.ok(AgentContextDTO.builder() + .id(databaseContext.getId()) + .name(databaseContext.getName()) + .description(databaseContext.getDescription()) + .context(databaseContext.getContext()) + .createdAt(databaseContext.getCreatedAt()) + .updatedAt(databaseContext.getUpdatedAt()) + .build()); + } + } diff --git a/api/src/main/java/io/sentrius/sso/controllers/api/AgentBootstrapController.java b/api/src/main/java/io/sentrius/sso/controllers/api/AgentBootstrapController.java index f8ecdc3f..98041cb6 100644 --- a/api/src/main/java/io/sentrius/sso/controllers/api/AgentBootstrapController.java +++ b/api/src/main/java/io/sentrius/sso/controllers/api/AgentBootstrapController.java @@ -40,6 +40,7 @@ import org.springframework.beans.factory.annotation.Value; import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.DeleteMapping; +import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestHeader; @@ -62,6 +63,7 @@ public class AgentBootstrapController extends BaseController { private final ZeroTrustClientService zeroTrustClientService; private final ObjectMapper yamlMapper = new ObjectMapper(new YAMLFactory()); final AppConfig appConfig; + private final AgentClientService agentClientService; public AgentBootstrapController( @@ -72,7 +74,8 @@ public AgentBootstrapController( CryptoService cryptoService, SessionTrackingService sessionTrackingService, KeycloakService keycloakService, ATPLPolicyService atplPolicyService, ZeroTrustAccessTokenService ztatService, ZeroTrustRequestService ztrService, AgentService agentService, - ZeroTrustClientService zeroTrustClientService, AppConfig appConfig + ZeroTrustClientService zeroTrustClientService, AppConfig appConfig, + AgentClientService agentClientService ) { super(userService, systemOptions, errorOutputService); this.auditService = auditService; @@ -85,6 +88,7 @@ public AgentBootstrapController( this.agentService = agentService; this.zeroTrustClientService = zeroTrustClientService; this.appConfig = appConfig; + this.agentClientService = agentClientService; } @@ -168,7 +172,26 @@ public ResponseEntity launchPod( @RequestBody AgentRegistrationDTO registrationDTO, HttpServletRequest request, HttpServletResponse response ) throws GeneralSecurityException, IOException, ZtatException { + try{ + log.info("Launching agent pod with ID: {}", registrationDTO.getAgentName()); + var status = getAgentStatus( registrationDTO.getAgentName(), request, response); + if ( status != null ) { + var body = status.getBody(); + if (body != null) { + + if (body.contains("Running") || body.contains("Pending")) { + log.info("Agent {} is already running or pending", registrationDTO.getAgentName()); + return ResponseEntity.ok("{\"status\": \"already exists\"}"); + } else { + log.warn("Agent {} is not running, attempting to launch again", registrationDTO.getAgentName()); + } + } + } + } catch (Exception e) { + log.error("Error getting agent status", e); + + } var operatingUser = getOperatingUser(request, response ); zeroTrustClientService.callAuthenticatedPostOnApi(appConfig.getSentriusLauncherService(), "agent/launcher/create", registrationDTO); @@ -192,6 +215,17 @@ public ResponseEntity deletePod( } + @GetMapping("/launcher/status") + @LimitAccess(applicationAccess = {ApplicationAccessEnum.CAN_MANAGE_APPLICATION}) + public ResponseEntity getAgentStatus( + @RequestParam("agentId") String agentId, HttpServletRequest request, HttpServletResponse response + ) throws GeneralSecurityException, IOException, ZtatException { + + String podResponse = + agentClientService.getAgentPodStatus(appConfig.getSentriusLauncherService(), agentId); + // bootstrap with a default policy + return ResponseEntity.ok("{\"status\": \"" + podResponse + "\"}"); + } diff --git a/api/src/main/java/io/sentrius/sso/controllers/api/HostApiController.java b/api/src/main/java/io/sentrius/sso/controllers/api/HostApiController.java index 8306c824..de7b96a5 100644 --- a/api/src/main/java/io/sentrius/sso/controllers/api/HostApiController.java +++ b/api/src/main/java/io/sentrius/sso/controllers/api/HostApiController.java @@ -196,7 +196,7 @@ public ResponseEntity connectSSHServer(HttpServletRequest request, H if (enclaveId == null || hostId == null) { return ResponseEntity.badRequest().build(); } - if (systemOptions.getSshEnabled() == false){ + if (systemOptions.getLockdownEnabled() == true){ node.put("sessionId",""); node.put("errorToUser","SSH is disabled"); return ResponseEntity.ok(node); diff --git a/api/src/main/java/io/sentrius/sso/controllers/api/SystemApiController.java b/api/src/main/java/io/sentrius/sso/controllers/api/SystemApiController.java index ccedfb67..26721f4e 100644 --- a/api/src/main/java/io/sentrius/sso/controllers/api/SystemApiController.java +++ b/api/src/main/java/io/sentrius/sso/controllers/api/SystemApiController.java @@ -73,20 +73,20 @@ protected SystemApiController( this.configurationApplicationTask = configurationApplicationTask; } - @GetMapping("/settings/sshEnabled") - public ResponseEntity getSSHEnabled() { + @GetMapping("/settings/lockdownEnabled") + public ResponseEntity getLockdownEnabled() { ObjectNode node = JsonUtil.MAPPER.createObjectNode(); - node.put("sshEnabled", systemOptions.getSshEnabled()); + node.put("lockdownEnabled", systemOptions.getLockdownEnabled()); return ResponseEntity.ok(node); } - @PutMapping("/settings/ssh/toggle") + @PutMapping("/settings/lockdown/toggle") @LimitAccess(applicationAccess ={ ApplicationAccessEnum.CAN_MANAGE_APPLICATION}) - public ResponseEntity toggleSSHEnabled() { - log.info("Toggling SSH enabled"); + public ResponseEntity toggleLockdown() { + log.info("Toggling Lockdown enabled"); ObjectNode node = JsonUtil.MAPPER.createObjectNode(); - systemOptions.setValue("sshEnabled", !systemOptions.getSshEnabled()); - node.put("sshEnabled", systemOptions.getSshEnabled()); + systemOptions.setValue("lockdownEnabled", !systemOptions.getLockdownEnabled()); + node.put("lockdownEnabled", systemOptions.getLockdownEnabled()); return ResponseEntity.ok(node); } diff --git a/api/src/main/java/io/sentrius/sso/controllers/view/ATPLConfigController.java b/api/src/main/java/io/sentrius/sso/controllers/view/ATPLConfigController.java index 58b0414e..618f5241 100644 --- a/api/src/main/java/io/sentrius/sso/controllers/view/ATPLConfigController.java +++ b/api/src/main/java/io/sentrius/sso/controllers/view/ATPLConfigController.java @@ -64,4 +64,10 @@ public String configurePage(Model model, @RequestParam(name= "id", required=fals } return "sso/atpl/configure"; } + + @GetMapping("/chat") + @LimitAccess(applicationAccess = {ApplicationAccessEnum.CAN_MANAGE_APPLICATION}) + public String chatPage(Model model) { + return "sso/atpl/chat"; + } } \ No newline at end of file diff --git a/api/src/main/java/io/sentrius/sso/controllers/view/AgentController.java b/api/src/main/java/io/sentrius/sso/controllers/view/AgentController.java index 58bcf2b9..3670cdd1 100644 --- a/api/src/main/java/io/sentrius/sso/controllers/view/AgentController.java +++ b/api/src/main/java/io/sentrius/sso/controllers/view/AgentController.java @@ -48,6 +48,12 @@ public String listAgents(Model m) { return "sso/agents/list_agents"; } + @GetMapping("/design/chat") + @LimitAccess(applicationAccess = {ApplicationAccessEnum.CAN_MANAGE_APPLICATION}) + public String designAgent(Model m) { + return "sso/agents/design_chat"; + } + @GetMapping("/connections") @LimitAccess(applicationAccess = {ApplicationAccessEnum.CAN_MANAGE_APPLICATION}) public String listConnections(Model m, @RequestParam("agentId") String agentId) throws GeneralSecurityException { @@ -59,4 +65,5 @@ public String listConnections(Model m, @RequestParam("agentId") String agentId) return "sso/agents/agent_comms"; } + } diff --git a/api/src/main/java/io/sentrius/sso/services/ATPLChatService.java b/api/src/main/java/io/sentrius/sso/services/ATPLChatService.java new file mode 100644 index 00000000..851f2cdd --- /dev/null +++ b/api/src/main/java/io/sentrius/sso/services/ATPLChatService.java @@ -0,0 +1,262 @@ +package io.sentrius.sso.services; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import io.sentrius.sso.core.services.ATPLPolicyService; +import io.sentrius.sso.core.trust.ATPLPolicy; +import io.sentrius.sso.core.trust.CapabilitySet; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +@Slf4j +@Service +@RequiredArgsConstructor +public class ATPLChatService { + + private final ATPLPolicyService atplPolicyService; + private final ObjectMapper objectMapper; + + public String processATPLChatMessage(String userMessage, Map context) { + String lowerMessage = userMessage.toLowerCase(); + + // Simple rule-based responses + if (lowerMessage.contains("start") || lowerMessage.contains("begin") || lowerMessage.contains("new")) { + return "I'll help you create a new ATPL policy. Let's start by understanding what your agent needs to do. " + + "What are the main capabilities or functions your agent should have? For example:\n" + + "• File operations (read, write, delete)\n" + + "• System monitoring (CPU, memory, disk)\n" + + "• Network communications\n" + + "• Database access\n" + + "• External API calls"; + } + + if (lowerMessage.contains("endpoint")) { + List existingPolicies = atplPolicyService.getAllPolicies(); + StringBuilder response = new StringBuilder(); + response.append("For endpoints, consider what APIs or services your agent needs to access. "); + + if (!existingPolicies.isEmpty()) { + response.append("Here are some endpoints from existing policies:\n"); + existingPolicies.forEach(policy -> { + if (policy.getCapabilities() != null && policy.getCapabilities().getPrimitives() != null) { + policy.getCapabilities().getPrimitives().forEach(cap -> { + if (cap.getEndpoints() != null && !cap.getEndpoints().isEmpty()) { + response.append("• ").append(String.join(", ", cap.getEndpoints())).append("\n"); + } + }); + } + }); + } else { + response.append("Common examples include:\n"); + response.append("• /api/v1/data/read - for reading data\n"); + response.append("• /api/v1/system/status - for system status\n"); + response.append("• /api/v1/execute - for executing operations\n"); + } + + response.append("\nWhat specific endpoints does your agent need to access?"); + return response.toString(); + } + + if (lowerMessage.contains("command")) { + return "For commands, think about what system commands your agent might need to execute. " + + "Examples include:\n" + + "• File operations: ls, cat, grep, find, cp, mv\n" + + "• System monitoring: ps, top, df, netstat, ss\n" + + "• Process management: systemctl, service, kill\n" + + "• Network operations: ping, wget, curl\n" + + "• Data processing: awk, sed, sort, uniq\n\n" + + "What commands should your agent be able to run?"; + } + + if (lowerMessage.contains("activity") || lowerMessage.contains("activities")) { + return "Activities define what your agent can do at a high level. Examples include:\n" + + "• file_operations - reading, writing, managing files\n" + + "• system_monitoring - checking system health and metrics\n" + + "• data_processing - transforming, analyzing data\n" + + "• network_access - making network requests\n" + + "• user_management - managing user accounts\n" + + "• service_management - starting, stopping services\n\n" + + "What activities should your agent perform?"; + } + + if (lowerMessage.contains("existing") || lowerMessage.contains("policies")) { + List existingPolicies = atplPolicyService.getAllPolicies(); + if (existingPolicies.isEmpty()) { + return "No existing ATPL policies found. Would you like to create your first policy?"; + } + + StringBuilder response = new StringBuilder(); + response.append("Here are the existing ATPL policies:\n\n"); + + existingPolicies.forEach(policy -> { + response.append("**").append(policy.getPolicyId()).append("**\n"); + response.append("Version: ").append(policy.getVersion()).append("\n"); + if (policy.getDescription() != null) { + response.append("Description: ").append(policy.getDescription()).append("\n"); + } + + if (policy.getCapabilities() != null && policy.getCapabilities().getPrimitives() != null) { + response.append("Capabilities:\n"); + policy.getCapabilities().getPrimitives().forEach(cap -> { + response.append("• ").append(cap.getId()).append(": ").append(cap.getDescription()).append("\n"); + }); + } + response.append("\n"); + }); + + response.append("Would you like to create a new policy or modify an existing one?"); + return response.toString(); + } + + if (lowerMessage.contains("help") || lowerMessage.contains("what")) { + return "I'm here to help you create ATPL (Agent Trust Policy Language) policies. " + + "I can assist with:\n\n" + + "• **Policy Structure** - Understanding the basic components\n" + + "• **Endpoints** - Defining what APIs your agent can access\n" + + "• **Commands** - Specifying what system commands are allowed\n" + + "• **Activities** - Defining high-level capabilities\n" + + "• **Security** - Ensuring appropriate access controls\n\n" + + "To get started, tell me about your agent's purpose or ask about any specific aspect!"; + } + + if (lowerMessage.contains("security") || lowerMessage.contains("access")) { + return "Security is crucial for ATPL policies. Consider these aspects:\n\n" + + "• **Principle of Least Privilege** - Only grant necessary permissions\n" + + "• **Risk Assessment** - Tag high-risk capabilities appropriately\n" + + "• **Endpoint Validation** - Ensure endpoints are legitimate and necessary\n" + + "• **Command Restrictions** - Avoid dangerous commands like rm -rf, sudo\n" + + "• **Activity Boundaries** - Clearly define what activities are allowed\n\n" + + "What security considerations do you have for your agent?"; + } + + // Default response + return "I understand you want to configure an ATPL policy. I can help you define:\n" + + "• **Endpoints** - API endpoints your agent can access\n" + + "• **Commands** - System commands your agent can execute\n" + + "• **Activities** - High-level capabilities your agent needs\n\n" + + "What would you like to configure first? Or tell me more about what your agent needs to do."; + } + + public ObjectNode generateATPLPolicy(Map configuration) { + try { + ObjectNode policyNode = objectMapper.createObjectNode(); + + // Basic information + policyNode.put("version", "v0"); + policyNode.put("policy_id", (String) configuration.getOrDefault("policy_id", "generated_policy_" + System.currentTimeMillis())); + policyNode.put("description", (String) configuration.getOrDefault("description", "Generated ATPL policy from chat session")); + + // Capabilities + ObjectNode capabilitiesNode = objectMapper.createObjectNode(); + ArrayNode primitivesArray = objectMapper.createArrayNode(); + + // Extract capabilities from session context + @SuppressWarnings("unchecked") + List> capabilities = (List>) configuration.get("capabilities"); + + if (capabilities != null) { + for (Map cap : capabilities) { + ObjectNode capNode = objectMapper.createObjectNode(); + capNode.put("id", (String) cap.get("id")); + capNode.put("description", (String) cap.get("description")); + + // Add endpoints + if (cap.containsKey("endpoints")) { + ArrayNode endpointsArray = objectMapper.createArrayNode(); + @SuppressWarnings("unchecked") + List endpoints = (List) cap.get("endpoints"); + endpoints.forEach(endpointsArray::add); + capNode.set("endpoints", endpointsArray); + } + + // Add commands + if (cap.containsKey("commands")) { + ArrayNode commandsArray = objectMapper.createArrayNode(); + @SuppressWarnings("unchecked") + List commands = (List) cap.get("commands"); + commands.forEach(commandsArray::add); + capNode.set("commands", commandsArray); + } + + // Add activities + if (cap.containsKey("activities")) { + ArrayNode activitiesArray = objectMapper.createArrayNode(); + @SuppressWarnings("unchecked") + List activities = (List) cap.get("activities"); + activities.forEach(activitiesArray::add); + capNode.set("activities", activitiesArray); + } + + primitivesArray.add(capNode); + } + } else { + // Create default capability based on conversation + ObjectNode defaultCap = objectMapper.createObjectNode(); + defaultCap.put("id", "basic_access"); + defaultCap.put("description", "Basic agent access capability"); + + ArrayNode endpoints = objectMapper.createArrayNode(); + endpoints.add("/api/v1/status"); + defaultCap.set("endpoints", endpoints); + + ArrayNode commands = objectMapper.createArrayNode(); + commands.add("ls"); + commands.add("ps"); + defaultCap.set("commands", commands); + + ArrayNode activities = objectMapper.createArrayNode(); + activities.add("monitoring"); + defaultCap.set("activities", activities); + + primitivesArray.add(defaultCap); + } + + capabilitiesNode.set("primitives", primitivesArray); + policyNode.set("capabilities", capabilitiesNode); + + return policyNode; + + } catch (Exception e) { + log.error("Error generating ATPL policy: {}", e.getMessage(), e); + return objectMapper.createObjectNode(); + } + } + + public List suggestCapabilities(String userDescription) { + List suggestions = new ArrayList<>(); + + String description = userDescription.toLowerCase(); + + if (description.contains("read") || description.contains("view") || description.contains("get")) { + suggestions.add("read_access"); + } + if (description.contains("write") || description.contains("modify") || description.contains("update")) { + suggestions.add("write_access"); + } + if (description.contains("execute") || description.contains("run") || description.contains("command")) { + suggestions.add("execute_access"); + } + if (description.contains("admin") || description.contains("manage") || description.contains("control")) { + suggestions.add("admin_access"); + } + if (description.contains("network") || description.contains("connection") || description.contains("socket")) { + suggestions.add("network_access"); + } + if (description.contains("file") || description.contains("disk") || description.contains("storage")) { + suggestions.add("filesystem_access"); + } + if (description.contains("monitor") || description.contains("watch") || description.contains("observe")) { + suggestions.add("monitoring_access"); + } + if (description.contains("database") || description.contains("db") || description.contains("sql")) { + suggestions.add("database_access"); + } + + return suggestions; + } +} \ No newline at end of file diff --git a/api/src/main/resources/db/migration/V18__agent_launch.sql b/api/src/main/resources/db/migration/V18__agent_launch.sql new file mode 100644 index 00000000..63694ba6 --- /dev/null +++ b/api/src/main/resources/db/migration/V18__agent_launch.sql @@ -0,0 +1,17 @@ +CREATE TABLE agent_contexts ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + name TEXT NOT NULL, + description TEXT, + context TEXT NOT NULL, -- YAML or any other string format + created_at TIMESTAMP WITH TIME ZONE DEFAULT now(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT now() +); + +CREATE TABLE agent_launches ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + agent_id TEXT NOT NULL, -- e.g., name or UUID of the launched agent + context_id UUID NOT NULL REFERENCES agent_contexts(id), + launched_by TEXT, -- who or what initiated it + launch_parameters TEXT, -- optional overrides or launch args + created_at TIMESTAMP WITH TIME ZONE DEFAULT now() +); \ No newline at end of file diff --git a/api/src/main/resources/static/js/add_system.js b/api/src/main/resources/static/js/add_system.js index 3752b900..0ae07c5e 100644 --- a/api/src/main/resources/static/js/add_system.js +++ b/api/src/main/resources/static/js/add_system.js @@ -9,14 +9,14 @@ document.addEventListener('DOMContentLoaded', function () { if (disableSSHButton) { - fetch(`/api/v1/system/settings/sshEnabled`) + fetch(`/api/v1/system/settings/lockdownEnabled`) .then(response => response.json()) .then(data => { - if (data.sshEnabled) { - disableSSHButton.innerText = 'Disable SSH'; + if (data.lockdownEnabled) { + disableSSHButton.innerText = 'LockDown Systems'; } else { - disableSSHButton.innerText = 'Enable SSH'; + disableSSHButton.innerText = 'Re-enable Systems'; } }) .catch(error => { @@ -25,7 +25,7 @@ document.addEventListener('DOMContentLoaded', function () { document.getElementById('disable-ssh-button').addEventListener('click', function(event) { event.preventDefault(); // Prevent the default anchor behavior const csrfToken = document.getElementById('csrf-token').value; // Get CSRF token value - fetch('/api/v1/system/settings/ssh/toggle', { + fetch('/api/v1/system/settings/lockdown/toggle', { method: 'PUT', // Specify PUT request headers: { 'Content-Type': 'application/json', // Optional, adjust based on your API @@ -34,10 +34,10 @@ document.addEventListener('DOMContentLoaded', function () { }).then(response => response.json()) .then(data => { if (data.sshEnabled) { - disableSSHButton.innerText = 'Disable SSH'; + disableSSHButton.innerText = 'LockDown Systems'; } else { - disableSSHButton.innerText = 'Enable SSH'; + disableSSHButton.innerText = 'Re-enable Systems'; } }) .catch(error => { diff --git a/api/src/main/resources/static/js/chat.js b/api/src/main/resources/static/js/chat.js index 64ca585c..aa47e530 100644 --- a/api/src/main/resources/static/js/chat.js +++ b/api/src/main/resources/static/js/chat.js @@ -3,18 +3,20 @@ // ========================= const chatSessions = new Map(); // key: agentId, value: ChatSession -window.addEventListener("beforeunload", persistChatSessions); +//window.addEventListener("beforeunload", persistChatSessions); // Restore on page load -(function restoreSessions() { +(function restoreSesions() { const saved = localStorage.getItem("openChats"); if (!saved) return; - const chatData = JSON.parse(saved); + /*const chatData = JSON.parse(saved); for (const [agentId, data] of Object.entries(chatData)) { const session = new ChatSession(data.agentName, data.agentId, data.sessionId, data.agentHost, data.messages); chatSessions.set(agentId, session); } + */ + })(); // ========================= @@ -281,11 +283,10 @@ export function switchToAgent(agentName,agentId, sessionId, agentHost) { } export function sendMessage(event) { + + if (event && event.key && event.key !== 'Enter') return; + console.log("Send message event:", event); - if (event.key !== "Enter"){ - console.log("Key pressed is not Enter, ignoring."); - return; - } const input = document.getElementById("chat-input"); const messageText = input.value.trim(); @@ -384,4 +385,7 @@ export async function fetchAvailableAgents() { } window.sendMessage = sendMessage; -window.toggleChat = toggleChat; \ No newline at end of file +window.toggleChat = toggleChat; +window.fetchAvailableAgents = fetchAvailableAgents; +window.chatSessions = chatSessions; +window.ChatSession = ChatSession; \ No newline at end of file diff --git a/api/src/main/resources/templates/sso/agents/design_chat.html b/api/src/main/resources/templates/sso/agents/design_chat.html new file mode 100644 index 00000000..af69acc8 --- /dev/null +++ b/api/src/main/resources/templates/sso/agents/design_chat.html @@ -0,0 +1,674 @@ + + + + + [[${systemOptions.systemLogoName}]] - Agent Designer + + + + + +
+
+ + +
+
+
+ +
+
+

Agent Designer

+

Let me help you create an agent

+
+ +
+
+ Welcome! I'm here to help you create an agent. + What kind of agent are you looking to configure? +
+
+ +
+
+ + +
+
+ +
+
Assistant is thinking... +
+
+ +
+ + +
+
+
+
+ + + + \ No newline at end of file diff --git a/api/src/main/resources/templates/sso/atpl/chat.html b/api/src/main/resources/templates/sso/atpl/chat.html new file mode 100644 index 00000000..20670b65 --- /dev/null +++ b/api/src/main/resources/templates/sso/atpl/chat.html @@ -0,0 +1,651 @@ + + + + + [[${systemOptions.systemLogoName}]] - ATPL Chat Configuration + + + + + +
+
+ + +
+
+
+ +
+
+

ATPL Configuration Assistant

+

Let me help you create a comprehensive Agent Trust Policy Language configuration

+
+ +
+
+ Welcome! I'm here to help you create an ATPL policy. + What kind of agent are you looking to configure? +
+
+ +
+
+ + + + +
+ +
+ + +
+
+ +
+
Assistant is thinking... +
+
+ +
+ + + +
+
+
+
+ + + + \ No newline at end of file diff --git a/api/src/main/resources/templates/sso/atpl/list.html b/api/src/main/resources/templates/sso/atpl/list.html index 00c1ebe8..3b7fee2c 100644 --- a/api/src/main/resources/templates/sso/atpl/list.html +++ b/api/src/main/resources/templates/sso/atpl/list.html @@ -72,7 +72,14 @@
diff --git a/api/src/main/resources/templates/sso/dashboard.html b/api/src/main/resources/templates/sso/dashboard.html index 60006ff0..adc32699 100755 --- a/api/src/main/resources/templates/sso/dashboard.html +++ b/api/src/main/resources/templates/sso/dashboard.html @@ -351,168 +351,168 @@ }); -