diff --git a/.local.env b/.local.env index e11552ac..b2826b1c 100644 --- a/.local.env +++ b/.local.env @@ -6,4 +6,4 @@ SENTRIUS_AI_AGENT_VERSION=1.1.264 LLMPROXY_VERSION=1.0.78 LAUNCHER_VERSION=1.0.82 AGENTPROXY_VERSION=1.0.85 -SSHPROXY_VERSION=1.0.87 \ No newline at end of file +SSHPROXY_VERSION=1.0.88 diff --git a/.local.env.bak b/.local.env.bak index e11552ac..ea42dddc 100644 --- a/.local.env.bak +++ b/.local.env.bak @@ -6,4 +6,4 @@ SENTRIUS_AI_AGENT_VERSION=1.1.264 LLMPROXY_VERSION=1.0.78 LAUNCHER_VERSION=1.0.82 AGENTPROXY_VERSION=1.0.85 -SSHPROXY_VERSION=1.0.87 \ No newline at end of file +SSHPROXY_VERSION=1.0.87 diff --git a/agent-launcher/src/main/java/io/sentrius/agent/launcher/service/PodLauncherService.java b/agent-launcher/src/main/java/io/sentrius/agent/launcher/service/PodLauncherService.java index 4146c842..0c2ce70d 100644 --- a/agent-launcher/src/main/java/io/sentrius/agent/launcher/service/PodLauncherService.java +++ b/agent-launcher/src/main/java/io/sentrius/agent/launcher/service/PodLauncherService.java @@ -178,6 +178,7 @@ public V1Pod launchAgentPod(AgentRegistrationDTO agent) throws Exception { List argList = new ArrayList<>(); argList.add("--spring.config.location=file:/config/agent.properties"); argList.add("--agent.namePrefix=" + agentId); + argList.add("--agent.type=" + agent.getAgentType()); argList.add("--agent.clientId=" + agent.getClientId()); argList.add("--agent.listen.websocket=true"); argList.add("--agent.callback.url=" + constructedCallbackUrl); diff --git a/agent-launcher/src/main/resources/application.properties b/agent-launcher/src/main/resources/application.properties index 6a83c490..6a493796 100644 --- a/agent-launcher/src/main/resources/application.properties +++ b/agent-launcher/src/main/resources/application.properties @@ -8,4 +8,4 @@ spring.main.web-application-type=servlet spring.thymeleaf.enabled=true spring.freemarker.enabled=false -sentrius.agent.registry=local \ No newline at end of file +sentrius.agent.registry=local diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/BaseEnterpriseAgent.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/BaseEnterpriseAgent.java new file mode 100644 index 00000000..980834fe --- /dev/null +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/BaseEnterpriseAgent.java @@ -0,0 +1,70 @@ +package io.sentrius.agent.analysis.agents.agents; + +import java.util.concurrent.TimeUnit; +import com.fasterxml.jackson.databind.node.ArrayNode; +import io.sentrius.agent.analysis.agents.verbs.AgentVerbs; +import io.sentrius.sso.core.dto.agents.AgentExecution; +import io.sentrius.sso.core.dto.agents.AgentExecutionContextDTO; +import io.sentrius.sso.core.dto.ztat.ZtatRequestDTO; +import io.sentrius.sso.core.exceptions.ZtatException; +import io.sentrius.sso.core.services.agents.AgentClientService; +import io.sentrius.sso.core.services.agents.ZeroTrustClientService; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.context.event.ApplicationReadyEvent; +import org.springframework.context.ApplicationListener; + +@Slf4j +public abstract class BaseEnterpriseAgent implements ApplicationListener { + + @Autowired + protected final AgentVerbs agentVerbs; + @Autowired + protected final ZeroTrustClientService zeroTrustClientService; + @Autowired + protected final AgentClientService agentClientService; + @Autowired + protected final VerbRegistry verbRegistry; + + protected BaseEnterpriseAgent( + AgentVerbs agentVerbs, ZeroTrustClientService zeroTrustClientService, AgentClientService agentClientService, + VerbRegistry verbRegistry + ) { + this.agentVerbs = agentVerbs; + this.zeroTrustClientService = zeroTrustClientService; + this.agentClientService = agentClientService; + this.verbRegistry = verbRegistry; + } + + + protected ArrayNode promptAgent(AgentExecution execution) throws ZtatException { + return promptAgent(execution); + } + + protected ArrayNode promptAgent(AgentExecution execution, AgentExecutionContextDTO contextDTO) throws ZtatException { + while(true){ + try { + log.info("Prompting agent..."); + return agentVerbs.promptAgent(execution,contextDTO); + } catch (ZtatException e) { + log.info("Mechanisms {}" , e.getMechanisms()); + var endpoint = zeroTrustClientService.createEndPointRequest("prompt_agent", e.getEndpoint()); + ZtatRequestDTO ztatRequestDTO = ZtatRequestDTO.builder() + .user(execution.getUser()) + .command(endpoint.toString()) + .justification("Registered Agent requires ability to prompt LLM endpoints to begin operations") + .summary("Registered Agent requires ability to prompt LLM endpoints to begin operations") + .build(); + var request = zeroTrustClientService.requestZtatToken(execution, execution.getUser(),ztatRequestDTO); + + var token = zeroTrustClientService.awaitZtatToken(execution, execution.getUser(), request, 60, + TimeUnit.MINUTES); + execution.setZtatToken(token); + } catch (Exception e) { + log.error(e.getMessage()); + throw new RuntimeException(e); + } + } + } +} diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/ChatAgent.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/ChatAgent.java index 5a01b496..073f4dc3 100644 --- a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/ChatAgent.java +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/ChatAgent.java @@ -1,77 +1,76 @@ package io.sentrius.agent.analysis.agents.agents; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.UUID; -import java.util.concurrent.TimeUnit; import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.node.ArrayNode; import io.sentrius.agent.analysis.agents.verbs.AgentVerbs; +import io.sentrius.agent.analysis.agents.verbs.ChatVerbs; import io.sentrius.agent.analysis.api.AgentKeyService; import io.sentrius.agent.analysis.api.UserCommunicationService; +import io.sentrius.agent.analysis.model.LLMResponse; import io.sentrius.agent.config.AgentConfigOptions; import io.sentrius.sso.core.dto.UserDTO; import io.sentrius.sso.core.dto.agents.AgentExecution; -import io.sentrius.sso.core.dto.ztat.ZtatRequestDTO; +import io.sentrius.sso.core.dto.agents.AgentExecutionContextDTO; import io.sentrius.sso.core.exceptions.ZtatException; import io.sentrius.sso.core.model.security.Ztat; +import io.sentrius.sso.core.model.verbs.VerbResponse; import io.sentrius.sso.core.services.agents.AgentClientService; import io.sentrius.sso.core.services.agents.AgentExecutionService; import io.sentrius.sso.core.services.agents.ZeroTrustClientService; import io.sentrius.sso.core.services.security.KeycloakService; import io.sentrius.sso.core.utils.JsonUtil; +import io.sentrius.sso.genai.Message; import jakarta.annotation.PreDestroy; -import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.event.ApplicationReadyEvent; -import org.springframework.context.ApplicationListener; import org.springframework.stereotype.Component; @Slf4j @Component -@RequiredArgsConstructor @ConditionalOnProperty(name = "agents.ai.chat.agent.enabled", havingValue = "true", matchIfMissing = false) -public class ChatAgent implements ApplicationListener { +public class ChatAgent extends BaseEnterpriseAgent { final ZeroTrustClientService zeroTrustClientService; final AgentClientService agentClientService; final VerbRegistry verbRegistry; - final AgentVerbs agentVerbs; final AgentExecutionService agentExecutionService; final UserCommunicationService userCommunicationService; final AgentConfigOptions agentConfigOptions; final AgentKeyService agentKeyService; private final KeycloakService keycloakService; + final ChatVerbs chatVerbs; private volatile boolean running = true; private Thread workerThread; private AgentExecution agentExecution; - public ArrayNode promptAgent(AgentExecution execution) throws ZtatException { - while(true){ - try { - log.info("Prompting agent..."); - return agentVerbs.promptAgent(execution,null); - } catch (ZtatException e) { - log.info("Mechanisms {}" , e.getMechanisms()); - var endpoint = zeroTrustClientService.createEndPointRequest("prompt_agent", e.getEndpoint()); - ZtatRequestDTO ztatRequestDTO = ZtatRequestDTO.builder() - .user(execution.getUser()) - .command(endpoint.toString()) - .justification("Registered Agent requires ability to prompt LLM endpoints to begin operations") - .summary("Registered Agent requires ability to prompt LLM endpoints to begin operations") - .build(); - var request = zeroTrustClientService.requestZtatToken(execution, execution.getUser(),ztatRequestDTO); - - var token = zeroTrustClientService.awaitZtatToken(execution, execution.getUser(), request, 60, - TimeUnit.MINUTES); - execution.setZtatToken(token); - } catch (Exception e) { - log.error(e.getMessage()); - throw new RuntimeException(e); - } - } + + @Autowired + public ChatAgent( + AgentVerbs agentVerbs, ZeroTrustClientService zeroTrustClientService, AgentClientService agentClientService, + VerbRegistry verbRegistry, AgentExecutionService agentExecutionService, UserCommunicationService userCommunicationService, + AgentConfigOptions agentConfigOptions, AgentKeyService agentKeyService, KeycloakService keycloakService, + ChatVerbs chatVerbs + ) { + super(agentVerbs, zeroTrustClientService, agentClientService, verbRegistry); + this.zeroTrustClientService = zeroTrustClientService; + this.agentClientService = agentClientService; + this.verbRegistry = verbRegistry; + this.agentExecutionService = agentExecutionService; + this.userCommunicationService = userCommunicationService; + this.agentConfigOptions = agentConfigOptions; + this.agentKeyService = agentKeyService; + this.keycloakService = keycloakService; + this.chatVerbs = chatVerbs; } @Override @@ -146,6 +145,38 @@ public void onApplicationEvent(final ApplicationReadyEvent event) { int allowedFailures = 20; log.info("Agent Registered..."); + AgentExecutionContextDTO agentExecutionContext = AgentExecutionContextDTO.builder().build(); + agentExecutionService.setExecutionContextDTO(agentExecution, agentExecutionContext); + LLMResponse response = null; + AgentConfig config = null; + try { + config = chatVerbs.getAgentConfig(agentExecution); + } catch (IOException e) { + throw new RuntimeException(e); + } catch (ZtatException e) { + throw new RuntimeException(e); + } + PromptBuilder promptBuilder = new PromptBuilder(verbRegistry, config); + var prompt = promptBuilder.buildPrompt(false); + try { + if (agentConfigOptions.getType().equalsIgnoreCase("chat-autonomous")) { + + response = chatVerbs.promptAgent(agentExecution, agentExecutionContext, prompt); + } + } catch (ZtatException e) { + throw new RuntimeException(e); + } catch (IOException e) { + throw new RuntimeException(e); + } + + + if (agentConfigOptions.getType().equalsIgnoreCase("chat-autonomous") && response == null) { + log.error("Chat autonomous agent mode enabled but no response received from promptAgent, shutting down..."); + throw new RuntimeException("Chat autonomous agent mode enabled but no response received from promptAgent"); + } + VerbResponse lastVerbResponse = null; + LLMResponse nextResponse = null; + List verbResponses = new ArrayList<>(); while(running) { @@ -153,8 +184,70 @@ public void onApplicationEvent(final ApplicationReadyEvent event) { Thread.sleep(5_000); agentClientService.heartbeat(agentExecution, agentExecution.getUser().getUsername()); + if (agentConfigOptions.getType().equalsIgnoreCase("chat-autonomous")) { + log.info("Chat autonomous agent mode enabled, executing workload..."); + VerbResponse priorResponse = null; + Map args = new HashMap<>(); + + var arguments = response.getArguments(); + if (null != response) { + if (response.getNextOperation() != null && !response.getNextOperation().isEmpty()) { + var executionResponse = verbRegistry.execute( + agentExecution, + agentExecutionContext, + lastVerbResponse, + response.getNextOperation(), arguments + ); + verbResponses.add(executionResponse); + lastVerbResponse = executionResponse; + + +// chatAgent.getAgentExecution().addMessages(Message.builder().role("System") +// .content("System executed operation: " + response.getNextOperation()).build()); + var responses = agentExecutionContext.getAgentDataList(); + var planResponse = + responses.isEmpty() ? "" : + responses.get(responses.size() - 1).asText(); + nextResponse = chatVerbs.interpret_plan_response( + agentExecution, + agentExecutionContext, + verbRegistry.getVerbs().get(response.getNextOperation()), + planResponse + ); + + var memory = agentExecutionContext.flushPersistentMemory(); + if (memory != null) { + for(var memoryEntry : memory.entrySet()){ + agentClientService.storeMemory(agentExecution, + agentExecutionContext.getAgentContext().getName(), + io.sentrius.sso.core.dto.agents.AgentMemoryDTO.builder() + .agentName(agentExecutionContext.getAgentContext().getName()) + .memoryKey(memoryEntry.getKey()) + .memoryValue(memoryEntry.getValue().toString()) + .build()); + } + } + + + response = nextResponse; + } + + }else { + response = chatVerbs.promptAgent(agentExecution, agentExecutionContext, prompt); + + } + + continue; + } allowedFailures = 20; // Reset allowed failures on successful heartbeat } catch (ZtatException | Exception ex) { + agentExecutionContext.addMessages(Message.builder().role("system").content( + "You caused the following error. Please re-validate you chose the right operations or " + + "endpoints for the context" + + ex.getMessage()).build()); + + + ex.printStackTrace(); if (allowedFailures-- <= 0) { log.error("Failed to heartbeat agent after multiple attempts, shutting down..."); throw new RuntimeException(ex); diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/RegisteredAgent.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/RegisteredAgent.java index a9fbb122..2db49cee 100644 --- a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/RegisteredAgent.java +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/agents/RegisteredAgent.java @@ -9,6 +9,7 @@ import io.sentrius.agent.config.AgentConfigOptions; import io.sentrius.sso.core.dto.agents.AgentExecution; import io.sentrius.sso.core.dto.agents.AgentExecutionContextDTO; +import io.sentrius.sso.core.dto.agents.AgentMemoryDTO; import io.sentrius.sso.core.dto.ztat.ZtatRequestDTO; import io.sentrius.sso.core.exceptions.ZtatException; import io.sentrius.sso.core.model.security.Ztat; @@ -22,6 +23,7 @@ import jakarta.annotation.PreDestroy; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.event.ApplicationReadyEvent; import org.springframework.context.ApplicationListener; @@ -29,15 +31,10 @@ @Slf4j @Component -@RequiredArgsConstructor @ConditionalOnProperty(name = "agents.ai.registered.agent.enabled", havingValue = "true", matchIfMissing = false) -public class RegisteredAgent implements ApplicationListener { +public class RegisteredAgent extends BaseEnterpriseAgent { - final ZeroTrustClientService zeroTrustClientService; - final AgentClientService agentClientService; - final VerbRegistry verbRegistry; - final AgentVerbs agentVerbs; final AgentExecutionService agentExecutionService; final AgentConfigOptions agentConfigOptions; final AgentKeyService agentKeyService; @@ -46,6 +43,19 @@ public class RegisteredAgent implements ApplicationListener verbs = new HashMap<>(); private final Map instances = new HashMap<>(); + private static final String [] AGENT_MARKINGS = new String[] {"SENTRIUS_INTERNAL"}; + private final AgentEndpointDiscoveryService agentEndpointDiscoveryService; private List endpoints = new ArrayList<>(); public void scanEndpoints(AgentExecution execution) throws ZtatException, JsonProcessingException { synchronized (this) { - var endpoints = agentClientService.getAvailableEndpoints(execution); + endpointRegistry.loadEndpoints(execution); + var endpoints = endpointRegistry.getAll(); log.info("Scanning endpoints for verbs..."); var verbs = agentClientService.getAvailableVerbs(execution); @@ -183,7 +190,9 @@ public VerbResponse execute(AgentExecution agentExecution, // add the output if (null != thisVerb.getReturnName() && !thisVerb.getReturnName().isEmpty()) { contextDTO.addToMemory(thisVerb.getReturnName(), execNode); + contextDTO.addToPersistentMemory(thisVerb.getReturnName(), execNode, "VERB", AGENT_MARKINGS); } else { + contextDTO.addToPersistentMemory(verb, execNode, "VERB", AGENT_MARKINGS); contextDTO.addToMemory(verb, execNode); } @@ -219,8 +228,10 @@ public VerbResponse execute(AgentExecution agentExecution, JsonNode execNode = JsonUtil.MAPPER.valueToTree(exec); // add the output if (null != thisVerb.getReturnName() && !thisVerb.getReturnName().isEmpty()) { + contextDTO.addToPersistentMemory(thisVerb.getReturnName(), execNode, "VERB", AGENT_MARKINGS); contextDTO.addToMemory(thisVerb.getReturnName(), execNode); } else { + contextDTO.addToPersistentMemory(verb, execNode, "VERB", AGENT_MARKINGS); contextDTO.addToMemory(verb, execNode); } diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/AgentVerbs.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/AgentVerbs.java index e853b5c5..796e6bee 100644 --- a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/AgentVerbs.java +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/AgentVerbs.java @@ -22,7 +22,6 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; -import com.fasterxml.jackson.databind.node.TextNode; import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; import com.google.common.collect.Maps; import com.google.common.collect.Sets; @@ -31,8 +30,11 @@ import io.sentrius.agent.analysis.agents.agents.VerbRegistry; import io.sentrius.agent.analysis.model.AssessedTerminal; import io.sentrius.agent.analysis.model.Assessment; +import io.sentrius.agent.analysis.model.LLMResponse; import io.sentrius.agent.analysis.model.ZtatAsessment; import io.sentrius.agent.analysis.model.ZtatResponse; +import io.sentrius.agent.services.EndpointRegistry; +import io.sentrius.agent.services.EndpointSearcher; import io.sentrius.sso.core.dto.AgentCommunicationDTO; import io.sentrius.sso.core.dto.AgentRegistrationDTO; import io.sentrius.sso.core.dto.ZtatDTO; @@ -53,6 +55,7 @@ import io.sentrius.sso.core.trust.Capability; import io.sentrius.sso.core.trust.CapabilitySet; import io.sentrius.sso.core.utils.JsonUtil; +import io.sentrius.sso.core.utils.ListUtils; import io.sentrius.sso.genai.Message; import io.sentrius.sso.genai.Response; import io.sentrius.sso.genai.model.LLMRequest; @@ -72,9 +75,8 @@ public class AgentVerbs extends VerbBase { final ZeroTrustClientService zeroTrustClientService; final LLMService llmService; final VerbRegistry verbRegistry; - - - + final EndpointRegistry endpointRegistry; + final EndpointSearcher endpointSearcher; final ObjectMapper mapper = new ObjectMapper(new YAMLFactory()); // Jackson ObjectMapper for YAML parsing private final AgentExecutionService agentExecutionService; @@ -83,20 +85,24 @@ public class AgentVerbs extends VerbBase { * Constructs an `AgentVerbs` instance with the required services and registry. * * @param zeroTrustClientService The service for Zero Trust client interactions. - * @param llmService The service for interacting with the LLM (Large Language Model). - * @param verbRegistry The registry containing available verbs and their metadata. - * @throws com.fasterxml.jackson.core.JsonProcessingException If there is an error processing JSON during initialization. + * @param llmService The service for interacting with the LLM (Large Language Model). + * @param verbRegistry The registry containing available verbs and their metadata. + * @throws com.fasterxml.jackson.core.JsonProcessingException If there is an error processing JSON during + * initialization. */ - public AgentVerbs( @Value("${agent.ai.config}") String agentConfigFile, - @Value("${agent.ai.context.db.id:none}") String agentDatabaseContext, - ZeroTrustClientService zeroTrustClientService, LLMService llmService, VerbRegistry verbRegistry, - AgentClientService agentService, - AgentExecutionService agentExecutionService + public AgentVerbs( + @Value("${agent.ai.config}") String agentConfigFile, + @Value("${agent.ai.context.db.id:none}") String agentDatabaseContext, + ZeroTrustClientService zeroTrustClientService, LLMService llmService, VerbRegistry verbRegistry, + AgentClientService agentService, EndpointRegistry endpointRegistry, EndpointSearcher endpointSearcher, + AgentExecutionService agentExecutionService ) throws JsonProcessingException { super(agentConfigFile, agentDatabaseContext, agentService); this.zeroTrustClientService = zeroTrustClientService; this.llmService = llmService; this.verbRegistry = verbRegistry; + this.endpointRegistry = endpointRegistry; + this.endpointSearcher = endpointSearcher; log.info("Loading agent config from {}", agentConfigFile); this.agentExecutionService = agentExecutionService; @@ -107,11 +113,13 @@ public AgentVerbs( @Value("${agent.ai.config}") String agentConfigFile, * * @return An `ArrayNode` containing the plan generated by the agent. * @throws io.sentrius.sso.core.exceptions.ZtatException If there is an error during the operation. - * @throws java.io.IOException If there is an error reading the configuration file. + * @throws java.io.IOException If there is an error reading the configuration file. */ - @Verb(name = "prompt_agent", returnType = ArrayNode.class, description = "Prompts for agent workload.", - isAiCallable = false, requiresTokenManagement = true) - public ArrayNode promptAgent(AgentExecution execution,AgentExecutionContextDTO context) throws ZtatException, + @Verb( + name = "prompt_agent", returnType = ArrayNode.class, description = "Prompts for agent workload.", + isAiCallable = false, requiresTokenManagement = true + ) + public ArrayNode promptAgent(AgentExecution execution, AgentExecutionContextDTO context) throws ZtatException, IOException { AgentConfig config = getAgentConfig(execution); @@ -125,7 +133,9 @@ public ArrayNode promptAgent(AgentExecution execution,AgentExecutionContextDTO c LLMRequest chatRequest = LLMRequest.builder().model("gpt-4o-mini").messages(messages).build(); var resp = llmService.askQuestion(execution, chatRequest); - context.addMessages( messages ); + if (null != context ) { + context.addMessages(messages); + } Response response = JsonUtil.MAPPER.readValue(resp, Response.class); //log.info("Response is {}", resp); for (Response.Choice choice : response.getChoices()) { @@ -151,27 +161,29 @@ public ArrayNode promptAgent(AgentExecution execution,AgentExecutionContextDTO c } - /** * Chats with an agent to justify operations based on the provided arguments. * * @return A string response from the agent. * @throws io.sentrius.sso.core.exceptions.ZtatException If there is an error during the operation. - * @throws java.io.IOException If there is an error reading the configuration file. + * @throws java.io.IOException If there is an error reading the configuration file. */ - @Verb(name = "justify_operations", description = "Chats with an agent to justify operations.", isAiCallable = - false, requiresTokenManagement = true) - public String justifyAgent(AgentExecution execution, AgentExecutionContextDTO context, ZtatRequestDTO ztatRequest, - AssessedTerminal reason) throws ZtatException, + @Verb( + name = "justify_operations", description = "Chats with an agent to justify operations.", isAiCallable = + false, requiresTokenManagement = true + ) + public String justifyAgent( + AgentExecution execution, AgentExecutionContextDTO context, ZtatRequestDTO ztatRequest, + AssessedTerminal reason + ) throws ZtatException, IOException, InterruptedException, TimeoutException { - - var status = zeroTrustClientService.getTokenStatus(execution, execution.getUser(), ztatRequest.getRequestId()); - log.info("Status: {} for {} ", status, ztatRequest); - if ("approved".equals(status.get("status").asText())) { - return status.get("ztat_token").asText(); - } + var status = zeroTrustClientService.getTokenStatus(execution, execution.getUser(), ztatRequest.getRequestId()); + log.info("Status: {} for {} ", status, ztatRequest); + if ("approved".equals(status.get("status").asText())) { + return status.get("ztat_token").asText(); + } InputStream assessZtatStream = getClass().getClassLoader().getResourceAsStream("respond-ztat.json"); if (assessZtatStream == null) { @@ -179,93 +191,97 @@ public String justifyAgent(AgentExecution execution, AgentExecutionContextDTO co } AtatRequest atat = - AtatRequest.builder().requestId(ztatRequest.getRequestId()).requestedAction(ztatRequest.getCommand()).build(); + AtatRequest.builder().requestId(ztatRequest.getRequestId()).requestedAction(ztatRequest.getCommand()) + .build(); String respondZtat = new String(assessZtatStream.readAllBytes()); - while(!status.equals("approved")) { + while (!status.equals("approved")) { - Thread.sleep(5_000); + Thread.sleep(5_000); - status = zeroTrustClientService.getTokenStatus(execution, execution.getUser(), ztatRequest.getRequestId()); - log.info("Status: {} for {} ", status, ztatRequest); - if ("approved".equals(status.get("status").asText())) { - return status.get("ztat_token").asText(); - } + status = zeroTrustClientService.getTokenStatus(execution, execution.getUser(), ztatRequest.getRequestId()); + log.info("Status: {} for {} ", status, ztatRequest); + if ("approved".equals(status.get("status").asText())) { + return status.get("ztat_token").asText(); + } - Set commsIds = agentClientService.getCommunicationIds(execution, ztatRequest); + Set commsIds = agentClientService.getCommunicationIds(execution, ztatRequest); - commsIds.remove(execution.getCommunicationId()); + commsIds.remove(execution.getCommunicationId()); - if (commsIds.isEmpty()) { - continue; - } + if (commsIds.isEmpty()) { + continue; + } - if (commsIds.size() > 1) { - // get the first one - log.info("CommsIds is {}", commsIds); - } + if (commsIds.size() > 1) { + // get the first one + log.info("CommsIds is {}", commsIds); + } - var commsId = commsIds.iterator().next(); - - AgentExecution newExecution = - AgentExecution.builder().executionId(execution.getExecutionId()).ztatToken(execution.getZtatToken()).communicationId(commsId).build(); - - var nextMessaged = agentClientService.getResponse(newExecution, ztatRequest, 1, TimeUnit.MINUTES); - Set otherAgents = Sets.newHashSet(); - Set communicationIds = new HashSet<>(); - if (!nextMessaged.isEmpty()) { - List messages = new ArrayList<>(); - messages.add(Message.builder().role("system").content("The following messages are " + - "communications between two agents. One agent is interpreting data from another and may " + - "ask questions. Please respond to the questions using the initial guidance layed out in " + - "the next messages").build()); - messages.addAll( reason.getMessages() ); - for (AgentCommunicationDTO agentCommunicationDTO : nextMessaged) { - if (agentCommunicationDTO.getTargetAgent().equals(execution.getUser().getUsername())) { - otherAgents.add( agentCommunicationDTO.getSourceAgent()); - } - communicationIds.add(agentCommunicationDTO.getCommunicationId()); - } - messages.add(Message.builder().role("system").content("please respond in the following json " + - "format: " + respondZtat).build()); - - LLMRequest chatRequest = LLMRequest.builder().model("gpt-4o-mini").messages(messages).build(); - context.addMessages( messages ); - var resp = llmService.askQuestion(execution, chatRequest); - Response response = JsonUtil.MAPPER.readValue(resp, Response.class); - //log.info("Response is {}", resp); - for (Response.Choice choice : response.getChoices()) { - var content = choice.getMessage().getContent(); - if (content.startsWith("```json")) { - content = content.substring(7, content.length() - 3); - } + var commsId = commsIds.iterator().next(); + + AgentExecution newExecution = + AgentExecution.builder().executionId(execution.getExecutionId()).ztatToken(execution.getZtatToken()) + .communicationId(commsId).build(); + var nextMessaged = agentClientService.getResponse(newExecution, ztatRequest, 1, TimeUnit.MINUTES); + Set otherAgents = Sets.newHashSet(); + Set communicationIds = new HashSet<>(); + if (!nextMessaged.isEmpty()) { + List messages = new ArrayList<>(); + messages.add(Message.builder().role("system").content("The following messages are " + + "communications between two agents. One agent is interpreting data from another and may " + + "ask questions. Please respond to the questions using the initial guidance layed out in " + + "the next messages").build()); + messages.addAll(reason.getMessages()); + for (AgentCommunicationDTO agentCommunicationDTO : nextMessaged) { + if (agentCommunicationDTO.getTargetAgent().equals(execution.getUser().getUsername())) { + otherAgents.add(agentCommunicationDTO.getSourceAgent()); + } + communicationIds.add(agentCommunicationDTO.getCommunicationId()); + } + messages.add(Message.builder().role("system").content("please respond in the following json " + + "format: " + respondZtat).build()); - var ztatResponse = JsonUtil.MAPPER.readValue(content, - ZtatResponse.class); + LLMRequest chatRequest = LLMRequest.builder().model("gpt-4o-mini").messages(messages).build(); + context.addMessages(messages); + var resp = llmService.askQuestion(execution, chatRequest); + Response response = JsonUtil.MAPPER.readValue(resp, Response.class); + //log.info("Response is {}", resp); + for (Response.Choice choice : response.getChoices()) { + var content = choice.getMessage().getContent(); + if (content.startsWith("```json")) { + content = content.substring(7, content.length() - 3); + } - for(var agent : otherAgents ) { - for (var commId : communicationIds) { - AgentCommunicationDTO myResponse = AgentCommunicationDTO.builder() - .communicationId(commId) - .payload(JsonUtil.MAPPER.writeValueAsString(ztatResponse)) - .messageType("atat_chat_respond") - .sourceAgent(execution.getUser().getUsername()) - .targetAgent(agent) - .build(); - agentClientService.sendResponse(execution, myResponse, ztatRequest); - } + var ztatResponse = JsonUtil.MAPPER.readValue( + content, + ZtatResponse.class + ); + + for (var agent : otherAgents) { + for (var commId : communicationIds) { + AgentCommunicationDTO myResponse = AgentCommunicationDTO.builder() + .communicationId(commId) + .payload(JsonUtil.MAPPER.writeValueAsString(ztatResponse)) + .messageType("atat_chat_respond") + .sourceAgent(execution.getUser().getUsername()) + .targetAgent(agent) + .build(); + + agentClientService.sendResponse(execution, myResponse, ztatRequest); } - log.info("content is {}", content); } + log.info("content is {}", content); } + } - // check for messages + // check for messages - } + } return null; // return llmService.askQuestion(chatRequest); @@ -276,15 +292,18 @@ public String justifyAgent(AgentExecution execution, AgentExecutionContextDTO co * * @return An `ArrayNode` containing the assessment results. * @throws io.sentrius.sso.core.exceptions.ZtatException If there is an error during the operation. - * @throws java.io.IOException If there is an error reading the configuration file. + * @throws java.io.IOException If there is an error reading the configuration file. */ - @Verb(name = "assess_api_data", returnType = ArrayNode.class, description = "Accepts api server data based on the" + + @Verb( + name = "assess_api_data", returnType = ArrayNode.class, description = "Accepts api server data based on the" + " " + "context and seeks" + " to perform the assessment of risk by prompting the LLM. Can be used to assess data or request information " + "from " + - "users and/or agents, but not for assessing ztat requests.", requiresTokenManagement = true) - public List assessData(AgentExecution execution, AgentExecutionContextDTO agentContext) throws ZtatException, IOException { + "users and/or agents, but not for assessing ztat requests.", requiresTokenManagement = true + ) + public List assessData(AgentExecution execution, AgentExecutionContextDTO agentContext) + throws ZtatException, IOException { AgentConfig config = getAgentConfig(execution); log.info("Agent config loaded: {}", config); @@ -298,12 +317,11 @@ public List assessData(AgentExecution execution, AgentExecutio List messages = new ArrayList<>(); var context = config.getContext(); - var userMessage =Message.builder().role("user").content(obj.toString()).build(); + var userMessage = Message.builder().role("user").content(obj.toString()).build(); agentContext.addMessages(userMessage); messages.add(userMessage); - LLMRequest chatRequest = LLMRequest.builder().model("gpt-4o-mini").messages(messages).build(); var resp = llmService.askQuestion(execution, chatRequest); @@ -324,18 +342,18 @@ public List assessData(AgentExecution execution, AgentExecutio } log.info("Object is {}", obj); } - }else { + } else { List messages = new ArrayList<>(); var context = config.getContext(); - messages.addAll( agentContext.getMessages()); + messages.addAll(agentContext.getMessages()); var assistantMessage = Message.builder().role("assistant").content("Assess the previous data, but respond with the " + "following format { \"assessment\"{ sessionId, risk, description} } where description is your " + "assessment, sessionId is a random UUID or a previously found sessionId, and risk is a measure of" + - " low, medium, and high of the previous data" ).build(); + " low, medium, and high of the previous data").build(); agentContext.addMessages(assistantMessage); @@ -361,19 +379,23 @@ public List assessData(AgentExecution execution, AgentExecutio return responses; } - @Verb(name = "list_ztat_requests", returnType = ArrayNode.class, description = "Lists zero trust access token " + + @Verb( + name = "list_ztat_requests", returnType = ArrayNode.class, description = "Lists zero trust access token " + "requests (ztats)" + " " + "to" + " review from API. Does not review ztats.", - requiresTokenManagement = true ) - public List getWork(AgentExecution token, Map args) throws ZtatException, IOException { + requiresTokenManagement = true + ) + public List getWork(AgentExecution token, Map args) throws ZtatException, IOException { List requests = new ArrayList<>(); var atatRequests = agentClientService.getAtatRequests(token); log.info("Atat requests: {}", atatRequests); - List dtos = JsonUtil.MAPPER.readValue(atatRequests, new TypeReference<>() { - }); + List dtos = JsonUtil.MAPPER.readValue( + atatRequests, new TypeReference<>() { + } + ); for (var dto : dtos) { Set communicationIds = Sets.newHashSet(dto.getCommunicationIds()); @@ -382,24 +404,27 @@ public List getWork(AgentExecution token, Map args) request.setUserName(dto.getUserName()); request.setRequestId(dto.getId().toString()); // get messages - request.setRequestedAction( dto.getSummary()); + request.setRequestedAction(dto.getSummary()); log.info("Request is {}", dto); List communicationMessages = new ArrayList<>(); - for(String commsId : dto.getCommunicationIds()){ - var communications = zeroTrustClientService.callGetOnApi(token,"/agent/communications/id", - Maps.immutableEntry("communicationId", List.of(commsId))); + for (String commsId : dto.getCommunicationIds()) { + var communications = zeroTrustClientService.callGetOnApi( + token, "/agent/communications/id", + Maps.immutableEntry("communicationId", List.of(commsId)) + ); var messages = JsonUtil.MAPPER.readTree(communications); - for(JsonNode message : messages) { + for (JsonNode message : messages) { if (message.has("payload") && message.has("messageType")) { var type = message.get("messageType").asText(); if (type.equalsIgnoreCase("chat_request")) { try { - LLMRequest msg = JsonUtil.MAPPER.readValue(message.get("payload").asText(), LLMRequest.class); + LLMRequest msg = + JsonUtil.MAPPER.readValue(message.get("payload").asText(), LLMRequest.class); log.info("Message is {} from {}", msg, message.get("payload").asText()); - communicationMessages.addAll(msg.getMessages()); + communicationMessages.addAll(msg.getMessages()); } catch (JsonProcessingException e) { log.error(e.getMessage()); @@ -417,9 +442,11 @@ public List getWork(AgentExecution token, Map args) return requests; } - @Verb(name = "assess_ztat_requests", returnType = ArrayNode.class, description = "Analyzes ztats " + + @Verb( + name = "assess_ztat_requests", returnType = ArrayNode.class, description = "Analyzes ztats " + "requests according to the by prompting the LLM. ", - requiresTokenManagement = true ) + requiresTokenManagement = true + ) public List analyzeAtatRequests(AgentExecution execution, List requests) throws ZtatException, IOException, TimeoutException { @@ -437,7 +464,7 @@ public List analyzeAtatRequests(AgentExecution execution, List responses = new ArrayList<>(); log.info("Size of requests {}", requests.size()); for (var request : requests) { - var originalMessages = request.getMessages().stream().map(message ->{ + var originalMessages = request.getMessages().stream().map(message -> { message.setRole("user"); return message; }).toList(); @@ -448,7 +475,9 @@ public List analyzeAtatRequests(AgentExecution execution, List analyzeAtatRequests(AgentExecution execution, List(originalMessages); - for(var newComm : newComms) { + for (var newComm : newComms) { if (newComm.getMessageType().equalsIgnoreCase("atat_chat_ask")) { var msg = JsonUtil.MAPPER.readValue(newComm.getPayload(), ZtatAsessment.class); var newMessage = @@ -508,11 +538,13 @@ public List analyzeAtatRequests(AgentExecution execution, List analyzeAtatRequests(AgentExecution execution, List 0); + } while (--max > 0); } - } return responses; } - @Verb(name = "create_agent_context", returnType = AgentContextDTO.class, description = "Creates an agent Context." + + @Verb( + name = "create_agent_context", returnType = AgentContextDTO.class, description = "Creates an agent Context." + " must be done before creating an agent.", - requiresTokenManagement = true, + requiresTokenManagement = true, returnName = "created_context", - exampleJson = "{ \"context\": \"Notify when a new user is added\" }") + exampleJson = "{ \"context\": \"Notify when a new user is added\" }" + ) public AgentContextDTO createAgentContext(AgentExecution execution, AgentExecutionContextDTO context) throws ZtatException, JsonProcessingException { log.info("Creating agent context"); @@ -567,18 +600,19 @@ public AgentContextDTO createAgentContext(AgentExecution execution, AgentExecuti String agentName = name.isPresent() ? name.get().toString() : "name"; if (!agentName.isEmpty()) { - agentName =agentName.replaceAll("_","-"); + agentName = agentName.replaceAll("_", "-"); } + var originalContext = context.getExecutionArgument("context"); - var requestDtoContext = context.getExecutionArgument("context").orElseThrow().toString(); - requestDtoContext += ". Please request endpoints to perform your work."; + var requestDtoContext = originalContext.orElseThrow().toString(); + requestDtoContext += ". Please request endpoints to perform your work."; AgentContextRequestDTO dto = AgentContextRequestDTO.builder().context(requestDtoContext). description(requestDtoContext).name(agentName).build(); var createdContext = agentClientService.createAgentContext(execution, dto); // Here you would typically create a context in your system, e.g., store it in a database or cache. - context.setAgentContext( AgentContextDTO.builder() + context.setAgentContext(AgentContextDTO.builder() .contextId(createdContext.getContextId()) .name(createdContext.getName()) .context(createdContext.getContext()) @@ -586,22 +620,113 @@ public AgentContextDTO createAgentContext(AgentExecution execution, AgentExecuti .build()); // load the endpoints + var messages = new ArrayList(); + + messages.add(Message.builder().role("system").content("The user will provide the context of what an agent to " + + "be created will do. Respond with a json response { \"endpoints_like\" : [ array ] } where array is the " + + "features " + + "or tools to be called. Do not put endpoints in there, just text and explanation of the endpoint. " + + "We'll perform a text " + + "search to find" + + " endpoints").build()); + messages.add(Message.builder().role("user").content(originalContext.get().asText()).build()); - ObjectNode endpointsLike = JsonUtil.MAPPER.createObjectNode(); - endpointsLike.put("context", requestDtoContext); - context.setExecutionArgs(endpointsLike); - var endpoints = getEndpointsLike(execution, context); - log.info("Endpoints like {}", endpoints); + LLMRequest chatRequest = LLMRequest.builder().model("gpt-4o-mini").messages(messages).build(); + var resp = llmService.askQuestion(execution, chatRequest); - context.addToMemory("endpoints",endpoints); + Response response = JsonUtil.MAPPER.readValue(resp, Response.class); +// log.info("Response is {}", resp); + ArrayNode endpointsLikeList = JsonUtil.MAPPER.createArrayNode(); + for (Response.Choice choice : response.getChoices()) { + var content = choice.getMessage().getContent(); + if (content.startsWith("```json")) { + content = content.substring(7, content.length() - 3); + } else if (content.startsWith("```")) { + content = content.substring(3, content.length() - 3); + } + log.info("content is {}", content); + if (null != content && !content.isEmpty()) { + + var node = JsonUtil.MAPPER.enable(JsonParser.Feature.ALLOW_COMMENTS).readTree(content); + + if (node.get("endpoints_like") == null || !node.get("endpoints_like").isArray()) { + log.info("No endpoints_like found in response"); + continue; + } + var arrayNode = (ArrayNode) node.get("endpoints_like"); + for (JsonNode localNode : arrayNode) { + endpointsLikeList.add(localNode.asText()); + } + + } + } + + if (endpointsLikeList.size() > 0) { + + ObjectNode endpointsLike = JsonUtil.MAPPER.createObjectNode(); + endpointsLike.put("context", originalContext.orElseThrow().toString()); + endpointsLike.put("endpoints_like", endpointsLikeList); + context.setExecutionArgs(endpointsLike); + var endpoints = getEndpointsLike(execution, context); + log.info("Endpoints like {}", endpoints); + + context.addToMemory("endpoints", endpoints); + } return createdContext; } + @Verb( + name = "summarize_agent_status", returnType = AgentExecutionContextDTO.class, description = + "Summarizes agent status. Used when user asks for agent status.", + requiresTokenManagement = true + ) + public JsonNode getAgentExecutionStatus(AgentExecution execution, AgentExecutionContextDTO context) + throws ZtatException, JsonProcessingException { + var status = agentExecutionService.getExecutionContextDTO(execution.getExecutionId()); + + + var lastTen = ListUtils.getLastNElements(status.getMessages(),10); + var messages = new ArrayList(); + + messages.add(Message.builder().role("system").content("All of the next messages are history between the " + + "system, assistant, and user" + + ". Your job is to" + + " " + + "summarize them. return { \"summary\" : \"summary text\" }").build()); + messages.addAll(status.getMessages()); + LLMRequest chatRequest = LLMRequest.builder().model("gpt-4o-mini").messages(messages).build(); + var resp = llmService.askQuestion(execution, chatRequest); + context.addMessages(messages); + Response response = JsonUtil.MAPPER.readValue(resp, Response.class); + //log.info("Response is {}", resp); + for (Response.Choice choice : response.getChoices()) { + var content = choice.getMessage().getContent(); + if (content.startsWith("```json")) { + content = content.substring(7, content.length() - 3); + } else if (content.startsWith("```")) { + content = content.substring(3, content.length() - 3); + } + log.info("content is {}", content); + if (null != content && !content.isEmpty()) { + JsonNode node = JsonUtil.MAPPER.enable(JsonParser.Feature.ALLOW_COMMENTS).readTree(content); + log.info("Node is {}", node); + if (node.get("summary") != null) { + ArrayNode plan = (ArrayNode) node.get("summary"); + log.info("summary is {}", plan); + return plan; + } + } + } + + return JsonUtil.MAPPER.createObjectNode(); + } + @Verb(name = "create_agent", returnType = AgentExecutionContextDTO.class, description = "Creates an agent who has the " + "context. a previously defined contextId is required. previously defined endpoints can be used to build a " + - "trust policy. must call create_agent_context before this verb.", - exampleJson = "{ \"agentName\": \"agentName\" }", + "trust policy. must call create_agent_context before this verb. agent type is chat, chat-autonomous, or " + + "autonomous. chat is chat only, chat-autonomous is chat and autonomous. determine based on workload.", + exampleJson = "{ \"agentName\": \"agentName\", \"agentType\": \"agentType\" }", requiresTokenManagement = true ) public ObjectNode createAgent(AgentExecution execution, AgentExecutionContextDTO context) throws ZtatException, JsonProcessingException { @@ -609,6 +734,7 @@ public ObjectNode createAgent(AgentExecution execution, AgentExecutionContextDTO var contextId=context.getSafeLabel("created_context", "contextId"); var agentName = context.getSafeLabel("agentName"); + var agentType = context.getSafeLabel("agentType"); Optional optEndpoints = context.getExecutionArgumentScoped("endpoints", ObjectNode.class); var policyId = ""; log.info("Context ID is {}, agentName is {}", contextId, agentName); @@ -650,6 +776,7 @@ public ObjectNode createAgent(AgentExecution execution, AgentExecutionContextDTO var agentBuilder = AgentRegistrationDTO.builder() .agentContextId(contextId) .clientId(UUID.randomUUID().toString()) + .agentType(agentType) .agentName(agentName); if (!policyId.isEmpty()){ log.info("Using policyId {}", policyId); @@ -698,9 +825,30 @@ public ObjectNode getEndpointsLike(AgentExecution execution, var queryInput = executionContextDTO.getExecutionArgs(); log.info("Querying for endpoints like: {}", queryInput); - ObjectNode contextNode = JsonUtil.MAPPER.createObjectNode(); + var parsedQuery = queryInput.get("endpoints_like"); + if (null == parsedQuery) { + throw new IllegalArgumentException("Missing 'endpoints_like' argument"); + } + ObjectNode contextNode = JsonUtil.MAPPER.createObjectNode(); + ArrayNode endpoints = JsonUtil.MAPPER.createArrayNode(); + for(JsonNode node : parsedQuery) { + if (!node.isTextual()) { + throw new IllegalArgumentException("All items in 'endpoints_like' must be strings"); + } + var endpointList = endpointSearcher.getEndpointsLike(execution, node.asText()); + endpointList.forEach(endpoint -> { + ObjectNode endpointNode = JsonUtil.MAPPER.createObjectNode(); + endpointNode.put("name", node.asText()); + endpointNode.put("method", endpoint.getHttpMethod()); + endpointNode.put("endpoint", endpoint.getPath()); + endpoints.add(endpointNode); + }); + } + contextNode.put("endpoints", endpoints); + + /* var endpoints = verbRegistry.getEndpoints(); ArrayNode endpointArray = JsonUtil.MAPPER.createArrayNode(); for (var verb : endpoints) { @@ -734,7 +882,8 @@ public ObjectNode getEndpointsLike(AgentExecution execution, endpoint.put("endpoint", path); endpointArray.add(endpoint); } - + */ + /* var listedEndpoints = Message.builder().role("system").content("These are a list of available endpoints, " + "description," + " their " + @@ -784,6 +933,7 @@ public ObjectNode getEndpointsLike(AgentExecution execution, } } } + */ return contextNode; } diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/ChatVerbs.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/ChatVerbs.java index bfe1fe4e..8a88488f 100644 --- a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/ChatVerbs.java +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/ChatVerbs.java @@ -12,7 +12,7 @@ import io.sentrius.agent.analysis.agents.agents.AgentVerb; import io.sentrius.agent.analysis.agents.agents.PromptBuilder; import io.sentrius.agent.analysis.agents.agents.VerbRegistry; -import io.sentrius.agent.analysis.model.TerminalResponse; +import io.sentrius.agent.analysis.model.LLMResponse; import io.sentrius.agent.analysis.model.WebSocky; import io.sentrius.sso.core.dto.agents.AgentExecution; import io.sentrius.sso.core.dto.agents.AgentExecutionContextDTO; @@ -67,13 +67,13 @@ protected ChatVerbs(@Value("${agent.ai.config}") String agentConfigFile, @Verb(name = "interpret_user_request", returnType = ArrayNode.class, description = "Queries the LLM using the " + "user input.", isAiCallable = false, requiresTokenManagement = true) - public TerminalResponse interpretUserData( + public LLMResponse interpretUserData( AgentExecution execution, AgentExecutionContextDTO executionContext, @NonNull WebSocky socketConnection, @NonNull Message userMessage) throws ZtatException, IOException { - var lastMessage = socketConnection.getMessages().stream().reduce((prev, next) -> next).orElse(null); - if (socketConnection.getMessages().isEmpty()) { + var lastMessage = socketConnection.getCommunicationResponses().stream().reduce((prev, next) -> next).orElse(null); + if (socketConnection.getCommunicationResponses().isEmpty()) { InputStream terminalHelperStream = getClass().getClassLoader().getResourceAsStream("terminal-helper.json"); if (terminalHelperStream == null) { @@ -105,7 +105,7 @@ public TerminalResponse interpretUserData( "sessions, using " + "terminal output if needed " + "for clarity of the next LLM request and for the user. Ensure your all future responses meets this " + - "json format (TerminalResponse format): " + terminalResponse).build()); + "json format (LLMResponse format): " + terminalResponse).build()); messages.add(Message.builder().role("user").content(userMessage.getContent()).build()); LLMRequest chatRequest = LLMRequest.builder().model("gpt-4o-mini").messages(messages).build(); var resp = llmService.askQuestion(execution, chatRequest); @@ -124,12 +124,12 @@ public TerminalResponse interpretUserData( try { var newResponse = JsonUtil.MAPPER.enable(JsonParser.Feature.ALLOW_COMMENTS).readValue( content, - TerminalResponse.class + LLMResponse.class ); return newResponse; }catch (JsonParseException e) { log.error("Failed to parse terminal response: {}", e.getMessage()); - return TerminalResponse.builder().responseForUser(content).build(); + return LLMResponse.builder().responseForUser(content).build(); } } } @@ -173,12 +173,12 @@ public TerminalResponse interpretUserData( try { var newResponse = JsonUtil.MAPPER.enable(JsonParser.Feature.ALLOW_COMMENTS).readValue( content, - TerminalResponse.class + LLMResponse.class ); return newResponse; }catch (JsonParseException e) { log.error("Failed to parse terminal response: {}", e.getMessage()); - return TerminalResponse.builder().responseForUser(content).terminalSummaryForLLM(lastMessage.getTerminalSummaryForLLM()).build(); + return LLMResponse.builder().responseForUser(content).summaryForLLM(lastMessage.getSummaryForLLM()).build(); } } @@ -188,6 +188,72 @@ public TerminalResponse interpretUserData( return null; } + public LLMResponse promptAgent( + AgentExecution execution, AgentExecutionContextDTO executionContext, + String prompt) throws ZtatException, + IOException { + + + InputStream terminalHelperStream = getClass().getClassLoader().getResourceAsStream("terminal-helper.json"); + if (terminalHelperStream == null) { + throw new RuntimeException("assessor-config.yaml not found on classpath"); + + } + + String terminalResponse = new String(terminalHelperStream.readAllBytes()); + + AgentConfig config = getAgentConfig(execution); + log.info("Agent config loaded: {}", config); + + List messages = new ArrayList<>(); + var context = Message.builder().role("system").content(prompt).build(); + messages.add(context); + + messages.add(Message.builder().role("system").content("You have executed verbs for the previous user " + + "messages. Please generate a user response that summarizes the last message.").build()); + int size = getMessageSize(context); + + var history = getContextWindow(executionContext.getMessages(), 1024*96 - (size)); + messages.addAll(history); + messages.add(Message.builder().role("system").content("Please ensure your nextOperation abides by the " + + "following json format. Don't leave next operation empty, restart the session if you need, and you " + + "are not interacting with the user, so don't request info from the user, but instead execute the " + + "necessary endpoints" + + "." + + ". " + + ". Please summarize messaging, using " + + "for clarity of the next LLM request and for the user. Ensure your all future responses meets this " + + "json format (LLMResponse format): " + terminalResponse).build()); + LLMRequest chatRequest = LLMRequest.builder().model("gpt-4o-mini").messages(messages).build(); + var resp = llmService.askQuestion(execution, chatRequest); + executionContext.addMessages( messages ); + Response response = JsonUtil.MAPPER.readValue(resp, Response.class); + log.info("Response is {}", resp); + for (Response.Choice choice : response.getChoices()) { + var content = choice.getMessage().getContent(); + if (content.startsWith("```json")) { + content = content.substring(7, content.length() - 3); + } else if (content.startsWith("```")) { + content = content.substring(3, content.length() - 3); + } + log.info("+ {}", content); + if (null != content && !content.isEmpty()) { + try { + var newResponse = JsonUtil.MAPPER.enable(JsonParser.Feature.ALLOW_COMMENTS).readValue( + content, + LLMResponse.class + ); + return newResponse; + }catch (JsonParseException e) { + log.error("Failed to parse terminal response: {}", e.getMessage()); + return LLMResponse.builder().responseForUser(content).build(); + } + } + } + return LLMResponse.builder().build(); + } + + public List getContextWindow(List allMessages, int maxContextSize) { List systemMessages = new ArrayList<>(); List window = new ArrayList<>(); @@ -239,16 +305,14 @@ private int getMessageSize(Message msg) { return size; } - public TerminalResponse interpret_plan_response( - AgentExecution execution, AgentExecutionContextDTO executionContext, @NonNull WebSocky socketConnection, + public LLMResponse interpret_plan_response( + AgentExecution execution, AgentExecutionContextDTO executionContext, AgentVerb agentVerb, String planExecutionOutput) throws ZtatException, IOException { log.info("interpret_plan_response {}", planExecutionOutput); - var lastMessage = socketConnection.getMessages().stream().reduce((prev, next) -> next).orElse(null); - InputStream terminalHelperStream = getClass().getClassLoader().getResourceAsStream("terminal-helper.json"); if (terminalHelperStream == null) { throw new RuntimeException("assessor-config.yaml not found on classpath"); @@ -274,7 +338,7 @@ public TerminalResponse interpret_plan_response( messages.add(Message.builder().role("system").content("You have executed verbs for the previous user " + "messages. Please generate a user response that summarizes the last message. Keep all responses in " + - "TerminalResponse format" + + "LLMResponse format" + ".").build()); } else { executionContext.addMessages( messages ); @@ -313,11 +377,11 @@ public TerminalResponse interpret_plan_response( try { var newResponse = JsonUtil.MAPPER.enable(JsonParser.Feature.ALLOW_COMMENTS).readValue( content, - TerminalResponse.class + LLMResponse.class ); return newResponse; } catch (Exception e){ - return TerminalResponse.builder().responseForUser(content).build(); + return LLMResponse.builder().responseForUser(content).build(); } } diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/VerbBase.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/VerbBase.java index 0f324957..c5c437a0 100644 --- a/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/VerbBase.java +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/agents/verbs/VerbBase.java @@ -34,7 +34,7 @@ protected VerbBase(@Value("${agent.ai.config}") String agentConfigFile, this.agentDatabaseContext = agentDatabaseContext; } - protected AgentConfig getAgentConfig(AgentExecution execution) throws IOException, ZtatException { + public AgentConfig getAgentConfig(AgentExecution execution) throws IOException, ZtatException { AgentConfig config = null; if (agentDatabaseContext != null && !agentDatabaseContext.equals("none")) { AgentContextDTO agentContext = agentClientService.getAgentContext(execution, diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/api/websocket/ChatWSHandler.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/api/websocket/ChatWSHandler.java index f99e0f45..870bbacc 100644 --- a/ai-agent/src/main/java/io/sentrius/agent/analysis/api/websocket/ChatWSHandler.java +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/api/websocket/ChatWSHandler.java @@ -16,15 +16,16 @@ import io.sentrius.agent.analysis.agents.verbs.ChatVerbs; import io.sentrius.agent.analysis.agents.verbs.TerminalVerbs; import io.sentrius.agent.analysis.api.UserCommunicationService; -import io.sentrius.agent.analysis.model.TerminalResponse; +import io.sentrius.agent.analysis.model.LLMResponse; import io.sentrius.sso.core.exceptions.ZtatException; import io.sentrius.sso.core.services.agents.AgentClientService; +import io.sentrius.sso.core.services.agents.AgentExecutionService; import io.sentrius.sso.core.services.agents.ZeroTrustClientService; import io.sentrius.sso.genai.Message; import io.sentrius.sso.protobuf.Session; import io.sentrius.sso.provenance.ProvenanceEvent; +import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.stereotype.Component; import org.springframework.web.socket.CloseStatus; @@ -34,6 +35,7 @@ @Slf4j @Component +@RequiredArgsConstructor @ConditionalOnProperty(name = "agents.ai.chat.agent.enabled", havingValue = "true", matchIfMissing = false) public class ChatWSHandler extends TextWebSocketHandler { @@ -42,6 +44,7 @@ public class ChatWSHandler extends TextWebSocketHandler { final TerminalVerbs terminalVerbs; final AgentVerbs agentVerbs; final ChatVerbs chatVerbs; + final AgentExecutionService agentExecutionService; // Store active sessions, using session ID or a custom identifier @@ -49,21 +52,6 @@ public class ChatWSHandler extends TextWebSocketHandler { private final AgentClientService agentClientService; private final VerbRegistry verbRegistry; - @Autowired - public ChatWSHandler(UserCommunicationService userCommunicationService, ZeroTrustClientService zeroTrustClientService, - TerminalVerbs terminalVerbs, AgentVerbs agentVerbs, ChatVerbs chatVerbs, ChatAgent chatAgent, - AgentClientService agentClientService, - VerbRegistry verbRegistry - ) { - this.userCommunicationService = userCommunicationService; - this.zeroTrustClientService = zeroTrustClientService; - this.terminalVerbs = terminalVerbs; - this.agentVerbs = agentVerbs; - this.chatVerbs = chatVerbs; - this.chatAgent = chatAgent; - this.agentClientService = agentClientService; - this.verbRegistry = verbRegistry; - } @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { @@ -199,14 +187,14 @@ protected void handleTextMessage(WebSocketSession session, TextMessage message) base64Message )); - websocky.get().getMessages().add(response); + websocky.get().getCommunicationResponses().add(response); if (response.getNextOperation() != null && !response.getNextOperation().isEmpty() && verbRegistry.isVerbRegistered(response.getNextOperation())) { try { - TerminalResponse nextResponse = null; + LLMResponse nextResponse = null; var lastVerbResponse = websocketCommunication.getVerbResponses().stream() @@ -231,12 +219,11 @@ protected void handleTextMessage(WebSocketSession session, TextMessage message) nextResponse = chatVerbs.interpret_plan_response( chatAgent.getAgentExecution(), websocketCommunication.getAgentExecutionContextDTO(), - websocketCommunication, verbRegistry.getVerbs().get(response.getNextOperation()), planResponse ); - websocky.get().getMessages().add(nextResponse); + websocky.get().getCommunicationResponses().add(nextResponse); websocketCommunication.getVerbResponses().add(executionResponse); @@ -259,6 +246,20 @@ protected void handleTextMessage(WebSocketSession session, TextMessage message) log.info("Next getArguments: {}", nextResponse.getArguments()); lastVerbResponse = executionResponse; response = nextResponse; + + var memory = websocketCommunication.getAgentExecutionContextDTO().flushPersistentMemory(); + if (memory != null) { + for(var memoryEntry : memory.entrySet()){ + agentClientService.storeMemory(chatAgent.getAgentExecution(), + websocketCommunication.getAgentExecutionContextDTO().getAgentContext().getName(), + io.sentrius.sso.core.dto.agents.AgentMemoryDTO.builder() + .agentName(websocketCommunication.getAgentExecutionContextDTO().getAgentContext().getName()) + .memoryKey(memoryEntry.getKey()) + .memoryValue(memoryEntry.getValue().toString()) + .build()); + } + } + }while (nextResponse.getNextOperation() != null && !nextResponse.getNextOperation().isEmpty()); }catch (Exception e){ e.printStackTrace(); diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/model/TerminalResponse.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/model/LLMResponse.java similarity index 82% rename from ai-agent/src/main/java/io/sentrius/agent/analysis/model/TerminalResponse.java rename to ai-agent/src/main/java/io/sentrius/agent/analysis/model/LLMResponse.java index 2f62e81a..ccc5a84c 100644 --- a/ai-agent/src/main/java/io/sentrius/agent/analysis/model/TerminalResponse.java +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/model/LLMResponse.java @@ -1,11 +1,7 @@ package io.sentrius.agent.analysis.model; -import java.util.ArrayList; import java.util.HashMap; -import java.util.List; import java.util.Map; -import io.sentrius.sso.core.dto.HostSystemDTO; -import io.sentrius.sso.genai.Message; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; @@ -13,7 +9,6 @@ import lombok.NoArgsConstructor; import lombok.Setter; import lombok.extern.slf4j.Slf4j; -import org.springframework.web.socket.WebSocketSession; @Data @Builder @@ -22,10 +17,10 @@ @NoArgsConstructor @AllArgsConstructor @Slf4j -public class TerminalResponse { +public class LLMResponse { String previousOperation; String nextOperation; - String terminalSummaryForLLM; + String summaryForLLM; String responseForUser; @Builder.Default public Map arguments = new HashMap<>(); diff --git a/ai-agent/src/main/java/io/sentrius/agent/analysis/model/WebSocky.java b/ai-agent/src/main/java/io/sentrius/agent/analysis/model/WebSocky.java index 4452db94..50bafbbb 100644 --- a/ai-agent/src/main/java/io/sentrius/agent/analysis/model/WebSocky.java +++ b/ai-agent/src/main/java/io/sentrius/agent/analysis/model/WebSocky.java @@ -5,7 +5,6 @@ import io.sentrius.sso.core.dto.HostSystemDTO; import io.sentrius.sso.core.dto.agents.AgentExecutionContextDTO; import io.sentrius.sso.core.model.verbs.VerbResponse; -import io.sentrius.sso.genai.Message; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; @@ -26,7 +25,7 @@ public class WebSocky { Long uniqueIdentifier; WebSocketSession webSocketSession; @Builder.Default - List messages = new ArrayList<>(); + List communicationResponses = new ArrayList<>(); @Builder.Default List verbResponses = new ArrayList<>(); diff --git a/ai-agent/src/main/java/io/sentrius/agent/config/AgentConfigOptions.java b/ai-agent/src/main/java/io/sentrius/agent/config/AgentConfigOptions.java index 969e1742..befc7eee 100644 --- a/ai-agent/src/main/java/io/sentrius/agent/config/AgentConfigOptions.java +++ b/ai-agent/src/main/java/io/sentrius/agent/config/AgentConfigOptions.java @@ -1,15 +1,21 @@ package io.sentrius.agent.config; import java.util.List; +import lombok.AllArgsConstructor; +import lombok.Builder; import lombok.Getter; +import lombok.NoArgsConstructor; import lombok.Setter; import lombok.extern.slf4j.Slf4j; import org.springframework.boot.context.properties.ConfigurationProperties; +@Builder @Slf4j @ConfigurationProperties(prefix = "agent") @Getter @Setter +@AllArgsConstructor +@NoArgsConstructor public class AgentConfigOptions { diff --git a/ai-agent/src/main/java/io/sentrius/agent/services/EndpointRegistry.java b/ai-agent/src/main/java/io/sentrius/agent/services/EndpointRegistry.java new file mode 100644 index 00000000..38f5f5fc --- /dev/null +++ b/ai-agent/src/main/java/io/sentrius/agent/services/EndpointRegistry.java @@ -0,0 +1,102 @@ +package io.sentrius.agent.services; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import com.fasterxml.jackson.core.JsonProcessingException; +import io.sentrius.sso.core.dto.agents.AgentExecution; +import io.sentrius.sso.core.dto.agents.AgentMemoryDTO; +import io.sentrius.sso.core.dto.agents.MemoryQueryDTO; +import io.sentrius.sso.core.dto.capabilities.EndpointDescriptor; +import io.sentrius.sso.core.dto.ztat.TokenDTO; +import io.sentrius.sso.core.embeddings.EmbeddingService; +import io.sentrius.sso.core.exceptions.ZtatException; +import io.sentrius.sso.core.services.agents.AgentClientService; +import jakarta.annotation.PostConstruct; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Component; + +@RequiredArgsConstructor +@Component +@Slf4j +public class EndpointRegistry { + public static final String MEMORY_NAME = "all-endpoints"; + Map embeddingMap = new HashMap<>(); + Map descriptorMap = new HashMap<>(); + + private final AgentClientService agentClientService; + private final EmbeddingService embeddingService; + + public void loadEndpoints(AgentExecution dto) throws ZtatException, JsonProcessingException { + List endpoints = agentClientService.getAvailableEndpoints(dto); // however you get them + + for (EndpointDescriptor ed : endpoints) { + String key = buildKey(ed); + String json = EndpointDescriptor.toEmbeddableJson(ed); + float[] embedding = null; + + MemoryQueryDTO query = MemoryQueryDTO.builder() + .agentId(MEMORY_NAME) + .memoryKey(key) + .searchTerm(key) + .build(); + List existing = agentClientService.retrieveMemories(dto, MEMORY_NAME, query); + + if (existing != null && !existing.isEmpty() && existing.get(0).isHasEmbedding()) { + embedding = existing.get(0).getEmbedding(); + log.info("Reusing existing embedding for {}", key); + } else { + embedding = embeddingService.embed(dto, json); + + AgentMemoryDTO memory = AgentMemoryDTO.builder() + .memoryKey(key) + .memoryValue(json) + .memoryType("endpoint") + .agentId(MEMORY_NAME) + .classification("public") + .accessLevel("read") + .creatorUserId("system") + .hasEmbedding(true) + .embedding(embedding) + .build(); + log.info("Storing embedding memory for {}", key); + agentClientService.storeMemory(dto, MEMORY_NAME, memory); + } + log.info("Key={} | Embedding hash={} | First5={}", + key, + System.identityHashCode(embedding), + Arrays.toString(Arrays.copyOfRange(embedding, 0, 5))); + embeddingMap.put(key, embedding); + descriptorMap.put(key, ed); + + } + } + + public List getAll() { + return new ArrayList<>(descriptorMap.values()); + } + + public Optional getDescriptor(String key) { + return Optional.ofNullable(descriptorMap.get(key)); + } + + public Optional getEmbedding(String key) { + return Optional.ofNullable(embeddingMap.get(key)); + } + + public Optional getEmbedding(EndpointDescriptor ed) { + return Optional.ofNullable(embeddingMap.get(buildKey(ed))); + } + + private String buildKey(EndpointDescriptor ed) { + return ed.getHttpMethod() + "@" + ed.getPath(); + } + + public List getAllEndpoints() { + return new ArrayList<>(descriptorMap.values()); + } +} diff --git a/ai-agent/src/main/java/io/sentrius/agent/services/EndpointSearcher.java b/ai-agent/src/main/java/io/sentrius/agent/services/EndpointSearcher.java new file mode 100644 index 00000000..576ff19b --- /dev/null +++ b/ai-agent/src/main/java/io/sentrius/agent/services/EndpointSearcher.java @@ -0,0 +1,62 @@ +package io.sentrius.agent.services; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import com.fasterxml.jackson.core.JsonProcessingException; +import io.sentrius.sso.core.dto.capabilities.EndpointDescriptor; +import io.sentrius.sso.core.dto.ztat.TokenDTO; +import io.sentrius.sso.core.embeddings.EmbeddingService; +import io.sentrius.sso.core.exceptions.ZtatException; +import io.sentrius.sso.core.services.endpoints.CosineSimilarity; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +@Slf4j +@Service +public class EndpointSearcher { + + private final EmbeddingService embeddingService; + private final EndpointRegistry endpointRegistry; + + public EndpointSearcher(EmbeddingService embeddingService, + EndpointRegistry endpointRegistry + ) { + + this.embeddingService = embeddingService; + this.endpointRegistry = endpointRegistry; + } + + public List getEndpointsLike(TokenDTO dto, String query) + throws ZtatException, JsonProcessingException { + float[] queryVector = embeddingService.embed(dto, query); + + List endpoints = endpointRegistry.getAllEndpoints(); + return endpoints.stream() + .map(ed -> { + var embed = endpointRegistry.getEmbedding(ed); + if (embed.isEmpty()) { + log.warn("No embedding found for endpoint: {}", ed.getName()); + return Map.entry(ed, 0.0f); + } + + var arr = embed.get(); + log.info("Scoring {} | Query first5={} | Endpoint first5={}", + ed.getName(), + Arrays.toString(Arrays.copyOfRange(queryVector, 0, 5)), + Arrays.toString(Arrays.copyOfRange(arr, 0, 5))); + var score = CosineSimilarity.score(queryVector, + embed.orElseThrow(() -> new RuntimeException("Embedding not found for " + + "endpoint: " + ed.getName()))); + log.info("Calculating similarity for endpoint: {} and {} {} ", ed.getName(), embed.get().length, + score); + return Map.entry(ed, score); + } + ) + .sorted((a, b) -> Float.compare(b.getValue(), a.getValue())) + .filter(entry -> entry.getValue() > 0.75) // adjust threshold as needed + .map(Map.Entry::getKey) + .collect(Collectors.toList()); + } +} diff --git a/ai-agent/src/main/resources/chat-helper.json b/ai-agent/src/main/resources/chat-helper.json index bffb7174..4fc90834 100644 --- a/ai-agent/src/main/resources/chat-helper.json +++ b/ai-agent/src/main/resources/chat-helper.json @@ -5,6 +5,6 @@ "argumentname": "", "argumentname2": "" }, - "terminalSummaryForLLM": "", + "summaryForLLM": "", "responseForUser": "" } \ No newline at end of file diff --git a/ai-agent/src/main/resources/chat-helper.yaml b/ai-agent/src/main/resources/chat-helper.yaml index ae9be177..432339b1 100644 --- a/ai-agent/src/main/resources/chat-helper.yaml +++ b/ai-agent/src/main/resources/chat-helper.yaml @@ -5,6 +5,6 @@ context: | { "previousOperation": "", "nextOperation": "", - "terminalSummaryForLLM": "", + "summaryForLLM": "", "responseForUser": "" } \ No newline at end of file diff --git a/ai-agent/src/main/resources/terminal-helper.json b/ai-agent/src/main/resources/terminal-helper.json index 02be454b..e9461855 100644 --- a/ai-agent/src/main/resources/terminal-helper.json +++ b/ai-agent/src/main/resources/terminal-helper.json @@ -5,6 +5,6 @@ "argumentname": "", "argumentname2": "" }, - "terminalSummaryForLLM": "", + "summaryForLLM": "", "responseForUser": "" } \ No newline at end of file diff --git a/api/src/main/java/io/sentrius/sso/config/SecurityConfig.java b/api/src/main/java/io/sentrius/sso/config/SecurityConfig.java index bef57d79..58201627 100644 --- a/api/src/main/java/io/sentrius/sso/config/SecurityConfig.java +++ b/api/src/main/java/io/sentrius/sso/config/SecurityConfig.java @@ -10,6 +10,7 @@ import io.sentrius.sso.core.security.CustomAuthenticationSuccessHandler; import io.sentrius.sso.core.services.CustomUserDetailsService; import io.sentrius.sso.core.services.UserService; +import io.sentrius.sso.core.services.security.KeycloakService; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; @@ -34,6 +35,7 @@ public class SecurityConfig { private final CustomUserDetailsService userDetailsService; private final CustomAuthenticationSuccessHandler successHandler; private final KeycloakAuthSuccessHandler keycloakAuthSuccessHandler; + private final KeycloakService keycloakService; final UserService userService; @Value("${https.required:false}") // Default is false @@ -88,6 +90,7 @@ public JwtAuthenticationConverter jwtAuthenticationConverterForKeycloak() { User user = userService.getUserByUsername(username); if (user == null) { + var type = userService.getUserType( UserType.createUnknownUser()); if (type.isEmpty()) { diff --git a/api/src/main/java/io/sentrius/sso/controllers/api/IntegrationApiController.java b/api/src/main/java/io/sentrius/sso/controllers/api/IntegrationApiController.java index f53a0ba2..6d87da26 100644 --- a/api/src/main/java/io/sentrius/sso/controllers/api/IntegrationApiController.java +++ b/api/src/main/java/io/sentrius/sso/controllers/api/IntegrationApiController.java @@ -9,6 +9,7 @@ import io.sentrius.sso.core.controllers.BaseController; import io.sentrius.sso.core.model.security.IntegrationSecurityToken; import io.sentrius.sso.core.model.users.UserConfig; +import io.sentrius.sso.core.model.verbs.Endpoint; import io.sentrius.sso.core.services.ErrorOutputService; import io.sentrius.sso.core.integrations.external.ExternalIntegrationDTO; import io.sentrius.sso.core.services.UserService; @@ -54,6 +55,7 @@ protected IntegrationApiController( } @PostMapping("/github/add") + @Endpoint(description = "Adding a github integration so github can be used as an external data provider") public ResponseEntity addGitHubIntegration(HttpServletRequest request, HttpServletResponse response, ExternalIntegrationDTO integrationDTO) @@ -73,6 +75,7 @@ public ResponseEntity addGitHubIntegration(HttpServletRe } @PostMapping("/jira/add") + @Endpoint(description = "Adding a jira integration so jira can be used as an external data provider") public ResponseEntity addJiraIntegration(HttpServletRequest request, HttpServletResponse response, ExternalIntegrationDTO integrationDTO) throws JsonProcessingException, GeneralSecurityException { @@ -92,6 +95,7 @@ public ResponseEntity addJiraIntegration(HttpServletRequ } @PostMapping("/openai/add") + @Endpoint(description = "Adding an OpenAI integration so OpenAI can be used as an external data provider") public ResponseEntity addOpenaiIntegration(HttpServletRequest request, HttpServletResponse response, @RequestBody ExternalIntegrationDTO integrationDTO) @@ -111,6 +115,7 @@ public ResponseEntity addOpenaiIntegration(HttpServletRe } @PostMapping("/jira/delete") + @Endpoint(description = "Deleting a jira integration so jira can no longer be used as an external data provider") public ResponseEntity deleteJiraIntegration(HttpServletRequest request, HttpServletResponse response, @RequestParam("id") String id) @@ -122,6 +127,7 @@ public ResponseEntity deleteJiraIntegration(HttpServletRequest request, } @PostMapping("/delete") + @Endpoint(description = "Deleting an integration so it can no longer be used as an external data provider") public ResponseEntity deleteIntegration(HttpServletRequest request, HttpServletResponse response, @RequestParam("integrationId") String id) { diff --git a/api/src/main/java/io/sentrius/sso/controllers/api/AgentApiController.java b/api/src/main/java/io/sentrius/sso/controllers/api/agents/AgentApiController.java similarity index 99% rename from api/src/main/java/io/sentrius/sso/controllers/api/AgentApiController.java rename to api/src/main/java/io/sentrius/sso/controllers/api/agents/AgentApiController.java index 07c5d417..c04a9a4c 100644 --- a/api/src/main/java/io/sentrius/sso/controllers/api/AgentApiController.java +++ b/api/src/main/java/io/sentrius/sso/controllers/api/agents/AgentApiController.java @@ -1,4 +1,4 @@ -package io.sentrius.sso.controllers.api; +package io.sentrius.sso.controllers.api.agents; import java.net.URLDecoder; import java.nio.charset.StandardCharsets; diff --git a/api/src/main/java/io/sentrius/sso/controllers/api/AgentBootstrapController.java b/api/src/main/java/io/sentrius/sso/controllers/api/agents/AgentBootstrapController.java similarity index 97% rename from api/src/main/java/io/sentrius/sso/controllers/api/AgentBootstrapController.java rename to api/src/main/java/io/sentrius/sso/controllers/api/agents/AgentBootstrapController.java index c10c0c65..2b218024 100644 --- a/api/src/main/java/io/sentrius/sso/controllers/api/AgentBootstrapController.java +++ b/api/src/main/java/io/sentrius/sso/controllers/api/agents/AgentBootstrapController.java @@ -1,11 +1,9 @@ -package io.sentrius.sso.controllers.api; +package io.sentrius.sso.controllers.api.agents; import java.io.IOException; import java.io.InputStream; import java.security.GeneralSecurityException; -import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.Optional; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; @@ -39,13 +37,10 @@ import jakarta.servlet.http.HttpServletResponse; import jakarta.transaction.Transactional; import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.factory.annotation.Value; import org.springframework.http.ResponseEntity; -import org.springframework.web.bind.annotation.DeleteMapping; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; -import org.springframework.web.bind.annotation.RequestHeader; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; diff --git a/api/src/main/java/io/sentrius/sso/controllers/api/agents/AgentMemoryController.java b/api/src/main/java/io/sentrius/sso/controllers/api/agents/AgentMemoryController.java new file mode 100644 index 00000000..3f65cb74 --- /dev/null +++ b/api/src/main/java/io/sentrius/sso/controllers/api/agents/AgentMemoryController.java @@ -0,0 +1,513 @@ +package io.sentrius.sso.controllers.api.agents; + +import io.sentrius.sso.core.config.SystemOptions; +import io.sentrius.sso.core.controllers.BaseController; +import io.sentrius.sso.core.dto.agents.AgentMemoryDTO; +import io.sentrius.sso.core.dto.agents.MemoryQueryDTO; +import io.sentrius.sso.core.model.agents.AgentMemory; +import io.sentrius.sso.core.services.ErrorOutputService; +import io.sentrius.sso.core.services.UserService; +import io.sentrius.sso.core.services.agents.PersistentAgentMemoryStore; +import io.sentrius.sso.core.services.agents.VectorAgentMemoryStore; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import lombok.extern.slf4j.Slf4j; +import org.springframework.data.domain.Page; +import org.springframework.data.domain.PageRequest; +import org.springframework.data.domain.Sort; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.security.core.Authentication; +import org.springframework.web.bind.annotation.*; + +import jakarta.validation.Valid; +import java.security.Principal; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +@Slf4j +@RestController +@RequestMapping("/api/v1/agents/memory") +public class AgentMemoryController extends BaseController { + + private final PersistentAgentMemoryStore memoryStore; + private final VectorAgentMemoryStore vectorMemoryStore; + + public AgentMemoryController(PersistentAgentMemoryStore memoryStore, VectorAgentMemoryStore vectorMemoryStore, UserService userService, SystemOptions systemOptions, ErrorOutputService errorOutputService) { + super(userService, systemOptions, errorOutputService); + this.memoryStore = memoryStore; + this.vectorMemoryStore = vectorMemoryStore; + } + + /** + * Store agent memory + */ + @PostMapping("/store") + public ResponseEntity storeMemory( + @RequestParam(name = "agentId") String agentId, + @RequestBody @Valid AgentMemoryDTO memoryDTO, + @RequestParam(defaultValue = "false") boolean generateEmbedding, + HttpServletRequest request, HttpServletResponse response) { + + log.info("Storing memory for agent: {}, key: {}, embedding: {}", + agentId, memoryDTO.getMemoryKey(), generateEmbedding || memoryDTO.isHasEmbedding()); + + try { + var operatingUser = getOperatingUser(request,response); + String userId = operatingUser.getUserId(); + + AgentMemory memory; + + if (generateEmbedding) { + // Use vector store for embedding generation + memory = vectorMemoryStore.storeMemoryWithEmbedding( + agentId, + memoryDTO.getMemoryKey(), + memoryDTO.getMemoryValue(), + memoryDTO.getClassification(), + memoryDTO.getMarkings(), + userId + ); + } else if (memoryDTO.isHasEmbedding() && memoryDTO.getEmbedding() != null) { + // Store with provided embedding + memory = vectorMemoryStore.storeMemoryWithProvidedEmbedding( + agentId, + memoryDTO.getMemoryKey(), + memoryDTO.getMemoryValue(), + memoryDTO.getClassification(), + memoryDTO.getMarkings(), + memoryDTO.getEmbedding(), + userId + ); + } else { + // Use traditional storage + memory = memoryStore.storeMemory( + agentId, + memoryDTO.getMemoryKey(), + memoryDTO.getMemoryValue(), + memoryDTO.getClassification(), + memoryDTO.getMarkings(), + userId + ); + } + + AgentMemoryDTO responseDTO = convertToDTO(memory); + return ResponseEntity.ok(responseDTO); + + } catch (Exception e) { + log.error("Error storing memory for agent: {}, key: {}", agentId, memoryDTO.getMemoryKey(), e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Retrieve agent memory + */ + @GetMapping("/{agentId}/{memoryKey}") + public ResponseEntity retrieveMemory( + @PathVariable String agentId, + @PathVariable String memoryKey, + HttpServletRequest request, HttpServletResponse response) { + + log.debug("Retrieving memory for agent: {}, key: {}", agentId, memoryKey); + + try { + var operatingUser = getOperatingUser(request,response); + String userId = operatingUser.getUserId(); + + Optional memoryOpt = memoryStore.retrieveMemory(agentId, memoryKey, userId); + + if (memoryOpt.isPresent()) { + AgentMemoryDTO responseDTO = convertToDTO(memoryOpt.get()); + return ResponseEntity.ok(responseDTO); + } else { + return ResponseEntity.notFound().build(); + } + + } catch (Exception e) { + log.error("Error retrieving memory for agent: {}, key: {}", agentId, memoryKey, e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Query memories with filters + */ + @PostMapping("/query") + public ResponseEntity> queryMemories( + @RequestBody @Valid MemoryQueryDTO queryDTO, + HttpServletRequest request, HttpServletResponse response) { + + log.debug("Querying memories with filters: {}", queryDTO); + + try { + var operatingUser = getOperatingUser(request,response); + String userId = operatingUser.getUserId(); + + PageRequest pageRequest = PageRequest.of( + queryDTO.getPage(), + queryDTO.getSize(), + Sort.by(queryDTO.getSortDirection(), queryDTO.getSortBy()) + ); + + Page memories = memoryStore.queryMemories( + queryDTO.getAgentId(), + queryDTO.getClassification(), + queryDTO.getMarkings(), + userId, + pageRequest + ); + + Page responsePage = memories.map(this::convertToDTO); + return ResponseEntity.ok(responsePage); + + } catch (Exception e) { + log.error("Error querying memories", e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Get shareable memories for an agent + */ + @GetMapping("/{agentId}/shareable") + public ResponseEntity> getShareableMemories( + @PathVariable String agentId, + HttpServletRequest request, HttpServletResponse response) { + + log.debug("Getting shareable memories for agent: {}", agentId); + + try { + var operatingUser = getOperatingUser(request,response); + String userId = operatingUser.getUserId(); + + List shareableMemories = memoryStore.findShareableMemories(agentId, userId); + List responseDTOs = shareableMemories.stream() + .map(this::convertToDTO) + .collect(Collectors.toList()); + + return ResponseEntity.ok(responseDTOs); + + } catch (Exception e) { + log.error("Error getting shareable memories for agent: {}", agentId, e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Share memory with specific agents + */ + @PostMapping("/{agentId}/{memoryKey}/share") + public ResponseEntity> shareMemory( + @PathVariable String agentId, + @PathVariable String memoryKey, + @RequestBody Map shareRequest, + HttpServletRequest request, HttpServletResponse response) { + + log.info("Sharing memory: agent={}, key={}", agentId, memoryKey); + + try { + var operatingUser = getOperatingUser(request,response); + String userId = operatingUser.getUserId(); + + @SuppressWarnings("unchecked") + List targetAgentsList = (List) shareRequest.get("targetAgents"); + String[] targetAgents = targetAgentsList.toArray(new String[0]); + + boolean success = memoryStore.shareMemoryWithAgents(agentId, memoryKey, targetAgents, userId); + + Map userResponse = new HashMap<>(); + userResponse.put("success", success); + userResponse.put("sharedWith", targetAgents); + + return ResponseEntity.ok(userResponse); + + } catch (Exception e) { + log.error("Error sharing memory: agent={}, key={}", agentId, memoryKey, e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Search memories by markings + */ + @GetMapping("/search/markings/{marking}") + public ResponseEntity> searchByMarkings( + @PathVariable String marking, + HttpServletRequest request, HttpServletResponse response) { + + log.debug("Searching memories by marking: {}", marking); + + try { + var operatingUser = getOperatingUser(request,response); + String userId = operatingUser.getUserId(); + + List memories = memoryStore.findMemoriesByMarkings(marking, userId); + List responseDTOs = memories.stream() + .map(this::convertToDTO) + .collect(Collectors.toList()); + + return ResponseEntity.ok(responseDTOs); + + } catch (Exception e) { + log.error("Error searching memories by marking: {}", marking, e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Delete memory + */ + @DeleteMapping("/{agentId}/{memoryKey}") + public ResponseEntity> deleteMemory( + @PathVariable String agentId, + @PathVariable String memoryKey, + HttpServletRequest request, HttpServletResponse response) { + + log.info("Deleting memory: agent={}, key={}", agentId, memoryKey); + + try { + var operatingUser = getOperatingUser(request,response); + String userId = operatingUser.getUserId(); + + boolean success = memoryStore.deleteMemory(agentId, memoryKey, userId); + + Map userResponse = new HashMap<>(); + userResponse.put("success", success); + userResponse.put("deleted", success); + + return ResponseEntity.ok(userResponse); + + } catch (Exception e) { + log.error("Error deleting memory: agent={}, key={}", agentId, memoryKey, e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Get memory statistics for an agent + */ + @GetMapping("/{agentId}/statistics") + public ResponseEntity> getMemoryStatistics(@PathVariable String agentId) { + log.debug("Getting memory statistics for agent: {}", agentId); + + try { + Map stats = memoryStore.getMemoryStatistics(agentId); + return ResponseEntity.ok(stats); + + } catch (Exception e) { + log.error("Error getting memory statistics for agent: {}", agentId, e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Clean up expired memories (admin endpoint) + */ + @PostMapping("/cleanup/expired") + public ResponseEntity> cleanupExpiredMemories(HttpServletRequest request, HttpServletResponse response) { + var operatingUser = getOperatingUser(request,response); + String userId = operatingUser.getUserId(); + log.info("Cleaning up expired memories, requested by: {}", userId); + + try { + memoryStore.cleanupExpiredMemories(); + + Map userResponse = new HashMap<>(); + userResponse.put("success", true); + userResponse.put("message", "Expired memories cleanup completed"); + + return ResponseEntity.ok(userResponse); + + } catch (Exception e) { + log.error("Error cleaning up expired memories", e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + // === SEMANTIC SEARCH ENDPOINTS === + + /** + * Find semantically similar memories using vector similarity + */ + @PostMapping("/search/semantic") + public ResponseEntity> semanticSearch( + @RequestBody Map searchRequest, + HttpServletRequest request, HttpServletResponse response) { + + String queryText = (String) searchRequest.get("query"); + Integer limit = (Integer) searchRequest.getOrDefault("limit", 10); + Double threshold = (Double) searchRequest.getOrDefault("threshold", 0.7); + + log.debug("Semantic search query: {}, limit: {}, threshold: {}", queryText, limit, threshold); + + try { + var operatingUser = getOperatingUser(request,response); + String userId = operatingUser.getUserId(); + + List similarMemories = vectorMemoryStore.findSimilarMemories( + queryText, userId, limit, threshold); + + List responseDTOs = similarMemories.stream() + .map(this::convertToDTO) + .collect(Collectors.toList()); + + return ResponseEntity.ok(responseDTOs); + + } catch (Exception e) { + log.error("Error in semantic search", e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Find semantically similar memories for a specific agent + */ + @PostMapping("/search/semantic/{agentId}") + public ResponseEntity> semanticSearchForAgent( + @PathVariable String agentId, + @RequestBody Map searchRequest, + HttpServletRequest request, HttpServletResponse response) { + + String queryText = (String) searchRequest.get("query"); + Integer limit = (Integer) searchRequest.getOrDefault("limit", 10); + Double threshold = (Double) searchRequest.getOrDefault("threshold", 0.7); + + log.debug("Agent semantic search - agent: {}, query: {}, limit: {}", agentId, queryText, limit); + + try { + var operatingUser = getOperatingUser(request,response); + String userId = operatingUser.getUserId(); + + List similarMemories = vectorMemoryStore.findSimilarMemoriesForAgent( + queryText, agentId, userId, limit, threshold); + + List responseDTOs = similarMemories.stream() + .map(this::convertToDTO) + .collect(Collectors.toList()); + + return ResponseEntity.ok(responseDTOs); + + } catch (Exception e) { + log.error("Error in agent semantic search for agent: {}", agentId, e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Hybrid search combining text and vector similarity + */ + @PostMapping("/search/hybrid") + public ResponseEntity> hybridSearch( + @RequestBody Map searchRequest, + HttpServletRequest request, HttpServletResponse response) { + + String searchTerm = (String) searchRequest.get("searchTerm"); + String markingsFilter = (String) searchRequest.get("markings"); + Integer limit = (Integer) searchRequest.getOrDefault("limit", 10); + Double threshold = (Double) searchRequest.getOrDefault("threshold", 0.7); + + log.debug("Hybrid search - term: {}, markings: {}, limit: {}", searchTerm, markingsFilter, limit); + + try { + var operatingUser = getOperatingUser(request,response); + String userId = operatingUser.getUserId(); + + List results = vectorMemoryStore.hybridSearch( + searchTerm, markingsFilter, userId, limit, threshold); + + List responseDTOs = results.stream() + .map(this::convertToDTO) + .collect(Collectors.toList()); + + return ResponseEntity.ok(responseDTOs); + + } catch (Exception e) { + log.error("Error in hybrid search", e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Generate embeddings for memories that don't have them (admin endpoint) + */ + @PostMapping("/embeddings/generate") + public ResponseEntity> generateMissingEmbeddings( + @RequestParam(defaultValue = "100") int batchSize, + HttpServletRequest request, HttpServletResponse response) { + + var operatingUser = getOperatingUser(request,response); + String userId = operatingUser.getUserId(); + log.info("Generating missing embeddings, batch size: {}, requested by: {}", + batchSize, userId); + + try { + vectorMemoryStore.generateMissingEmbeddings(batchSize); + + Map userResponse = new HashMap<>(); + userResponse.put("success", true); + userResponse.put("message", "Embedding generation started for batch size: " + batchSize); + + return ResponseEntity.ok(userResponse); + + } catch (Exception e) { + log.error("Error generating embeddings", e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Get vector store statistics + */ + @GetMapping("/statistics/vector") + public ResponseEntity> getVectorStoreStatistics() { + log.debug("Getting vector store statistics"); + + try { + Map stats = vectorMemoryStore.getVectorStoreStatistics(); + return ResponseEntity.ok(stats); + + } catch (Exception e) { + log.error("Error getting vector store statistics", e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Convert AgentMemory entity to DTO + */ + private AgentMemoryDTO convertToDTO(AgentMemory memory) { + AgentMemoryDTO dto = AgentMemoryDTO.builder() + .id(memory.getId()) + .memoryKey(memory.getMemoryKey()) + .memoryValue(memory.getMemoryValue()) + .memoryType(memory.getMemoryType()) + .agentId(memory.getAgentId()) + .agentName(memory.getAgentName()) + .conversationId(memory.getConversationId()) + .classification(memory.getClassification()) + .markings(memory.getMarkingsArray()) + .accessLevel(memory.getAccessLevel()) + .creatorUserId(memory.getCreatorUserId()) + .creatorUserType(memory.getCreatorUserType()) + .createdAt(memory.getCreatedAt()) + .updatedAt(memory.getUpdatedAt()) + .expiresAt(memory.getExpiresAt()) + .sharedWithAgents(memory.getSharedAgentsArray()) + .metadata(memory.getMetadataAsMap()) + .version(memory.getVersion()) + .hasEmbedding(memory.hasEmbedding()) + .build(); + + // Only include embedding if it exists (optional for performance) + if (memory.hasEmbedding()) { + dto.setEmbeddingFromArray(memory.getEmbedding()); + } + + return dto; + } + +} \ No newline at end of file diff --git a/api/src/main/java/io/sentrius/sso/controllers/api/UserApiController.java b/api/src/main/java/io/sentrius/sso/controllers/api/users/UserApiController.java similarity index 99% rename from api/src/main/java/io/sentrius/sso/controllers/api/UserApiController.java rename to api/src/main/java/io/sentrius/sso/controllers/api/users/UserApiController.java index 71d88542..5e4d25f7 100644 --- a/api/src/main/java/io/sentrius/sso/controllers/api/UserApiController.java +++ b/api/src/main/java/io/sentrius/sso/controllers/api/users/UserApiController.java @@ -1,4 +1,4 @@ -package io.sentrius.sso.controllers.api; +package io.sentrius.sso.controllers.api.users; import java.lang.reflect.Field; import java.security.GeneralSecurityException; @@ -6,7 +6,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.BooleanNode; @@ -28,7 +27,6 @@ import io.sentrius.sso.core.model.users.UserConfig; import io.sentrius.sso.core.model.users.UserPublicKey; import io.sentrius.sso.core.model.users.UserSettings; -import io.sentrius.sso.core.services.ConfigurationService; import io.sentrius.sso.core.services.ErrorOutputService; import io.sentrius.sso.core.services.HostGroupService; import io.sentrius.sso.core.services.SessionService; @@ -40,14 +38,13 @@ import io.sentrius.sso.core.services.security.CryptoService; import io.sentrius.sso.core.services.security.ZeroTrustAccessTokenService; import io.sentrius.sso.core.services.security.ZeroTrustRequestService; +import io.sentrius.sso.core.services.users.UserAttributeService; import io.sentrius.sso.core.utils.JsonUtil; import io.sentrius.sso.core.utils.MessagingUtil; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; -import org.springframework.core.ParameterizedTypeReference; -import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; import org.springframework.stereotype.Controller; diff --git a/api/src/main/java/io/sentrius/sso/controllers/api/users/UserAttributeController.java b/api/src/main/java/io/sentrius/sso/controllers/api/users/UserAttributeController.java new file mode 100644 index 00000000..a7119b1a --- /dev/null +++ b/api/src/main/java/io/sentrius/sso/controllers/api/users/UserAttributeController.java @@ -0,0 +1,374 @@ +package io.sentrius.sso.controllers.api.users; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; +import io.sentrius.sso.core.config.SystemOptions; +import io.sentrius.sso.core.controllers.BaseController; +import io.sentrius.sso.core.dto.users.UserAttributeDTO; +import io.sentrius.sso.core.model.security.enums.ApplicationAccessEnum; +import io.sentrius.sso.core.model.users.UserAttribute; +import io.sentrius.sso.core.services.ErrorOutputService; +import io.sentrius.sso.core.services.UserService; +import io.sentrius.sso.core.services.users.UserAttributeService; +import io.sentrius.sso.core.utils.AccessUtil; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import jakarta.validation.Valid; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.DeleteMapping; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RestController; + +@Slf4j +@RestController +@RequestMapping("/api/v1/users/attributes") +public class UserAttributeController extends BaseController { + + private final UserAttributeService userAttributeService; + + public UserAttributeController(UserAttributeService userAttributeService, UserService userService, SystemOptions systemOptions, ErrorOutputService errorOutputService) { + super(userService, systemOptions, errorOutputService); + this.userAttributeService = userAttributeService; + } + + /** + * Get all attributes for the current user + */ + @GetMapping("/me") + public ResponseEntity> getMyAttributes(HttpServletRequest request, HttpServletResponse response) { + + var operatingUser = getOperatingUser(request, response); + + log.debug("Getting attributes for user: {}", operatingUser.getUserId()); + + try { + List attributes = userAttributeService.getUserAttributes(operatingUser.getUserId()); + List attributeDTOs = attributes.stream() + .map(this::convertToDTO) + .collect(Collectors.toList()); + + return ResponseEntity.ok(attributeDTOs); + + } catch (Exception e) { + log.error("Error getting attributes for user: {}", operatingUser.getUserId(), e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Get all attributes for a specific user (admin endpoint) + */ + @GetMapping("/{userId}") + public ResponseEntity> getUserAttributes(@PathVariable String userId) { + log.debug("Getting attributes for user: {}", userId); + + try { + List attributes = userAttributeService.getUserAttributes(userId); + List attributeDTOs = attributes.stream() + .map(this::convertToDTO) + .collect(Collectors.toList()); + + return ResponseEntity.ok(attributeDTOs); + + } catch (Exception e) { + log.error("Error getting attributes for user: {}", userId, e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Get user attributes as a map + */ + @GetMapping("/{userId}/map") + public ResponseEntity> getUserAttributesAsMap(@PathVariable String userId) { + log.debug("Getting attributes map for user: {}", userId); + + try { + Map attributesMap = userAttributeService.getUserAttributesAsMap(userId); + return ResponseEntity.ok(attributesMap); + + } catch (Exception e) { + log.error("Error getting attributes map for user: {}", userId, e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Get a specific user attribute + */ + @GetMapping("/{userId}/{attributeName}") + public ResponseEntity getUserAttribute( + @PathVariable String userId, + @PathVariable String attributeName) { + + log.debug("Getting attribute {} for user: {}", attributeName, userId); + + try { + Optional attributeOpt = userAttributeService.getUserAttribute(userId, attributeName); + + if (attributeOpt.isPresent()) { + UserAttributeDTO attributeDTO = convertToDTO(attributeOpt.get()); + return ResponseEntity.ok(attributeDTO); + } else { + return ResponseEntity.notFound().build(); + } + + } catch (Exception e) { + log.error("Error getting attribute {} for user: {}", attributeName, userId, e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Set a user attribute + */ + @PostMapping("/update") + public ResponseEntity setUserAttribute( + @RequestParam("userId") String userId, + @RequestBody @Valid UserAttributeDTO attributeDTO, + HttpServletRequest request, HttpServletResponse response) { + + log.info("Setting attribute {} for user: {}", attributeDTO.getAttributeName(), userId); + + try { + var operatingUser = getOperatingUser(request, response); + String requestingUserId = operatingUser.getUserId(); + + // Basic authorization - users can only modify their own attributes unless admin + if (!userId.equals(requestingUserId) && !AccessUtil.canAccess(operatingUser, ApplicationAccessEnum.CAN_MANAGE_APPLICATION)) { + log.warn("User {} attempted to modify attributes for user {}", requestingUserId, userId); + return ResponseEntity.status(HttpStatus.FORBIDDEN).build(); + } + + UserAttribute attribute = userAttributeService.setUserAttribute( + userId, + attributeDTO.getAttributeName(), + attributeDTO.getAttributeValue(), + attributeDTO.getAttributeType(), + attributeDTO.getSource() + ); + + UserAttributeDTO responseDTO = convertToDTO(attribute); + return ResponseEntity.ok(responseDTO); + + } catch (IllegalArgumentException e) { + log.warn("Invalid attribute data for user {}: {}", userId, e.getMessage()); + return ResponseEntity.badRequest().build(); + } catch (Exception e) { + log.error("Error setting attribute {} for user: {}", attributeDTO.getAttributeName(), userId, e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Set multiple user attributes at once + */ + @PostMapping("/{userId}/bulk") + public ResponseEntity> setUserAttributes( + @PathVariable String userId, + @RequestBody Map attributes, + HttpServletRequest request, HttpServletResponse response) { + + log.info("Setting {} attributes for user: {}", attributes.size(), userId); + + try { + var operatingUser = getOperatingUser(request, response); + String requestingUserId = operatingUser.getUserId(); + + // Basic authorization - users can only modify their own attributes unless admin + if (!userId.equals(requestingUserId) && !AccessUtil.canAccess(operatingUser, ApplicationAccessEnum.CAN_MANAGE_APPLICATION)) { + log.warn("User {} attempted to modify attributes for user {}", requestingUserId, userId); + return ResponseEntity.status(HttpStatus.FORBIDDEN).build(); + } + + List savedAttributes = userAttributeService.setUserAttributes( + userId, attributes, "SENTRIUS"); + + List responseDTOs = savedAttributes.stream() + .map(this::convertToDTO) + .collect(Collectors.toList()); + + return ResponseEntity.ok(responseDTOs); + + } catch (Exception e) { + log.error("Error setting bulk attributes for user: {}", userId, e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Remove a user attribute + */ + @DeleteMapping("/{userId}/{attributeName}") + public ResponseEntity> removeUserAttribute( + @PathVariable String userId, + @PathVariable String attributeName, + HttpServletRequest request, HttpServletResponse response) { + + log.info("Removing attribute {} for user: {}", attributeName, userId); + + try { + var operatingUser = getOperatingUser(request, response); + String requestingUserId = operatingUser.getUserId(); + + // Basic authorization - users can only modify their own attributes unless admin + if (!userId.equals(requestingUserId) && !AccessUtil.canAccess(operatingUser, ApplicationAccessEnum.CAN_MANAGE_APPLICATION)) { + log.warn("User {} attempted to remove attributes for user {}", requestingUserId, userId); + return ResponseEntity.status(HttpStatus.FORBIDDEN).build(); + } + + boolean success = userAttributeService.removeUserAttribute(userId, attributeName); + + Map userResponse = new HashMap<>(); + userResponse.put("success", success); + userResponse.put("removed", success); + + return ResponseEntity.ok(userResponse); + + } catch (Exception e) { + log.error("Error removing attribute {} for user: {}", attributeName, userId, e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Sync user attributes from Keycloak + */ + @PostMapping("/{userId}/sync/keycloak") + public ResponseEntity> syncFromKeycloak( + @PathVariable String userId, + HttpServletRequest request, HttpServletResponse response) { + + log.info("Syncing attributes from Keycloak for user: {}", userId); + + try { + var operatingUser = getOperatingUser(request, response); + String requestingUserId = operatingUser.getUserId(); + + // Basic authorization - users can sync their own attributes or admin can sync any + if (!userId.equals(requestingUserId) && !AccessUtil.canAccess(operatingUser, ApplicationAccessEnum.CAN_MANAGE_APPLICATION)) { + log.warn("User {} attempted to sync attributes for user {}", requestingUserId, userId); + return ResponseEntity.status(HttpStatus.FORBIDDEN).build(); + } + + List syncedAttributes = userAttributeService.syncUserAttributesFromKeycloak(userId); + List responseDTOs = syncedAttributes.stream() + .map(this::convertToDTO) + .collect(Collectors.toList()); + + return ResponseEntity.ok(responseDTOs); + + } catch (Exception e) { + log.error("Error syncing attributes from Keycloak for user: {}", userId, e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Check if user has specific attribute value + */ + @GetMapping("/{userId}/check") + public ResponseEntity> checkUserAttribute( + @PathVariable String userId, + @RequestParam String attributeName, + @RequestParam String attributeValue) { + + log.debug("Checking if user {} has attribute {}={}", userId, attributeName, attributeValue); + + try { + boolean hasAttribute = userAttributeService.userHasAttributeValue(userId, attributeName, attributeValue); + + Map response = new HashMap<>(); + response.put("hasAttribute", hasAttribute); + + return ResponseEntity.ok(response); + + } catch (Exception e) { + log.error("Error checking attribute for user: {}", userId, e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Find users with specific attribute (admin endpoint) + */ + @GetMapping("/search") + public ResponseEntity> findUsersWithAttribute( + @RequestParam String attributeName, + @RequestParam String attributeValue) { + + log.debug("Finding users with attribute {}={}", attributeName, attributeValue); + + try { + List userIds = userAttributeService.findUsersWithAttribute(attributeName, attributeValue); + return ResponseEntity.ok(userIds); + + } catch (Exception e) { + log.error("Error finding users with attribute", e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Get all unique attribute names (admin endpoint) + */ + @GetMapping("/names") + public ResponseEntity> getAllAttributeNames() { + log.debug("Getting all unique attribute names"); + + try { + List attributeNames = userAttributeService.getAllAttributeNames(); + return ResponseEntity.ok(attributeNames); + + } catch (Exception e) { + log.error("Error getting attribute names", e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Get attribute statistics (admin endpoint) + */ + @GetMapping("/statistics") + public ResponseEntity> getAttributeStatistics() { + log.debug("Getting attribute statistics"); + + try { + Map stats = userAttributeService.getAttributeStatistics(); + return ResponseEntity.ok(stats); + + } catch (Exception e) { + log.error("Error getting attribute statistics", e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Convert UserAttribute entity to DTO + */ + private UserAttributeDTO convertToDTO(UserAttribute attribute) { + return UserAttributeDTO.builder() + .id(attribute.getId()) + .userId(attribute.getUserId()) + .attributeName(attribute.getAttributeName()) + .attributeValue(attribute.getAttributeValue()) + .attributeType(attribute.getAttributeType()) + .source(attribute.getSource()) + .isActive(attribute.getIsActive()) + .createdAt(attribute.getCreatedAt()) + .updatedAt(attribute.getUpdatedAt()) + .syncedFromKeycloak(attribute.getSyncedFromKeycloak()) + .build(); + } + +} \ No newline at end of file diff --git a/api/src/main/java/io/sentrius/sso/controllers/view/UserController.java b/api/src/main/java/io/sentrius/sso/controllers/view/UserController.java index 3eb30a22..9fc2e9c1 100644 --- a/api/src/main/java/io/sentrius/sso/controllers/view/UserController.java +++ b/api/src/main/java/io/sentrius/sso/controllers/view/UserController.java @@ -28,8 +28,10 @@ import io.sentrius.sso.core.services.UserPublicKeyService; import io.sentrius.sso.core.services.UserService; import io.sentrius.sso.core.services.WorkHoursService; +import io.sentrius.sso.core.services.users.UserAttributeService; import io.sentrius.sso.core.services.security.CryptoService; import io.sentrius.sso.core.utils.JsonUtil; +import io.sentrius.sso.core.model.users.UserAttribute; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; import lombok.extern.slf4j.Slf4j; @@ -50,12 +52,14 @@ public class UserController extends BaseController { final HostGroupService hostGroupService; final WorkHoursService workHoursService; final CryptoService cryptoService; + final UserAttributeService userAttributeService; protected UserController( UserService userService, SystemOptions systemOptions, ErrorOutputService errorOutputService, UserCustomizationService userThemeService, UserPublicKeyService userPublicKeyService, HostGroupService hostGroupService, - WorkHoursService workHoursService, CryptoService cryptoService + WorkHoursService workHoursService, CryptoService cryptoService, + UserAttributeService userAttributeService ) { super(userService, systemOptions, errorOutputService); this.userThemeService = userThemeService; @@ -63,6 +67,7 @@ protected UserController( this.hostGroupService = hostGroupService; this.workHoursService = workHoursService; this.cryptoService = cryptoService; + this.userAttributeService = userAttributeService; } @ModelAttribute("userSettings") @@ -173,12 +178,23 @@ public String listUsers(Model model) { public String editUser(Model model, HttpServletRequest request, HttpServletResponse response, @RequestParam("userId") String userId) throws GeneralSecurityException { model.addAttribute("globalAccessSet", UserType.createSuperUser().getAccessSet()); - Long id = Long.parseLong(cryptoService.decrypt(userId)); + var decryptedUserId = cryptoService.decrypt(userId); + Long id = Long.parseLong(decryptedUserId); User user = userService.getUserById(id); UserDTO userDTO = user.toDto(); + userDTO.userId = userId; + List userAttributes = userAttributeService.getUserAttributes(userId); + for(UserAttribute attr : userAttributes) { + if(attr.getAttributeName().equals("VISIBILITY_EXPRESSION")) { + model.addAttribute("visibilityExpression", attr.getStringValue()); + } + } var types = userService.getUserTypeList(); model.addAttribute("userTypes",types); model.addAttribute("user", userDTO); + log.info("Editing user: {}", userDTO); + log.info("user id is {}", userId); + model.addAttribute("userId", userId); return "sso/users/edit_user"; } @@ -222,6 +238,48 @@ public String auditUsers() { return "sso/users/audit_users"; } + @GetMapping("/attributes") + public String userAttributes(Model model, HttpServletRequest request, HttpServletResponse response) { + try { + var user = userService.getOperatingUser(request, response, null); + List userAttributes = userAttributeService.getUserAttributes(user.getUserId()); + model.addAttribute("userAttributes", userAttributes); + model.addAttribute("userId", user.getUserId()); + model.addAttribute("availableTypes", UserAttribute.AttributeType.values()); + model.addAttribute("availableSources", UserAttribute.Source.values()); + return "sso/users/user_attributes"; + } catch (Exception e) { + log.error("Error loading user attributes", e); + model.addAttribute("error", "Error loading user attributes: " + e.getMessage()); + return "sso/users/user_attributes"; + } + } + + @GetMapping("/attributes/manage") + public String manageUserAttributes(Model model, + @RequestParam(required = false) String userId, + HttpServletRequest request, + HttpServletResponse response) { + try { + String targetUserId = userId; + if (targetUserId == null) { + var user = userService.getOperatingUser(request, response, null); + targetUserId = user.getUserId(); + } + + List userAttributes = userAttributeService.getUserAttributes(targetUserId); + model.addAttribute("userAttributes", userAttributes); + model.addAttribute("targetUserId", targetUserId); + model.addAttribute("availableTypes", UserAttribute.AttributeType.values()); + model.addAttribute("availableSources", UserAttribute.Source.values()); + model.addAttribute("newAttribute", new UserAttribute()); + return "sso/users/manage_user_attributes"; + } catch (Exception e) { + log.error("Error loading user attributes for management", e); + model.addAttribute("error", "Error loading user attributes: " + e.getMessage()); + return "sso/users/manage_user_attributes"; + } + } } diff --git a/api/src/main/resources/db/migration/V21__agent_memory_store.sql b/api/src/main/resources/db/migration/V21__agent_memory_store.sql new file mode 100644 index 00000000..0716f270 --- /dev/null +++ b/api/src/main/resources/db/migration/V21__agent_memory_store.sql @@ -0,0 +1,105 @@ +-- Create agent memory store with markings support +CREATE TABLE agent_memory ( + id BIGSERIAL PRIMARY KEY, + memory_key VARCHAR(255) NOT NULL, + memory_value TEXT NOT NULL, + memory_type VARCHAR(50) DEFAULT 'JSON', + agent_id VARCHAR(255), + agent_name VARCHAR(255), + conversation_id VARCHAR(255), + classification VARCHAR(50) DEFAULT 'PRIVATE', + markings VARCHAR(255), + access_level VARCHAR(50) DEFAULT 'AGENT_ONLY', + creator_user_id VARCHAR(255), + creator_user_type VARCHAR(100), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + expires_at TIMESTAMP, + shared_with_agents TEXT, + metadata JSONB, + version INTEGER DEFAULT 1 +); + +-- Create index for efficient memory retrieval +CREATE INDEX idx_agent_memory_agent_id ON agent_memory(agent_id); +CREATE INDEX idx_agent_memory_conversation_id ON agent_memory(conversation_id); +CREATE INDEX idx_agent_memory_classification ON agent_memory(classification); +CREATE INDEX idx_agent_memory_access_level ON agent_memory(access_level); +CREATE INDEX idx_agent_memory_markings ON agent_memory(markings); +CREATE INDEX idx_agent_memory_creator ON agent_memory(creator_user_id); + +-- Create enhanced user attributes table for ABAC +CREATE TABLE user_attributes ( + id BIGSERIAL PRIMARY KEY, + user_id VARCHAR(255) NOT NULL, + attribute_name VARCHAR(255) NOT NULL, + attribute_value TEXT NOT NULL, + attribute_type VARCHAR(50) DEFAULT 'STRING', + source VARCHAR(50) DEFAULT 'SENTRIUS', + is_active BOOLEAN DEFAULT true, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + synced_from_keycloak BOOLEAN DEFAULT false, + UNIQUE(user_id, attribute_name) +); + +-- Create index for user attributes +CREATE INDEX idx_user_attributes_user_id ON user_attributes(user_id); +CREATE INDEX idx_user_attributes_name ON user_attributes(attribute_name); +CREATE INDEX idx_user_attributes_active ON user_attributes(is_active); + +-- Create memory access policies table for ABAC +CREATE TABLE memory_access_policies ( + id BIGSERIAL PRIMARY KEY, + policy_name VARCHAR(255) NOT NULL UNIQUE, + policy_description TEXT, + target_classification VARCHAR(50), + target_markings VARCHAR(255), + required_user_attributes JSONB, + required_agent_attributes JSONB, + access_type VARCHAR(50) DEFAULT 'READ', + is_active BOOLEAN DEFAULT true, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +-- Create index for memory access policies +CREATE INDEX idx_memory_policies_classification ON memory_access_policies(target_classification); +CREATE INDEX idx_memory_policies_markings ON memory_access_policies(target_markings); +CREATE INDEX idx_memory_policies_active ON memory_access_policies(is_active); + +-- Insert default memory access policies +INSERT INTO memory_access_policies ( + policy_name, + policy_description, + target_classification, + access_type, + required_user_attributes +) VALUES +('PUBLIC_READ', 'Allow read access to public memory for all users', 'PUBLIC', 'READ', '{}'), +('PRIVATE_OWNER_ONLY', 'Allow full access to private memory only for creator', 'PRIVATE', 'FULL', '{"created_by": "user_id"}'), +('SHARED_TEAM_READ', 'Allow read access to shared memory for team members', 'SHARED', 'READ', '{"team": "required"}'), +('CONFIDENTIAL_ADMIN_ONLY', 'Allow access to confidential memory only for admins', 'CONFIDENTIAL', 'FULL', '{"user_type": "ADMIN"}'); + +-- Insert default classification examples +INSERT INTO agent_memory ( + memory_key, + memory_value, + memory_type, + agent_id, + classification, + markings, + access_level, + creator_user_id, + metadata +) VALUES +('system.welcome_message', + '{"message": "Welcome to Sentrius Agent Memory Store", "version": "1.0"}', + 'JSON', + 'system', + 'PUBLIC', + 'SYSTEM,WELCOME', + 'ALL_USERS', + 'system', + '{"is_system": true, "category": "documentation"}' +); \ No newline at end of file diff --git a/api/src/main/resources/db/migration/V22__add_vector_support.sql b/api/src/main/resources/db/migration/V22__add_vector_support.sql new file mode 100644 index 00000000..b8418f39 --- /dev/null +++ b/api/src/main/resources/db/migration/V22__add_vector_support.sql @@ -0,0 +1,36 @@ +-- Add vector support to agent memory store +-- Enable pgvector extension for PostgreSQL vector operations +CREATE EXTENSION IF NOT EXISTS vector; + +-- Add embedding column to agent_memory table for semantic search +ALTER TABLE agent_memory ADD COLUMN embedding vector(1536); + +-- Create index for vector similarity search using cosine distance +CREATE INDEX idx_agent_memory_embedding ON agent_memory USING ivfflat (embedding vector_cosine_ops); + +-- Create additional indexes for hybrid search (combining vector with markings) +CREATE INDEX idx_agent_memory_embedding_classification ON agent_memory (classification, embedding); +CREATE INDEX idx_agent_memory_embedding_markings ON agent_memory (markings, embedding); + +-- Add metadata for vector store configuration +INSERT INTO agent_memory ( + memory_key, + memory_value, + memory_type, + agent_id, + classification, + markings, + access_level, + creator_user_id, + metadata +) VALUES +('system.vector_store_config', + '{"dimension": 1536, "similarity_function": "cosine", "index_type": "ivfflat", "enabled": true}', + 'JSON', + 'system', + 'PUBLIC', + 'SYSTEM,VECTOR_STORE,CONFIG', + 'ALL_USERS', + 'system', + '{"is_system": true, "category": "configuration", "version": "1.0"}' +); \ No newline at end of file diff --git a/api/src/main/resources/db/migration/V23__agent_memory_indexes.sql b/api/src/main/resources/db/migration/V23__agent_memory_indexes.sql new file mode 100644 index 00000000..28fa23ee --- /dev/null +++ b/api/src/main/resources/db/migration/V23__agent_memory_indexes.sql @@ -0,0 +1,28 @@ +-- Drop problematic old index (if it exists) +DROP INDEX IF EXISTS idx_agent_memory_embedding_classification; + +-- Vector similarity search index +CREATE INDEX IF NOT EXISTS idx_agent_memory_embedding + ON agent_memory + USING ivfflat (embedding vector_cosine_ops) + WITH (lists = 100); + +-- Scalar filters +CREATE INDEX IF NOT EXISTS idx_agent_memory_classification + ON agent_memory (classification); + +CREATE INDEX IF NOT EXISTS idx_agent_memory_markings + ON agent_memory (markings); + +CREATE INDEX IF NOT EXISTS idx_agent_memory_expires_at + ON agent_memory (expires_at); + +-- JSONB metadata indexing +CREATE INDEX IF NOT EXISTS idx_agent_memory_metadata_gin + ON agent_memory + USING gin (metadata jsonb_path_ops); + +-- Optional full-text index for memory_value +CREATE INDEX IF NOT EXISTS idx_agent_memory_memory_value_fts + ON agent_memory + USING gin (to_tsvector('english', memory_value)); diff --git a/api/src/main/resources/db/migration/V24__agent_memory_indexes.sql b/api/src/main/resources/db/migration/V24__agent_memory_indexes.sql new file mode 100644 index 00000000..4b8a8f21 --- /dev/null +++ b/api/src/main/resources/db/migration/V24__agent_memory_indexes.sql @@ -0,0 +1,17 @@ +-- V__drop_bad_embedding_indexes.sql + +-- Drop problematic composite B-Tree indexes +DROP INDEX IF EXISTS idx_agent_memory_embedding_classification; +DROP INDEX IF EXISTS idx_agent_memory_embedding_markings; + +-- Create vector similarity index +CREATE INDEX IF NOT EXISTS idx_agent_memory_embedding_ivfflat + ON agent_memory USING ivfflat (embedding vector_l2_ops) + WITH (lists = 100); + +-- Index classification and markings separately for filtering +CREATE INDEX IF NOT EXISTS idx_agent_memory_classification + ON agent_memory (classification); + +CREATE INDEX IF NOT EXISTS idx_agent_memory_markings + ON agent_memory (markings); diff --git a/api/src/main/resources/templates/sso/atpl/configure.html b/api/src/main/resources/templates/sso/atpl/configure.html index 375fb4d4..8b483200 100644 --- a/api/src/main/resources/templates/sso/atpl/configure.html +++ b/api/src/main/resources/templates/sso/atpl/configure.html @@ -1055,11 +1055,46 @@
Endpoint Rule ${endpointCounter + 1}
} } + function incrementVersion(current) { + if (!current) return "v1"; + + // Match things like "v0", "v12", "v1.2.3" + const match = current.match(/^v(\d+)(?:\.(\d+)(?:\.(\d+))?)?$/); + if (!match) { + // If it doesn’t match our format, just return current unchanged + return current; + } + + let major = parseInt(match[1] || "0", 10); + let minor = parseInt(match[2] || "0", 10); + let patch = parseInt(match[3] || "0", 10); + + if (match[3] !== undefined) { + // If we have patch version: bump patch + patch++; + return `v${major}.${minor}.${patch}`; + } else if (match[2] !== undefined) { + // If we have minor version: bump minor + minor++; + return `v${major}.${minor}`; + } else { + // Just major: bump major + major++; + return `v${major}`; + } + } + // Save policy function savePolicy() { try { const policy = formDataToPolicyObject(); - + + // Auto-increment version if it matches known pattern + policy.version = incrementVersion(policy.version); + + // Update the input field so user sees the new version + document.getElementById('version').value = policy.version; + // Basic validation if (!policy.version || !policy.policy_id) { alert('Version and Policy ID are required fields.'); @@ -1220,6 +1255,18 @@
Endpoint Rule ${endpointCounter + 1}
if (primitive.id) container.querySelector('input[name*=".id"]').value = primitive.id; if (primitive.description) container.querySelector('input[name*=".description"]').value = primitive.description; + if (primitive.endpoints) { + const endpointsList = container.querySelector('#primitive-endpoints-' + (primitiveCounter - 1)); + primitive.endpoints.forEach(endpoint => { + addListItem( + endpointsList.id, + 'capabilities.primitives[' + (primitiveCounter - 1) + '].endpoints' + ); + const inputs = endpointsList.querySelectorAll('input'); + inputs[inputs.length - 1].value = endpoint; + }); + } + /* if (primitive.endpoints) { const tagsList = container.querySelector('.dynamic-list'); primitive.endpoints.forEach(endpoint => { @@ -1231,6 +1278,7 @@
Endpoint Rule ${endpointCounter + 1}
inputs[inputs.length - 1].value = endpoint; }); } + */ if (primitive.tags) { diff --git a/api/src/main/resources/templates/sso/users/edit_user.html b/api/src/main/resources/templates/sso/users/edit_user.html index 55b41cf7..300de23a 100644 --- a/api/src/main/resources/templates/sso/users/edit_user.html +++ b/api/src/main/resources/templates/sso/users/edit_user.html @@ -67,6 +67,22 @@

Edit User

Cancel + + +
+

Access Policy Expression (ABAC)

+
+
+ + +
+ + Use & for AND, | for OR, () for grouping. + Example: (DEV&LEAD)|DATA_ATTRIBUTE + + @@ -77,20 +93,18 @@

Edit User

const form = document.getElementById("editUserForm"); const alertContainer = document.getElementById("alertContainer"); + // Normal user form submit form.addEventListener("submit", function (e) { e.preventDefault(); const formData = new FormData(form); const payload = {}; let csrfToken = null; - let csrfParam = null; formData.forEach((value, key) => { if (key.toLowerCase().includes("csrf")) { - csrfParam = key; csrfToken = value; } else { - // Support nested keys (e.g., authorizationType.id → { authorizationType: { id: ... } }) if (key.includes(".")) { const [parent, child] = key.split("."); if (!payload[parent]) payload[parent] = {}; @@ -126,6 +140,47 @@

Edit User

`; }); }); + + // Save visibility expression + const userId = document.querySelector("input[name='userId']").value; + const saveBtn = document.getElementById("saveExpressionBtn"); + const input = document.getElementById("visibilityExpression"); + const attributeAlert = document.getElementById("attributeAlert"); + + saveBtn.addEventListener("click", function () { + const csrfToken = document.getElementById("csrfToken").value; + + const payload = { + attributeName: "VISIBILITY_EXPRESSION", + attributeValue: input.value, + attributeType: "STRING", + source: "SENTRIUS", + isActive: true + }; + + fetch(`/api/v1/users/attributes/update?userId=${userId}`, { + method: "POST", + headers: { + "Content-Type": "application/json", + "X-Requested-With": "XMLHttpRequest", + "X-CSRF-TOKEN": csrfToken + }, + body: JSON.stringify(payload) + }) + .then(res => { + if (!res.ok) throw new Error("Failed to save expression"); + return res.json(); + }) + .then(() => { + attributeAlert.innerHTML = + `
Expression saved successfully
`; + }) + .catch(err => { + console.error(err); + attributeAlert.innerHTML = + `
Error saving expression
`; + }); + }); }); diff --git a/api/src/test/java/io/sentrius/sso/controllers/api/UserPublicKeyApiControllerTest.java b/api/src/test/java/io/sentrius/sso/controllers/api/users/UserPublicKeyApiControllerTest.java similarity index 98% rename from api/src/test/java/io/sentrius/sso/controllers/api/UserPublicKeyApiControllerTest.java rename to api/src/test/java/io/sentrius/sso/controllers/api/users/UserPublicKeyApiControllerTest.java index 759986b8..1bea4b72 100644 --- a/api/src/test/java/io/sentrius/sso/controllers/api/UserPublicKeyApiControllerTest.java +++ b/api/src/test/java/io/sentrius/sso/controllers/api/users/UserPublicKeyApiControllerTest.java @@ -1,6 +1,5 @@ -package io.sentrius.sso.controllers.api; +package io.sentrius.sso.controllers.api.users; -import com.fasterxml.jackson.databind.ObjectMapper; import io.sentrius.sso.core.config.SystemOptions; import io.sentrius.sso.core.dto.UserPublicKeyDTO; import io.sentrius.sso.core.model.users.User; diff --git a/core/pom.xml b/core/pom.xml index 85966178..82ca3999 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -200,6 +200,14 @@ classgraph + + + org.springframework.ai + spring-ai-pgvector-store + ${spring-ai-version} + true + + diff --git a/core/src/main/java/io/sentrius/sso/core/config/VectorStoreConfig.java b/core/src/main/java/io/sentrius/sso/core/config/VectorStoreConfig.java new file mode 100644 index 00000000..4bff52c0 --- /dev/null +++ b/core/src/main/java/io/sentrius/sso/core/config/VectorStoreConfig.java @@ -0,0 +1,83 @@ +package io.sentrius.sso.core.config; + +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.client.RestTemplate; + +/** + * Configuration for vector store capabilities in agent memory store. + * Configuration values are read from application properties but can be + * overridden dynamically through SystemOptions where available. + */ +@Slf4j +@Configuration +public class VectorStoreConfig { + + @Value("${sentrius.memory.vector.dimension:1536}") + private int vectorDimension; + + @Value("${sentrius.memory.vector.similarity-threshold:0.7}") + private double defaultSimilarityThreshold; + + @Value("${sentrius.memory.vector.enabled:true}") + private boolean vectorStoreEnabled; + + /** + * Configuration properties for vector store + */ + @Bean + public VectorStoreProperties vectorStoreProperties() { + VectorStoreProperties properties = new VectorStoreProperties(); + properties.setDimension(vectorDimension); + properties.setDefaultSimilarityThreshold(defaultSimilarityThreshold); + properties.setEnabled(vectorStoreEnabled); + + log.info("Vector store configuration: enabled={}, dimension={}, threshold={}", + properties.isEnabled(), properties.getDimension(), properties.getDefaultSimilarityThreshold()); + + return properties; + } + + /** + * RestTemplate for HTTP calls + */ + @Bean + public RestTemplate restTemplate() { + return new RestTemplate(); + } + + /** + * Properties class for vector store configuration + */ + public static class VectorStoreProperties { + private int dimension; + private double defaultSimilarityThreshold; + private boolean enabled; + + public int getDimension() { + return dimension; + } + + public void setDimension(int dimension) { + this.dimension = dimension; + } + + public double getDefaultSimilarityThreshold() { + return defaultSimilarityThreshold; + } + + public void setDefaultSimilarityThreshold(double defaultSimilarityThreshold) { + this.defaultSimilarityThreshold = defaultSimilarityThreshold; + } + + public boolean isEnabled() { + return enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + } +} \ No newline at end of file diff --git a/core/src/main/java/io/sentrius/sso/core/data/AgentMessageBase.java b/core/src/main/java/io/sentrius/sso/core/data/AgentMessageBase.java new file mode 100644 index 00000000..ed614d71 --- /dev/null +++ b/core/src/main/java/io/sentrius/sso/core/data/AgentMessageBase.java @@ -0,0 +1,19 @@ +package io.sentrius.sso.core.data; + +import java.util.ArrayList; +import java.util.List; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.experimental.SuperBuilder; + +@SuperBuilder +@Getter +@NoArgsConstructor +@AllArgsConstructor +public class AgentMessageBase { + + @Builder.Default + List messages = new ArrayList<>(); +} diff --git a/core/src/main/java/io/sentrius/sso/core/data/VectorMemoryStore.java b/core/src/main/java/io/sentrius/sso/core/data/VectorMemoryStore.java new file mode 100644 index 00000000..efe41dff --- /dev/null +++ b/core/src/main/java/io/sentrius/sso/core/data/VectorMemoryStore.java @@ -0,0 +1,9 @@ +package io.sentrius.sso.core.data; + +import java.util.List; +import java.util.Map; + +public interface VectorMemoryStore { + void upsert(String collection, String id, float[] vector, Map payload); + List search(String collection, float[] queryVector, int topK, Map filter); +} \ No newline at end of file diff --git a/core/src/main/java/io/sentrius/sso/core/data/VectorResult.java b/core/src/main/java/io/sentrius/sso/core/data/VectorResult.java new file mode 100644 index 00000000..c7e3abb0 --- /dev/null +++ b/core/src/main/java/io/sentrius/sso/core/data/VectorResult.java @@ -0,0 +1,5 @@ +package io.sentrius.sso.core.data; + +import com.fasterxml.jackson.databind.JsonNode; + +public record VectorResult(String id, float score, JsonNode payload) {} \ No newline at end of file diff --git a/core/src/main/java/io/sentrius/sso/core/dto/AgentRegistrationDTO.java b/core/src/main/java/io/sentrius/sso/core/dto/AgentRegistrationDTO.java index aa646d3e..a1022114 100644 --- a/core/src/main/java/io/sentrius/sso/core/dto/AgentRegistrationDTO.java +++ b/core/src/main/java/io/sentrius/sso/core/dto/AgentRegistrationDTO.java @@ -18,7 +18,8 @@ public class AgentRegistrationDTO { private final String agentPublicKeyAlgo; private final String clientSecret; private final String clientId; - private final String agentType; + @Builder.Default + private final String agentType = "chat"; private final String agentCallbackUrl; @Builder.Default private final String agentContextId = ""; diff --git a/core/src/main/java/io/sentrius/sso/core/dto/PagedResultDTO.java b/core/src/main/java/io/sentrius/sso/core/dto/PagedResultDTO.java new file mode 100644 index 00000000..0df48b5d --- /dev/null +++ b/core/src/main/java/io/sentrius/sso/core/dto/PagedResultDTO.java @@ -0,0 +1,27 @@ +package io.sentrius.sso.core.dto; + +import java.util.List; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +@Builder +@Getter +@Setter +@AllArgsConstructor +@NoArgsConstructor +@JsonIgnoreProperties(ignoreUnknown = true) +public class PagedResultDTO { + private List content; + private int totalPages; + private long totalElements; + private int number; + private int size; + private boolean last; + private boolean first; + + // getters and setters +} diff --git a/core/src/main/java/io/sentrius/sso/core/dto/agents/AgentExecutionContextDTO.java b/core/src/main/java/io/sentrius/sso/core/dto/agents/AgentExecutionContextDTO.java index b1fc0259..e93c75d4 100644 --- a/core/src/main/java/io/sentrius/sso/core/dto/agents/AgentExecutionContextDTO.java +++ b/core/src/main/java/io/sentrius/sso/core/dto/agents/AgentExecutionContextDTO.java @@ -2,6 +2,7 @@ import java.io.IOException; import java.util.*; +import java.util.concurrent.ConcurrentHashMap; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.JsonNode; @@ -38,6 +39,10 @@ public class AgentExecutionContextDTO { @Builder.Default private ObjectNode callParams = JsonUtil.MAPPER.createObjectNode(); + @Builder.Default + + private final Map longTermMemories = new ConcurrentHashMap<>(); + // === Memory Management === public void addToMemory(JsonNode node) { @@ -55,6 +60,72 @@ public void putStructuredToMemory(String key, JsonNode value) { ObjectNode wrapper = JsonUtil.MAPPER.createObjectNode(); wrapper.set(key, value); agentDataList.add(wrapper); + + } + + // === Persistent Memory Integration === + + /** + * Store value to persistent memory with markings + */ + public void addToPersistentMemory(String key, JsonNode value, String classification, String[] markings) { + // Add to short-term memory for immediate access + putStructuredToMemory(key, value); + + // Store metadata for persistent storage + if (agentContext != null) { + ObjectNode memoryMeta = JsonUtil.MAPPER.createObjectNode(); + memoryMeta.put("key", key); + memoryMeta.set("value", value); + memoryMeta.put("classification", classification != null ? classification : "PRIVATE"); + if (markings != null) { + memoryMeta.put("markings", String.join(",", markings)); + } + memoryMeta.put("persistent", true); + + // Add to data list with persistent flag + longTermMemories.put(key, memoryMeta); + + } + } + + /** + * Store simple value to persistent memory + */ + public void addToPersistentMemory(String key, String value, String classification, String[] markings) { + JsonNode jsonValue = JsonUtil.MAPPER.convertValue(value, JsonNode.class); + addToPersistentMemory(key, jsonValue, classification, markings); + } + + /** + * Store object to persistent memory + */ + public void addToPersistentMemory(String key, Object value, String classification, String[] markings) { + JsonNode jsonValue = JsonUtil.MAPPER.convertValue(value, JsonNode.class); + addToPersistentMemory(key, jsonValue, classification, markings); + } + + /** + * Get list of memories marked for persistent storage + */ + public Map getPersistentMemoryItems() { + + return longTermMemories; + } + + /** + * Check if context has persistent memory items + */ + public boolean hasPersistentMemory() { + return !getPersistentMemoryItems().isEmpty(); + } + + public Map flushPersistentMemory() { + var persistentItems = getPersistentMemoryItems(); + longTermMemories.clear(); + return persistentItems; + + } public static void flatten(String prefix, JsonNode node, Map map) { diff --git a/core/src/main/java/io/sentrius/sso/core/dto/agents/AgentMemoryDTO.java b/core/src/main/java/io/sentrius/sso/core/dto/agents/AgentMemoryDTO.java new file mode 100644 index 00000000..f9aab291 --- /dev/null +++ b/core/src/main/java/io/sentrius/sso/core/dto/agents/AgentMemoryDTO.java @@ -0,0 +1,87 @@ +package io.sentrius.sso.core.dto.agents; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import lombok.NoArgsConstructor; +import lombok.AllArgsConstructor; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; + +import java.time.Instant; +import java.util.Map; +import java.util.Arrays; + +@Builder +@Getter +@Setter +@NoArgsConstructor +@AllArgsConstructor +@JsonIgnoreProperties(ignoreUnknown = true) +public class AgentMemoryDTO { + + private Long id; + private String memoryKey; + private String memoryValue; + private String memoryType; + private String agentId; + private String agentName; + private String conversationId; + private String classification; + private String[] markings; + private String accessLevel; + private String creatorUserId; + private String creatorUserType; + private Instant createdAt; + private Instant updatedAt; + private Instant expiresAt; + private String[] sharedWithAgents; + private Map metadata; + private Integer version; + + // Vector embedding for semantic search (optional, for display purposes) + private float[] embedding; + private boolean hasEmbedding; + + // Helper methods for markings + public void setMarkingsFromString(String markingsStr) { + this.markings = markingsStr != null ? markingsStr.split(",") : new String[0]; + } + + public String getMarkingsAsString() { + return markings != null ? String.join(",", markings) : null; + } + + // Helper methods for shared agents + public void setSharedAgentsFromString(String sharedStr) { + this.sharedWithAgents = sharedStr != null ? sharedStr.split(",") : new String[0]; + } + + public String getSharedAgentsAsString() { + return sharedWithAgents != null ? String.join(",", sharedWithAgents) : null; + } + + // Validation helpers + public boolean isExpired() { + return expiresAt != null && Instant.now().isAfter(expiresAt); + } + + public boolean hasMarking(String marking) { + if (markings == null) return false; + for (String m : markings) { + if (m.trim().equalsIgnoreCase(marking.trim())) { + return true; + } + } + return false; + } + + // Helper methods for embeddings + public String getEmbeddingAsString() { + return embedding != null ? Arrays.toString(embedding) : null; + } + + public void setEmbeddingFromArray(float[] embedding) { + this.embedding = embedding; + this.hasEmbedding = embedding != null && embedding.length > 0; + } +} \ No newline at end of file diff --git a/core/src/main/java/io/sentrius/sso/core/dto/agents/MemoryQueryDTO.java b/core/src/main/java/io/sentrius/sso/core/dto/agents/MemoryQueryDTO.java new file mode 100644 index 00000000..08fa8256 --- /dev/null +++ b/core/src/main/java/io/sentrius/sso/core/dto/agents/MemoryQueryDTO.java @@ -0,0 +1,92 @@ +package io.sentrius.sso.core.dto.agents; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import lombok.NoArgsConstructor; +import lombok.AllArgsConstructor; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import lombok.ToString; +import org.springframework.data.domain.Sort; + +import java.util.List; + +@Builder +@Getter +@Setter +@NoArgsConstructor +@AllArgsConstructor +@ToString +@JsonIgnoreProperties(ignoreUnknown = true) +public class MemoryQueryDTO { + + private String agentId; + private String classification; + private String markings; + private String accessLevel; + private String creatorUserId; + private String memoryKey; + private String searchTerm; + private Boolean includeExpired; + private List includeMarkings; + private List excludeMarkings; + private String memoryType; + + // Pagination parameters + @Builder.Default + private Integer page = 0; + @Builder.Default + private Integer size = 100; + private String sortBy; + private Sort.Direction sortDirection; + + // Response filtering + private Boolean includeMetadata; + private Boolean includeSharedAgents; + private Boolean excludeValues; // Only return metadata, not actual values + + // Default values + public Integer getPage() { + return page != null ? page : 0; + } + + public Integer getSize() { + return size != null ? size : 20; + } + + public String getSortBy() { + return sortBy != null ? sortBy : "createdAt"; + } + + public Sort.Direction getSortDirection() { + return sortDirection != null ? sortDirection : Sort.Direction.DESC; + } + + public Boolean getIncludeExpired() { + return includeExpired != null ? includeExpired : false; + } + + public Boolean getIncludeMetadata() { + return includeMetadata != null ? includeMetadata : true; + } + + public Boolean getIncludeSharedAgents() { + return includeSharedAgents != null ? includeSharedAgents : true; + } + + public Boolean getExcludeValues() { + return excludeValues != null ? excludeValues : false; + } + + // Helper methods for validation + public boolean isValid() { + return getPage() >= 0 && getSize() > 0 && getSize() <= 100; + } + + public boolean hasFilters() { + return agentId != null || classification != null || markings != null || + accessLevel != null || creatorUserId != null || memoryKey != null || + searchTerm != null || (includeMarkings != null && !includeMarkings.isEmpty()) || + (excludeMarkings != null && !excludeMarkings.isEmpty()); + } +} \ No newline at end of file diff --git a/core/src/main/java/io/sentrius/sso/core/dto/capabilities/EndpointDescriptor.java b/core/src/main/java/io/sentrius/sso/core/dto/capabilities/EndpointDescriptor.java index abb56a22..345c670c 100644 --- a/core/src/main/java/io/sentrius/sso/core/dto/capabilities/EndpointDescriptor.java +++ b/core/src/main/java/io/sentrius/sso/core/dto/capabilities/EndpointDescriptor.java @@ -2,9 +2,12 @@ import java.util.List; import java.util.Map; +import com.fasterxml.jackson.databind.node.ObjectNode; +import io.sentrius.sso.core.utils.JsonUtil; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; @@ -19,6 +22,7 @@ @Getter @Setter @ToString +@EqualsAndHashCode @NoArgsConstructor @AllArgsConstructor public class EndpointDescriptor { @@ -42,4 +46,19 @@ public class EndpointDescriptor { private boolean requiresTokenManagement = false; private Class returnType; + + public static String toEmbeddableJson(EndpointDescriptor ed) { + ObjectNode node = JsonUtil.MAPPER.createObjectNode(); + node.put("name", ed.getName()); + node.put("description", ed.getDescription()); + node.put("type", ed.getType()); + node.put("httpMethod", ed.getHttpMethod()); + node.put("path", ed.getPath()); + node.put("className", ed.getClassName()); + node.put("methodName", ed.getMethodName()); + node.put("requiresAuthentication", ed.isRequiresAuthentication()); + node.put("requiresTokenManagement", ed.isRequiresTokenManagement()); + + return node.toString(); + } } \ No newline at end of file diff --git a/core/src/main/java/io/sentrius/sso/core/dto/users/UserAttributeDTO.java b/core/src/main/java/io/sentrius/sso/core/dto/users/UserAttributeDTO.java new file mode 100644 index 00000000..26b1445b --- /dev/null +++ b/core/src/main/java/io/sentrius/sso/core/dto/users/UserAttributeDTO.java @@ -0,0 +1,72 @@ +package io.sentrius.sso.core.dto.users; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import lombok.NoArgsConstructor; +import lombok.AllArgsConstructor; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; + +import java.time.Instant; + +@Builder +@Getter +@Setter +@NoArgsConstructor +@AllArgsConstructor +@JsonIgnoreProperties(ignoreUnknown = true) +public class UserAttributeDTO { + + private Long id; + private String userId; + private String attributeName; + private String attributeValue; + private String attributeType; + private String source; + private Boolean isActive; + private Instant createdAt; + private Instant updatedAt; + private Boolean syncedFromKeycloak; + + // Helper methods for type-safe value access + public String getStringValue() { + return attributeValue; + } + + public Integer getIntegerValue() { + try { + return Integer.parseInt(attributeValue); + } catch (NumberFormatException e) { + return null; + } + } + + public Boolean getBooleanValue() { + return Boolean.parseBoolean(attributeValue); + } + + public String[] getListValue() { + return attributeValue != null ? attributeValue.split(",") : new String[0]; + } + + // Validation helper + public boolean isValidForType() { + if (attributeType == null) return true; + + switch (attributeType.toUpperCase()) { + case "INTEGER": + try { + Integer.parseInt(attributeValue); + return true; + } catch (NumberFormatException e) { + return false; + } + case "BOOLEAN": + return "true".equalsIgnoreCase(attributeValue) || "false".equalsIgnoreCase(attributeValue); + case "JSON": + return attributeValue.trim().startsWith("{") || attributeValue.trim().startsWith("["); + default: + return true; + } + } +} \ No newline at end of file diff --git a/core/src/main/java/io/sentrius/sso/core/embeddings/EmbeddingService.java b/core/src/main/java/io/sentrius/sso/core/embeddings/EmbeddingService.java new file mode 100644 index 00000000..b6247971 --- /dev/null +++ b/core/src/main/java/io/sentrius/sso/core/embeddings/EmbeddingService.java @@ -0,0 +1,9 @@ +package io.sentrius.sso.core.embeddings; + +import com.fasterxml.jackson.core.JsonProcessingException; +import io.sentrius.sso.core.dto.ztat.TokenDTO; +import io.sentrius.sso.core.exceptions.ZtatException; + +public interface EmbeddingService { + float[] embed(TokenDTO dto, String text) throws ZtatException, JsonProcessingException; +} diff --git a/core/src/main/java/io/sentrius/sso/core/model/verbs/Endpoint.java b/core/src/main/java/io/sentrius/sso/core/model/verbs/Endpoint.java new file mode 100644 index 00000000..9b675dee --- /dev/null +++ b/core/src/main/java/io/sentrius/sso/core/model/verbs/Endpoint.java @@ -0,0 +1,13 @@ +package io.sentrius.sso.core.model.verbs; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +public @interface Endpoint { + String description() default ""; + +} \ No newline at end of file diff --git a/core/src/main/java/io/sentrius/sso/core/services/agents/AgentClientService.java b/core/src/main/java/io/sentrius/sso/core/services/agents/AgentClientService.java index 5f17cb81..53a08e13 100644 --- a/core/src/main/java/io/sentrius/sso/core/services/agents/AgentClientService.java +++ b/core/src/main/java/io/sentrius/sso/core/services/agents/AgentClientService.java @@ -12,8 +12,11 @@ import io.sentrius.sso.core.dto.AgentCommunicationDTO; import io.sentrius.sso.core.dto.AgentHeartbeatDTO; import io.sentrius.sso.core.dto.AgentRegistrationDTO; +import io.sentrius.sso.core.dto.PagedResultDTO; import io.sentrius.sso.core.dto.agents.AgentContextDTO; import io.sentrius.sso.core.dto.agents.AgentContextRequestDTO; +import io.sentrius.sso.core.dto.agents.AgentMemoryDTO; +import io.sentrius.sso.core.dto.agents.MemoryQueryDTO; import io.sentrius.sso.core.dto.capabilities.EndpointDescriptor; import io.sentrius.sso.core.dto.agents.AgentExecution; import io.sentrius.sso.core.dto.ztat.AtatRequest; @@ -312,4 +315,37 @@ public String getCreatedAgentStatus(AgentExecution execution, String agentId) return zeroTrustClientService.callGetOnApi(execution, ask , Maps.immutableEntry("agentId", List.of(agentId))); } + + public void storeMemory(AgentExecution execution, String agentId, AgentMemoryDTO memory) { + String url = "/api/v1/agents/memory/store"; + try { + zeroTrustClientService.callPostOnApi(execution, url,memory, + Maps.immutableEntry("agentId",List.of(agentId)) ); + } catch (ZtatException e) { + log.error("Failed to store memory for key {}", e.getMessage()); + } + } + + public List retrieveMemories(AgentExecution execution, String agentId, MemoryQueryDTO query) { + String url = "/api/v1/agents/memory/query"; + try { + + PagedResultDTO page = zeroTrustClientService.callPostOnApi( + execution, + url, + query, + AgentMemoryDTO.class, + query.getPage(), query.getSize(), List.of("created_at,desc"), Maps.immutableEntry("agentId", + List.of(agentId)) + ); + + + + return page.getContent(); + + } catch (ZtatException e) { + log.error("Failed to store memory for key {}", e.getMessage()); + } + return List.of(); + } } diff --git a/core/src/main/java/io/sentrius/sso/core/services/agents/AgentExecutionService.java b/core/src/main/java/io/sentrius/sso/core/services/agents/AgentExecutionService.java index b3d7f5b8..20a2beb7 100644 --- a/core/src/main/java/io/sentrius/sso/core/services/agents/AgentExecutionService.java +++ b/core/src/main/java/io/sentrius/sso/core/services/agents/AgentExecutionService.java @@ -2,10 +2,12 @@ import java.util.UUID; import java.util.concurrent.TimeUnit; +import com.github.benmanes.caffeine.cache.Cache; import com.github.benmanes.caffeine.cache.Caffeine; import com.github.benmanes.caffeine.cache.LoadingCache; import io.sentrius.sso.core.dto.UserDTO; import io.sentrius.sso.core.dto.agents.AgentExecution; +import io.sentrius.sso.core.dto.agents.AgentExecutionContextDTO; import org.springframework.stereotype.Service; @Service @@ -16,6 +18,19 @@ public class AgentExecutionService { .expireAfterWrite(1, TimeUnit.HOURS) .build(this::createExecution); + private final Cache agentExecutionContextCache = + Caffeine.newBuilder() + .expireAfterWrite(1, TimeUnit.HOURS).build(); + + public void setExecutionContextDTO(AgentExecution execution, AgentExecutionContextDTO contextDTO) { + agentExecutionContextCache.put(execution.getExecutionId(), contextDTO); + } + + + public AgentExecutionContextDTO getExecutionContextDTO(String executionId) { + return agentExecutionContextCache.getIfPresent(executionId); + } + protected AgentExecution createExecution(UserDTO user){ return AgentExecution.builder().user(user).executionId(UUID.randomUUID().toString()).build(); } diff --git a/core/src/main/java/io/sentrius/sso/core/services/agents/EmbeddingService.java b/core/src/main/java/io/sentrius/sso/core/services/agents/EmbeddingService.java new file mode 100644 index 00000000..9447c102 --- /dev/null +++ b/core/src/main/java/io/sentrius/sso/core/services/agents/EmbeddingService.java @@ -0,0 +1,225 @@ +package io.sentrius.sso.core.services.agents; + +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.ResponseEntity; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.stereotype.Service; +import org.springframework.web.client.RestTemplate; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Embedding service that delegates to the integration proxy for proper security handling. + * This service acts as a facade to the integration proxy's embedding capabilities. + */ +@Slf4j +@Service +public class EmbeddingService { + + private final RestTemplate restTemplate; + private final String integrationProxyUrl; + + public EmbeddingService( + RestTemplate restTemplate, + @Value("${sentrius.integration-proxy.url:http://localhost:8081}") String integrationProxyUrl) { + this.restTemplate = restTemplate; + this.integrationProxyUrl = integrationProxyUrl; + } + + /** + * Check if embedding service is available + */ + public boolean isAvailable() { + try { + String url = integrationProxyUrl + "/api/v1/embeddings/status"; + HttpHeaders headers = createAuthHeaders(); + if (headers == null) { + return false; + } + + HttpEntity entity = new HttpEntity<>(headers); + ResponseEntity response = restTemplate.exchange(url, HttpMethod.GET, entity, Map.class); + + if (response.getStatusCode().is2xxSuccessful() && response.getBody() != null) { + Map status = response.getBody(); + return Boolean.TRUE.equals(status.get("available")); + } + + return false; + } catch (Exception e) { + log.warn("Failed to check embedding service availability: {}", e.getMessage()); + return false; + } + } + + /** + * Generate embedding for the given text via integration proxy + */ + public float[] embed(String text) { + if (text == null || text.trim().isEmpty()) { + log.warn("Cannot generate embedding for empty text"); + return null; + } + + try { + String url = integrationProxyUrl + "/api/v1/embeddings/generate"; + HttpHeaders headers = createAuthHeaders(); + if (headers == null) { + log.warn("No authentication context available for embedding generation"); + return null; + } + + Map requestBody = new HashMap<>(); + requestBody.put("text", text); + + HttpEntity> entity = new HttpEntity<>(requestBody, headers); + + ResponseEntity response = restTemplate.exchange(url, HttpMethod.POST, entity, Map.class); + + if (response.getStatusCode().is2xxSuccessful() && response.getBody() != null) { + Map responseBody = response.getBody(); + + if (responseBody.containsKey("embedding")) { + Object embeddingObj = responseBody.get("embedding"); + + if (embeddingObj instanceof float[]) { + return (float[]) embeddingObj; + } else if (embeddingObj instanceof List) { + @SuppressWarnings("unchecked") + List embeddingList = (List) embeddingObj; + float[] result = new float[embeddingList.size()]; + for (int i = 0; i < embeddingList.size(); i++) { + result[i] = embeddingList.get(i).floatValue(); + } + + log.debug("Generated embedding with {} dimensions for text length: {}", + result.length, text.length()); + return result; + } + } + } + + log.warn("Failed to generate embedding - unexpected response format"); + return null; + + } catch (Exception e) { + log.error("Error generating embedding for text: {}", text.substring(0, Math.min(100, text.length())), e); + return null; + } + } + + /** + * Generate embeddings for multiple texts in batch via integration proxy + */ + public Map embedBatch(List texts) { + if (texts == null || texts.isEmpty()) { + return new HashMap<>(); + } + + try { + String url = integrationProxyUrl + "/api/v1/embeddings/generate/batch"; + HttpHeaders headers = createAuthHeaders(); + if (headers == null) { + log.warn("No authentication context available for batch embedding generation"); + return new HashMap<>(); + } + + Map requestBody = new HashMap<>(); + requestBody.put("texts", texts); + + HttpEntity> entity = new HttpEntity<>(requestBody, headers); + + ResponseEntity response = restTemplate.exchange(url, HttpMethod.POST, entity, Map.class); + + if (response.getStatusCode().is2xxSuccessful() && response.getBody() != null) { + Map responseBody = response.getBody(); + + if (responseBody.containsKey("embeddings")) { + @SuppressWarnings("unchecked") + Map embeddings = (Map) responseBody.get("embeddings"); + + log.debug("Generated batch embeddings for {} texts", embeddings.size()); + return embeddings; + } + } + + log.warn("Failed to generate batch embeddings - unexpected response format"); + return new HashMap<>(); + + } catch (Exception e) { + log.error("Error generating batch embeddings for {} texts", texts.size(), e); + return new HashMap<>(); + } + } + + /** + * Calculate cosine similarity between two embeddings + */ + public static double calculateCosineSimilarity(float[] embedding1, float[] embedding2) { + if (embedding1 == null || embedding2 == null || embedding1.length != embedding2.length) { + return 0.0; + } + + double dotProduct = 0.0; + double normA = 0.0; + double normB = 0.0; + + for (int i = 0; i < embedding1.length; i++) { + dotProduct += embedding1[i] * embedding2[i]; + normA += Math.pow(embedding1[i], 2); + normB += Math.pow(embedding2[i], 2); + } + + if (normA == 0.0 || normB == 0.0) { + return 0.0; + } + + return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); + } + + /** + * Create authentication headers for integration proxy calls + */ + private HttpHeaders createAuthHeaders() { + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + if (authentication == null) { + return null; + } + + // For now, assume we have a Bearer token available + // In a real implementation, this would extract the JWT token from the security context + String token = extractTokenFromAuthentication(authentication); + if (token == null) { + return null; + } + + HttpHeaders headers = new HttpHeaders(); + headers.set("Authorization", "Bearer " + token); + headers.set("Content-Type", "application/json"); + + return headers; + } + + /** + * Extract JWT token from authentication context + * This is a placeholder implementation - in practice, you'd extract the actual JWT + */ + private String extractTokenFromAuthentication(Authentication authentication) { + // This is a simplified implementation + // In practice, you'd extract the actual JWT token from the authentication object + if (authentication.getCredentials() instanceof String) { + return (String) authentication.getCredentials(); + } + + // For now, return null to indicate no token available + // This would need to be implemented based on your specific authentication setup + return null; + } +} \ No newline at end of file diff --git a/core/src/main/java/io/sentrius/sso/core/services/agents/LLMService.java b/core/src/main/java/io/sentrius/sso/core/services/agents/LLMService.java index 1dc53045..b98d786d 100644 --- a/core/src/main/java/io/sentrius/sso/core/services/agents/LLMService.java +++ b/core/src/main/java/io/sentrius/sso/core/services/agents/LLMService.java @@ -1,14 +1,18 @@ package io.sentrius.sso.core.services.agents; +import java.util.Map; +import com.fasterxml.jackson.core.JsonProcessingException; import io.sentrius.sso.core.dto.ztat.TokenDTO; +import io.sentrius.sso.core.embeddings.EmbeddingService; import io.sentrius.sso.core.exceptions.ZtatException; +import io.sentrius.sso.core.utils.JsonUtil; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.autoconfigure.condition.ConditionalOnWebApplication; import org.springframework.stereotype.Service; @Service @ConditionalOnWebApplication(type = ConditionalOnWebApplication.Type.SERVLET) -public class LLMService { +public class LLMService implements EmbeddingService { final ZeroTrustClientService zeroTrustClientService; @@ -26,4 +30,20 @@ public String askQuestion(TokenDTO dto, T body) throws ZtatException { return zeroTrustClientService.callPostOnApi(dto, openAiEndpoint, "/chat/completions", body); } + @Override + public float[] embed(TokenDTO dto, String input) throws ZtatException, JsonProcessingException { + + var payload = Map.of("input", input, "model", "text-embedding-3-small"); + + var textResponse = zeroTrustClientService.callPostOnApi(dto, openAiEndpoint, "/embeddings/generate", payload); + + var response = JsonUtil.MAPPER.readTree(textResponse); + + var vector = response.get("embedding"); + float[] embedding = new float[vector.size()]; + for (int i = 0; i < vector.size(); i++) { + embedding[i] = (float) vector.get(i).asDouble(); + } + return embedding; + } } diff --git a/core/src/main/java/io/sentrius/sso/core/services/agents/ZeroTrustClientService.java b/core/src/main/java/io/sentrius/sso/core/services/agents/ZeroTrustClientService.java index 8eecbc61..d5abb74c 100644 --- a/core/src/main/java/io/sentrius/sso/core/services/agents/ZeroTrustClientService.java +++ b/core/src/main/java/io/sentrius/sso/core/services/agents/ZeroTrustClientService.java @@ -1,13 +1,17 @@ package io.sentrius.sso.core.services.agents; +import java.io.IOException; import java.net.URI; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JavaType; import com.fasterxml.jackson.databind.JsonMappingException; import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; +import io.sentrius.sso.core.dto.PagedResultDTO; import io.sentrius.sso.core.dto.UserDTO; import io.sentrius.sso.core.dto.agents.AgentExecution; import io.sentrius.sso.core.dto.ztat.EndpointRequest; @@ -20,6 +24,7 @@ import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.autoconfigure.condition.ConditionalOnWebApplication; +import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.*; import org.springframework.stereotype.Service; import org.springframework.web.client.HttpClientErrorException; @@ -771,4 +776,87 @@ public boolean verifyZtatChallenge(AgentExecution execution, String ztatToken, S } } + public PagedResultDTO callPostOnApi( + @NonNull TokenDTO token, + @NonNull String apiEndpoint, + T body, + Class responseType, + Integer page, + Integer size, + List sort, + Map.Entry>... params + ) throws ZtatException { + + String keycloakJwt = getKeycloakToken(); + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + headers.setBearerAuth(keycloakJwt); + headers.set("X-Ztat-Token", token.getZtatToken()); + headers.set("X-Communication-Id", token.getCommunicationId()); + + HttpEntity requestEntity = new HttpEntity<>(body, headers); + + if (!apiEndpoint.startsWith("/")) { + apiEndpoint = "/" + apiEndpoint; + } + if (!apiEndpoint.startsWith("/api/v1/")) { + apiEndpoint = "/api/v1" + apiEndpoint; + } + + var builder = UriComponentsBuilder.fromUri(URI.create(agentApiUrl)) + .path(apiEndpoint); + + // Add pagination params if present + if (page != null) builder.queryParam("page", page); + if (size != null) builder.queryParam("size", size); + if (sort != null) { + for (String s : sort) { + builder.queryParam("sort", s); + } + } + + // Add any other params + if (params != null) { + for (Map.Entry> entry : params) { + for (String value : entry.getValue()) { + builder.queryParam(entry.getKey(), UriUtils.encodeQueryParam(value, StandardCharsets.UTF_8)); + } + } + } + + try { + // Call API as plain JSON String + ResponseEntity response = restTemplate.exchange( + builder.build(true).toUriString(), + HttpMethod.POST, + requestEntity, + String.class + ); + + if (response.getStatusCode().is2xxSuccessful() && response.getBody() != null) { + + + // Construct JavaType for PagedResultDTO + JavaType type = JsonUtil.MAPPER.getTypeFactory() + .constructParametricType(PagedResultDTO.class, responseType); + + return JsonUtil.MAPPER.readValue(response.getBody(), type); + } else if (response.getStatusCode() == HttpStatus.PRECONDITION_REQUIRED) { + throw new ZtatException("Inaccessible endpoint: " + response.getStatusCode(), apiEndpoint); + } else { + throw new RuntimeException("Failed: " + response.getStatusCode()); + } + + } catch (HttpClientErrorException e) { + if (e.getStatusCode() == HttpStatus.PRECONDITION_REQUIRED) { + throw new ZtatException(e.getResponseBodyAsString(), apiEndpoint); + } else { + log.info("Error: {}", e.getResponseBodyAsString()); + } + throw new RuntimeException(e.getResponseBodyAsString()); + } catch (IOException e) { + throw new RuntimeException("Error deserializing response", e); + } } + +} diff --git a/core/src/main/java/io/sentrius/sso/core/services/capabilities/EndpointScanningService.java b/core/src/main/java/io/sentrius/sso/core/services/capabilities/EndpointScanningService.java index 266689e6..54820b83 100644 --- a/core/src/main/java/io/sentrius/sso/core/services/capabilities/EndpointScanningService.java +++ b/core/src/main/java/io/sentrius/sso/core/services/capabilities/EndpointScanningService.java @@ -6,6 +6,7 @@ import io.sentrius.sso.core.dto.capabilities.AccessLimitations; import io.sentrius.sso.core.dto.capabilities.EndpointDescriptor; import io.sentrius.sso.core.dto.capabilities.ParameterDescriptor; +import io.sentrius.sso.core.model.verbs.Endpoint; import io.sentrius.sso.core.model.verbs.Verb; import lombok.extern.slf4j.Slf4j; import org.springframework.context.ApplicationContext; @@ -132,6 +133,8 @@ private void scanRestControllerClass(Class clazz) { private EndpointDescriptor scanRestMethod(Class clazz, Method method, String basePath) { String httpMethod = null; String path = basePath; + String description = method.getAnnotation(Endpoint.class) != null ? + method.getAnnotation(Endpoint.class).description() : "REST endpoint: " + httpMethod + " " + path; // Check for HTTP method annotations if (method.isAnnotationPresent(GetMapping.class)) { @@ -180,7 +183,7 @@ private EndpointDescriptor scanRestMethod(Class clazz, Method method, String return EndpointDescriptor.builder() .name(method.getName()) - .description("REST endpoint: " + httpMethod + " " + path) + .description(description) .type("REST") .httpMethod(httpMethod) .path(path) diff --git a/core/src/main/java/io/sentrius/sso/core/services/endpoints/CosineSimilarity.java b/core/src/main/java/io/sentrius/sso/core/services/endpoints/CosineSimilarity.java new file mode 100644 index 00000000..c8fb78ed --- /dev/null +++ b/core/src/main/java/io/sentrius/sso/core/services/endpoints/CosineSimilarity.java @@ -0,0 +1,14 @@ +package io.sentrius.sso.core.services.endpoints; + +public class CosineSimilarity { + + public static float score(float[] a, float[] b) { + float dot = 0f, normA = 0f, normB = 0f; + for (int i = 0; i < a.length; i++) { + dot += a[i] * b[i]; + normA += a[i] * a[i]; + normB += b[i] * b[i]; + } + return (float) (dot / (Math.sqrt(normA) * Math.sqrt(normB))); + } +} diff --git a/core/src/main/java/io/sentrius/sso/core/services/security/KeycloakService.java b/core/src/main/java/io/sentrius/sso/core/services/security/KeycloakService.java index 4ab35d77..afbf2ffd 100644 --- a/core/src/main/java/io/sentrius/sso/core/services/security/KeycloakService.java +++ b/core/src/main/java/io/sentrius/sso/core/services/security/KeycloakService.java @@ -54,6 +54,21 @@ public Map> getUserAttributes(String userId) { return user.getAttributes(); } + public String extractInitialUserType(String token) { + if (token.startsWith("Bearer ")) { + token = token.substring(7); + } + var kid = JwtUtil.extractKid(token); + var publicKey = keycloak.getPublicKey(kid); + var claims = Jwts.parser() + .setSigningKey(publicKey) + .build() + .parseClaimsJws(token) + .getBody(); + + return claims.get("initial_user_type", String.class); + } + /** * Validate a JWT using the Keycloak Public Key. diff --git a/core/src/main/java/io/sentrius/sso/core/utils/ListUtils.java b/core/src/main/java/io/sentrius/sso/core/utils/ListUtils.java new file mode 100644 index 00000000..735dcfca --- /dev/null +++ b/core/src/main/java/io/sentrius/sso/core/utils/ListUtils.java @@ -0,0 +1,10 @@ +package io.sentrius.sso.core.utils; + +import java.util.List; + +public class ListUtils { + public static List getLastNElements(List list, int n) { + int size = list.size(); + return list.subList(Math.max(size - n, 0), size); + } +} \ No newline at end of file diff --git a/core/src/test/java/io/sentrius/sso/core/services/agents/EmbeddingServiceTest.java b/core/src/test/java/io/sentrius/sso/core/services/agents/EmbeddingServiceTest.java new file mode 100644 index 00000000..63f5a725 --- /dev/null +++ b/core/src/test/java/io/sentrius/sso/core/services/agents/EmbeddingServiceTest.java @@ -0,0 +1,268 @@ +package io.sentrius.sso.core.services.agents; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.web.client.RestTemplate; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +@ExtendWith(MockitoExtension.class) +class EmbeddingServiceTest { + + @Mock + private RestTemplate restTemplate; + + @Mock + private Authentication authentication; + + @Mock + private SecurityContext securityContext; + + private EmbeddingService embeddingService; + private static final String INTEGRATION_PROXY_URL = "http://localhost:8081"; + + @BeforeEach + void setUp() { + embeddingService = new EmbeddingService(restTemplate, INTEGRATION_PROXY_URL); + } + + @AfterEach + void tearDown() { + SecurityContextHolder.clearContext(); + } + + private void setupAuthenticationContext() { + lenient().when(securityContext.getAuthentication()).thenReturn(authentication); + lenient().when(authentication.getCredentials()).thenReturn("test-jwt-token"); + SecurityContextHolder.setContext(securityContext); + } + + @Test + void testIsAvailable_Success() { + // Arrange + setupAuthenticationContext(); + Map statusResponse = new HashMap<>(); + statusResponse.put("available", true); + ResponseEntity mockResponse = new ResponseEntity<>(statusResponse, HttpStatus.OK); + + when(restTemplate.exchange(anyString(), eq(HttpMethod.GET), any(HttpEntity.class), eq(Map.class))) + .thenReturn(mockResponse); + + // Act + boolean result = embeddingService.isAvailable(); + + // Assert + assertTrue(result); + verify(restTemplate).exchange(contains("/api/v1/embeddings/status"), eq(HttpMethod.GET), any(HttpEntity.class), eq(Map.class)); + } + + @Test + void testIsAvailable_NotAvailable() { + // Arrange + setupAuthenticationContext(); + Map statusResponse = new HashMap<>(); + statusResponse.put("available", false); + ResponseEntity mockResponse = new ResponseEntity<>(statusResponse, HttpStatus.OK); + + when(restTemplate.exchange(anyString(), eq(HttpMethod.GET), any(HttpEntity.class), eq(Map.class))) + .thenReturn(mockResponse); + + // Act + boolean result = embeddingService.isAvailable(); + + // Assert + assertFalse(result); + } + + @Test + void testIsAvailable_NoAuthentication() { + // Arrange - no authentication setup + + // Act + boolean result = embeddingService.isAvailable(); + + // Assert + assertFalse(result); + verify(restTemplate, never()).exchange(anyString(), any(), any(), any(Class.class)); + } + + @Test + void testEmbed_Success() { + // Arrange + setupAuthenticationContext(); + String inputText = "test text for embedding"; + float[] mockEmbedding = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f}; + + Map responseBody = new HashMap<>(); + responseBody.put("embedding", mockEmbedding); + + ResponseEntity mockResponse = new ResponseEntity<>(responseBody, HttpStatus.OK); + + when(restTemplate.exchange(anyString(), eq(HttpMethod.POST), any(HttpEntity.class), eq(Map.class))) + .thenReturn(mockResponse); + + // Act + float[] result = embeddingService.embed(inputText); + + // Assert + assertNotNull(result); + assertEquals(5, result.length); + assertEquals(0.1f, result[0], 0.001f); + assertEquals(0.2f, result[1], 0.001f); + assertEquals(0.3f, result[2], 0.001f); + assertEquals(0.4f, result[3], 0.001f); + assertEquals(0.5f, result[4], 0.001f); + + verify(restTemplate).exchange(contains("/api/v1/embeddings/generate"), eq(HttpMethod.POST), any(HttpEntity.class), eq(Map.class)); + } + + @Test + void testEmbed_NoAuthentication() { + // Arrange - no authentication setup + + // Act + float[] result = embeddingService.embed("test text"); + + // Assert + assertNull(result); + verify(restTemplate, never()).exchange(anyString(), any(), any(), any(Class.class)); + } + + @Test + void testEmbed_EmptyText() { + // Arrange + setupAuthenticationContext(); + + // Act + float[] result1 = embeddingService.embed(""); + float[] result2 = embeddingService.embed(null); + + // Assert + assertNull(result1); + assertNull(result2); + verify(restTemplate, never()).exchange(anyString(), any(), any(), any(Class.class)); + } + + @Test + void testEmbed_ApiError() { + // Arrange + setupAuthenticationContext(); + String inputText = "test text"; + + when(restTemplate.exchange(anyString(), eq(HttpMethod.POST), any(HttpEntity.class), eq(Map.class))) + .thenThrow(new RuntimeException("API Error")); + + // Act + float[] result = embeddingService.embed(inputText); + + // Assert + assertNull(result); + verify(restTemplate).exchange(contains("/api/v1/embeddings/generate"), eq(HttpMethod.POST), any(HttpEntity.class), eq(Map.class)); + } + + @Test + void testCalculateCosineSimilarity_ValidEmbeddings() { + // Arrange + float[] embedding1 = {1.0f, 0.0f, 0.0f}; + float[] embedding2 = {0.0f, 1.0f, 0.0f}; + float[] embedding3 = {1.0f, 0.0f, 0.0f}; // Same as embedding1 + + // Act & Assert + double similarity1 = EmbeddingService.calculateCosineSimilarity(embedding1, embedding2); + assertEquals(0.0, similarity1, 0.001); // Orthogonal vectors + + double similarity2 = EmbeddingService.calculateCosineSimilarity(embedding1, embedding3); + assertEquals(1.0, similarity2, 0.001); // Identical vectors + } + + @Test + void testCalculateCosineSimilarity_NullEmbeddings() { + // Arrange + float[] embedding1 = {1.0f, 0.0f, 0.0f}; + + // Act & Assert + double similarity1 = EmbeddingService.calculateCosineSimilarity(null, embedding1); + assertEquals(0.0, similarity1); + + double similarity2 = EmbeddingService.calculateCosineSimilarity(embedding1, null); + assertEquals(0.0, similarity2); + + double similarity3 = EmbeddingService.calculateCosineSimilarity(null, null); + assertEquals(0.0, similarity3); + } + + @Test + void testCalculateCosineSimilarity_DifferentLengths() { + // Arrange + float[] embedding1 = {1.0f, 0.0f, 0.0f}; + float[] embedding2 = {1.0f, 0.0f}; // Different length + + // Act + double similarity = EmbeddingService.calculateCosineSimilarity(embedding1, embedding2); + + // Assert + assertEquals(0.0, similarity); + } + + @Test + void testEmbedBatch_Success() { + // Arrange + setupAuthenticationContext(); + List texts = Arrays.asList("text1", "text2"); + Map mockEmbeddings = new HashMap<>(); + mockEmbeddings.put("text1", new float[]{0.1f, 0.2f, 0.3f}); + mockEmbeddings.put("text2", new float[]{0.4f, 0.5f, 0.6f}); + + Map responseBody = new HashMap<>(); + responseBody.put("embeddings", mockEmbeddings); + + ResponseEntity mockResponse = new ResponseEntity<>(responseBody, HttpStatus.OK); + + when(restTemplate.exchange(anyString(), eq(HttpMethod.POST), any(HttpEntity.class), eq(Map.class))) + .thenReturn(mockResponse); + + // Act + Map results = embeddingService.embedBatch(texts); + + // Assert + assertNotNull(results); + assertEquals(2, results.size()); + assertTrue(results.containsKey("text1")); + assertTrue(results.containsKey("text2")); + + verify(restTemplate).exchange(contains("/api/v1/embeddings/generate/batch"), eq(HttpMethod.POST), any(HttpEntity.class), eq(Map.class)); + } + + @Test + void testEmbedBatch_EmptyList() { + // Act + Map result1 = embeddingService.embedBatch(Arrays.asList()); + Map result2 = embeddingService.embedBatch(null); + + // Assert + assertNotNull(result1); + assertTrue(result1.isEmpty()); + assertNotNull(result2); + assertTrue(result2.isEmpty()); + + verify(restTemplate, never()).exchange(anyString(), any(), any(), any(Class.class)); + } +} \ No newline at end of file diff --git a/dataplane/pom.xml b/dataplane/pom.xml index 15f5ed2e..5f2c1062 100644 --- a/dataplane/pom.xml +++ b/dataplane/pom.xml @@ -30,6 +30,14 @@ provenance-core 1.0.0-SNAPSHOT + + org.hibernate.orm + hibernate-vector + + + org.apache.accumulo + accumulo-access + org.apache.commons commons-lang3 @@ -91,6 +99,12 @@ test + + org.springframework.boot + spring-boot-test-autoconfigure + test + + org.springframework.boot spring-boot-starter-oauth2-resource-server diff --git a/dataplane/src/main/java/io/sentrius/sso/core/config/SystemOptions.java b/dataplane/src/main/java/io/sentrius/sso/core/config/SystemOptions.java index be0dad54..9a2c28a1 100644 --- a/dataplane/src/main/java/io/sentrius/sso/core/config/SystemOptions.java +++ b/dataplane/src/main/java/io/sentrius/sso/core/config/SystemOptions.java @@ -122,6 +122,17 @@ public class SystemOptions { @Updatable(description = "Allows LLM to ask questions of the user") public Boolean enableLLMQuestions = false; + @Updatable(description = "Enables agent memory store functionality") + @Builder.Default public Boolean enableMemoryStore = true; + + @Updatable(description = "Enables vector store capabilities for semantic memory search") + @Builder.Default public Boolean enableVectorStore = true; + + @Updatable(description = "Default similarity threshold for vector searches") + @Builder.Default public Double vectorSimilarityThreshold = 0.7; + + @Updatable(description = "Dimension size for vector embeddings") + @Builder.Default public Integer vectorDimension = 1536; public Boolean lockdownEnabled = false; @@ -171,6 +182,8 @@ private void init() throws IllegalAccessException { field.set(this, Boolean.parseBoolean(propertyValue)); } else if (field.getType() == Integer.class || field.getType() == int.class) { field.set(this, Integer.parseInt(propertyValue)); + } else if (field.getType() == Double.class || field.getType() == double.class) { + field.set(this, Double.parseDouble(propertyValue)); } else if (field.getType() == String.class) { field.set(this, propertyValue); } diff --git a/dataplane/src/main/java/io/sentrius/sso/core/integrations/ticketing/JiraService.java b/dataplane/src/main/java/io/sentrius/sso/core/integrations/ticketing/JiraService.java index 0a047682..ecdc8f41 100644 --- a/dataplane/src/main/java/io/sentrius/sso/core/integrations/ticketing/JiraService.java +++ b/dataplane/src/main/java/io/sentrius/sso/core/integrations/ticketing/JiraService.java @@ -247,4 +247,32 @@ public String extractTextFromADF(Object adf) { return ""; } } + + public List getComments(String ticketKey) { + List comments = new ArrayList<>(); + try { + String commentsUrl = String.format("%s/rest/api/3/issue/%s/comment", jiraBaseUrl, ticketKey); + + HttpHeaders headers = new HttpHeaders(); + headers.setBasicAuth(username, apiToken); + headers.setContentType(MediaType.APPLICATION_JSON); + + HttpEntity requestEntity = new HttpEntity<>(headers); + ResponseEntity response = restTemplate.exchange(commentsUrl, HttpMethod.GET, requestEntity, String.class); + + if (response.getStatusCode() == HttpStatus.OK) { + JsonNode root = JsonUtil.MAPPER.readTree(response.getBody()); + ArrayNode commentNodes = (ArrayNode) root.path("comments"); + for (JsonNode commentNode : commentNodes) { + String commentText = extractTextFromADF(commentNode.path("body")); + comments.add(commentText); + } + } else { + log.info("Failed to fetch comments. Status: {}", response.getStatusCode()); + } + } catch (Exception e) { + log.error("Error fetching comments: {}", e.getMessage()); + } + return comments; + } } diff --git a/dataplane/src/main/java/io/sentrius/sso/core/model/agents/AgentMemory.java b/dataplane/src/main/java/io/sentrius/sso/core/model/agents/AgentMemory.java new file mode 100644 index 00000000..71b988f0 --- /dev/null +++ b/dataplane/src/main/java/io/sentrius/sso/core/model/agents/AgentMemory.java @@ -0,0 +1,228 @@ +package io.sentrius.sso.core.model.agents; + +import java.time.Instant; +import jakarta.persistence.*; +import lombok.Getter; +import lombok.Setter; +import lombok.Builder; +import lombok.NoArgsConstructor; +import lombok.AllArgsConstructor; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.hibernate.annotations.JdbcTypeCode; +import org.hibernate.type.SqlTypes; +import java.util.Map; +import java.util.HashMap; +import java.util.Arrays; + +@Entity +@Table(name = "agent_memory") +@Getter +@Setter +@Builder +@NoArgsConstructor +@AllArgsConstructor +@JsonIgnoreProperties(ignoreUnknown = true) +public class AgentMemory { + + @Id + @GeneratedValue(strategy = GenerationType.IDENTITY) + private Long id; + + @Column(name = "memory_key", nullable = false) + private String memoryKey; + + @Column(name = "memory_value", nullable = false, columnDefinition = "TEXT") + private String memoryValue; + + @Column(name = "memory_type") + private String memoryType = "JSON"; + + @Column(name = "agent_id") + private String agentId; + + @Column(name = "agent_name") + private String agentName; + + @Column(name = "conversation_id") + private String conversationId; + + @Column(name = "classification") + private String classification = "PRIVATE"; + + @Column(name = "markings") + private String markings; + + @Column(name = "access_level") + private String accessLevel = "AGENT_ONLY"; + + @Column(name = "creator_user_id") + private String creatorUserId; + + @Column(name = "creator_user_type") + private String creatorUserType; + + @Column(name = "created_at") + private Instant createdAt; + + @Column(name = "updated_at") + private Instant updatedAt; + + @Column(name = "expires_at") + private Instant expiresAt; + + @Column(name = "shared_with_agents", columnDefinition = "TEXT") + private String sharedWithAgents; + + @Column(name = "metadata", columnDefinition = "jsonb") + @JdbcTypeCode(SqlTypes.JSON) + private JsonNode metadata; + + @Column(name = "version") + @Builder.Default + private Integer version = 1; + + @Column(name = "embedding", columnDefinition = "vector(1536)") + @JdbcTypeCode(SqlTypes.VECTOR) + private float[] embedding; + + @PrePersist + protected void onCreate() { + createdAt = updatedAt = Instant.now(); + } + + @PreUpdate + protected void onUpdate() { + updatedAt = Instant.now(); + version++; + } + + // Enum for predefined classifications + public enum Classification { + PUBLIC, PRIVATE, SHARED, CONFIDENTIAL + } + + // Enum for predefined access levels + public enum AccessLevel { + ALL_USERS, AGENT_ONLY, TEAM_MEMBERS, CREATOR_ONLY, ADMIN_ONLY + } + + // Helper methods for metadata + public Map getMetadataAsMap() { + if (metadata == null || metadata.isNull()) { + return new HashMap<>(); + } + try { + ObjectMapper mapper = new ObjectMapper(); + return mapper.convertValue(metadata, Map.class); + } catch (IllegalArgumentException e) { + return new HashMap<>(); + } + } + + public void setMetadataFromMap(Map metadataMap) { + if (metadataMap == null || metadataMap.isEmpty()) { + this.metadata = null; + return; + } + ObjectMapper mapper = new ObjectMapper(); + this.metadata = mapper.valueToTree(metadataMap); + } + + public JsonNode getMetadataAsJsonNode() { + return metadata; // already a JsonNode, no parsing needed + } + + public void setMetadataAsJsonNode(JsonNode jsonNode) { + this.metadata = jsonNode; + } + + // Helper methods for markings + public String[] getMarkingsArray() { + return markings != null ? markings.split(",") : new String[0]; + } + + public void setMarkingsArray(String[] markingsArray) { + this.markings = markingsArray != null ? String.join(",", markingsArray) : null; + } + + // Helper methods for shared agents + public String[] getSharedAgentsArray() { + return sharedWithAgents != null ? sharedWithAgents.split(",") : new String[0]; + } + + public void setSharedAgentsArray(String[] sharedAgentsArray) { + this.sharedWithAgents = sharedAgentsArray != null ? String.join(",", sharedAgentsArray) : null; + } + + // Helper methods for validation + public boolean isExpired() { + return expiresAt != null && Instant.now().isAfter(expiresAt); + } + + public boolean hasMarking(String marking) { + if (markings == null) return false; + String[] markingArray = getMarkingsArray(); + for (String m : markingArray) { + if (m.trim().equalsIgnoreCase(marking.trim())) { + return true; + } + } + return false; + } + + public boolean canBeSharedWith(String agentId) { + if (accessLevel != null && accessLevel.equals("ALL_USERS")) return true; + if (sharedWithAgents == null) return false; + String[] sharedAgents = getSharedAgentsArray(); + for (String shared : sharedAgents) { + if (shared.trim().equals(agentId.trim())) { + return true; + } + } + return false; + } + + // Helper methods for embeddings + public boolean hasEmbedding() { + return embedding != null && embedding.length > 0; + } + + public void setEmbedding(float[] embedding) { + this.embedding = embedding; + } + + public float[] getEmbedding() { + return embedding; + } + + public String getEmbeddingAsString() { + return embedding != null ? Arrays.toString(embedding) : null; + } + + // Calculate cosine similarity between this memory's embedding and another + public double calculateCosineSimilarity(float[] otherEmbedding) { + if (embedding == null || otherEmbedding == null || + embedding.length != otherEmbedding.length) { + return 0.0; + } + + double dotProduct = 0.0; + double normA = 0.0; + double normB = 0.0; + + for (int i = 0; i < embedding.length; i++) { + dotProduct += embedding[i] * otherEmbedding[i]; + normA += Math.pow(embedding[i], 2); + normB += Math.pow(otherEmbedding[i], 2); + } + + if (normA == 0.0 || normB == 0.0) { + return 0.0; + } + + return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); + } +} \ No newline at end of file diff --git a/dataplane/src/main/java/io/sentrius/sso/core/model/agents/MemoryAccessPolicy.java b/dataplane/src/main/java/io/sentrius/sso/core/model/agents/MemoryAccessPolicy.java new file mode 100644 index 00000000..92f1b797 --- /dev/null +++ b/dataplane/src/main/java/io/sentrius/sso/core/model/agents/MemoryAccessPolicy.java @@ -0,0 +1,238 @@ +package io.sentrius.sso.core.model.agents; + +import java.time.Instant; +import java.util.Map; +import java.util.HashMap; +import com.fasterxml.jackson.databind.JsonNode; +import jakarta.persistence.*; +import lombok.Getter; +import lombok.Setter; +import lombok.Builder; +import lombok.NoArgsConstructor; +import lombok.AllArgsConstructor; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.hibernate.annotations.JdbcTypeCode; +import org.hibernate.type.SqlTypes; + +/** + * Memory Access Policy defines fine-grained access control for agent memory operations. + * + * This policy system provides a secondary layer of policy enforcement that operates + * AFTER trust policies have been evaluated. Key considerations: + * + * - Trust policies are the primary enforcement layer and can completely preclude memory usage + * - If a trust policy blocks memory access, these memory access policies will never be evaluated + * - These policies only take effect when trust policies allow memory operations to proceed + * - Trust policy decisions override memory access policy decisions + * + * Policy Evaluation Order: + * 1. Trust policies are evaluated first (can completely block memory access) + * 2. If trust policies allow, then memory access policies are evaluated + * 3. Both must allow access for the operation to proceed + * + * This design ensures that high-level organizational trust decisions take precedence + * over specific memory access rules. + */ +@Entity +@Table(name = "memory_access_policies") +@Getter +@Setter +@Builder +@NoArgsConstructor +@AllArgsConstructor +@JsonIgnoreProperties(ignoreUnknown = true) +public class MemoryAccessPolicy { + + @Id + @GeneratedValue(strategy = GenerationType.IDENTITY) + private Long id; + + @Column(name = "policy_name", nullable = false, unique = true) + private String policyName; + + @Column(name = "policy_description", columnDefinition = "TEXT") + private String policyDescription; + + @Column(name = "target_classification") + private String targetClassification; + + @Column(name = "target_markings") + private String targetMarkings; + + + @Column(name = "required_user_attributes", columnDefinition = "jsonb") + @JdbcTypeCode(SqlTypes.JSON) + private JsonNode requiredUserAttributes; + + @Column(name = "required_agent_attributes", columnDefinition = "jsonb") + @JdbcTypeCode(SqlTypes.JSON) + private JsonNode requiredAgentAttributes; + + @Column(name = "access_type") + private String accessType = "READ"; + + @Column(name = "is_active") + private Boolean isActive = true; + + @Column(name = "created_at") + private Instant createdAt; + + @Column(name = "updated_at") + private Instant updatedAt; + + @PrePersist + protected void onCreate() { + createdAt = updatedAt = Instant.now(); + } + + @PreUpdate + protected void onUpdate() { + updatedAt = Instant.now(); + } + + // Enum for predefined access types + public enum AccessType { + READ, WRITE, DELETE, FULL + } + + public Map getRequiredUserAttributesAsMap() { + if (requiredUserAttributes == null || requiredUserAttributes.isEmpty()) { + return new HashMap<>(); + } + try { + ObjectMapper mapper = new ObjectMapper(); + return mapper.convertValue(requiredUserAttributes, Map.class); + } catch (IllegalArgumentException e) { + return new HashMap<>(); + } + } + + public void setRequiredUserAttributesFromMap(Map attributesMap) { + if (attributesMap == null || attributesMap.isEmpty()) { + this.requiredUserAttributes = null; + return; + } + ObjectMapper mapper = new ObjectMapper(); + this.requiredUserAttributes = mapper.valueToTree(attributesMap); + } + + public Map getRequiredAgentAttributesAsMap() { + if (requiredAgentAttributes == null || requiredAgentAttributes.isEmpty()) { + return new HashMap<>(); + } + try { + ObjectMapper mapper = new ObjectMapper(); + return mapper.convertValue(requiredAgentAttributes, Map.class); + } catch (IllegalArgumentException e) { + return new HashMap<>(); + } + } + + public void setRequiredAgentAttributesFromMap(Map attributesMap) { + if (attributesMap == null || attributesMap.isEmpty()) { + this.requiredAgentAttributes = null; + return; + } + ObjectMapper mapper = new ObjectMapper(); + this.requiredAgentAttributes = mapper.valueToTree(attributesMap); + } + + + // Helper methods for markings + public String[] getTargetMarkingsArray() { + return targetMarkings != null ? targetMarkings.split(",") : new String[0]; + } + + public void setTargetMarkingsArray(String[] markingsArray) { + this.targetMarkings = markingsArray != null ? String.join(",", markingsArray) : null; + } + + // Helper methods for validation + public boolean appliesToClassification(String classification) { + return targetClassification == null || targetClassification.equals(classification); + } + + public boolean appliesToMarkings(String markings) { + if (targetMarkings == null) return true; + if (markings == null) return false; + + String[] targetArray = getTargetMarkingsArray(); + String[] memoryMarkings = markings.split(","); + + // Check if any target marking is present in memory markings + for (String target : targetArray) { + for (String memory : memoryMarkings) { + if (target.trim().equalsIgnoreCase(memory.trim())) { + return true; + } + } + } + return false; + } + + public boolean allowsAccessType(String requestedAccessType) { + if (accessType == null || accessType.equals("FULL")) return true; + return accessType.equalsIgnoreCase(requestedAccessType); + } + + // ABAC evaluation methods + public boolean evaluateUserAttributes(Map userAttributes) { + Map required = getRequiredUserAttributesAsMap(); + if (required.isEmpty()) { + return true; + } + + for (Map.Entry requiredEntry : required.entrySet()) { + String requiredKey = requiredEntry.getKey(); + Object requiredValue = requiredEntry.getValue(); + + if (!userAttributes.containsKey(requiredKey)) { + return false; + } + + Object userValue = userAttributes.get(requiredKey); + if (!matchesValue(requiredValue, userValue)) { + return false; + } + } + + return true; + } + + public boolean evaluateAgentAttributes(Map agentAttributes) { + Map required = getRequiredAgentAttributesAsMap(); + if (required.isEmpty()) { + return true; + } + + for (Map.Entry requiredEntry : required.entrySet()) { + String requiredKey = requiredEntry.getKey(); + Object requiredValue = requiredEntry.getValue(); + + if (!agentAttributes.containsKey(requiredKey)) { + return false; + } + + Object agentValue = agentAttributes.get(requiredKey); + if (!matchesValue(requiredValue, agentValue)) { + return false; + } + } + + return true; + } + + private boolean matchesValue(Object required, Object actual) { + if (required == null && actual == null) return true; + if (required == null || actual == null) return false; + + // Special case for "user_id" attribute - match against creator + if (required.equals("user_id")) { + return true; // This will be evaluated at runtime + } + + return required.toString().equals(actual.toString()); + } +} \ No newline at end of file diff --git a/dataplane/src/main/java/io/sentrius/sso/core/model/users/UserAttribute.java b/dataplane/src/main/java/io/sentrius/sso/core/model/users/UserAttribute.java new file mode 100644 index 00000000..d4839476 --- /dev/null +++ b/dataplane/src/main/java/io/sentrius/sso/core/model/users/UserAttribute.java @@ -0,0 +1,137 @@ +package io.sentrius.sso.core.model.users; + +import java.time.Instant; +import jakarta.persistence.*; +import lombok.Getter; +import lombok.Setter; +import lombok.Builder; +import lombok.NoArgsConstructor; +import lombok.AllArgsConstructor; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; + +@Entity +@Table(name = "user_attributes") +@Getter +@Setter +@Builder +@NoArgsConstructor +@AllArgsConstructor +@JsonIgnoreProperties(ignoreUnknown = true) +public class UserAttribute { + + @Id + @GeneratedValue(strategy = GenerationType.IDENTITY) + private Long id; + + @Column(name = "user_id", nullable = false) + private String userId; + + @Column(name = "attribute_name", nullable = false) + private String attributeName; + + @Column(name = "attribute_value", nullable = false, columnDefinition = "TEXT") + private String attributeValue; + + @Column(name = "attribute_type") + private String attributeType = "STRING"; + + @Column(name = "source") + private String source = "SENTRIUS"; + + @Column(name = "is_active") + private Boolean isActive = true; + + @Column(name = "created_at") + private Instant createdAt; + + @Column(name = "updated_at") + private Instant updatedAt; + + /** + * Indicates whether this attribute was synchronized from Keycloak. + * + * Sync States: + * - true: Attribute was imported from Keycloak and should be treated as externally managed + * - false: Attribute was created locally in Sentrius + * + * Risks of not being synced: + * - Data inconsistency between Keycloak and Sentrius user profiles + * - Potential security policy mismatches if attributes are used for access control + * - Loss of centralized identity management benefits + * - Manual attribute updates may be overwritten during next sync + * - Audit trail gaps when tracking attribute source changes + */ + @Column(name = "synced_from_keycloak") + private Boolean syncedFromKeycloak = false; + + @PrePersist + protected void onCreate() { + createdAt = updatedAt = Instant.now(); + } + + @PreUpdate + protected void onUpdate() { + updatedAt = Instant.now(); + } + + // Enum for predefined attribute types + public enum AttributeType { + STRING, INTEGER, BOOLEAN, JSON, LIST, DATE + } + + // Enum for predefined sources + public enum Source { + SENTRIUS, KEYCLOAK, LDAP, EXTERNAL + } + + // Helper methods for type-safe value access + public String getStringValue() { + return attributeValue; + } + + public Integer getIntegerValue() { + try { + return Integer.parseInt(attributeValue); + } catch (NumberFormatException e) { + return null; + } + } + + public Boolean getBooleanValue() { + return Boolean.parseBoolean(attributeValue); + } + + public String[] getListValue() { + return attributeValue != null ? attributeValue.split(",") : new String[0]; + } + + // Helper methods for validation + public boolean isValidForType() { + if (attributeType == null) return true; + + switch (attributeType.toUpperCase()) { + case "INTEGER": + try { + Integer.parseInt(attributeValue); + return true; + } catch (NumberFormatException e) { + return false; + } + case "BOOLEAN": + return "true".equalsIgnoreCase(attributeValue) || "false".equalsIgnoreCase(attributeValue); + case "JSON": + // Basic JSON validation - starts with { or [ + return attributeValue.trim().startsWith("{") || attributeValue.trim().startsWith("["); + default: + return true; + } + } + + public boolean matches(String value) { + return attributeValue != null && attributeValue.equals(value); + } + + public boolean contains(String value) { + return attributeValue != null && attributeValue.contains(value); + } +} \ No newline at end of file diff --git a/dataplane/src/main/java/io/sentrius/sso/core/repository/AgentMemoryRepository.java b/dataplane/src/main/java/io/sentrius/sso/core/repository/AgentMemoryRepository.java new file mode 100644 index 00000000..46b43501 --- /dev/null +++ b/dataplane/src/main/java/io/sentrius/sso/core/repository/AgentMemoryRepository.java @@ -0,0 +1,194 @@ +package io.sentrius.sso.core.repository; + +import io.sentrius.sso.core.model.agents.AgentMemory; +import org.springframework.data.domain.Page; +import org.springframework.data.domain.Pageable; +import org.springframework.data.jpa.repository.JpaRepository; +import org.springframework.data.jpa.repository.Query; +import org.springframework.data.repository.query.Param; +import org.springframework.stereotype.Repository; + +import java.time.Instant; +import java.util.List; +import java.util.Optional; + +@Repository +public interface AgentMemoryRepository extends JpaRepository { + + // === Basic Finders === + + List findByAgentIdOrderByCreatedAtDesc(String agentId); + + List findByConversationIdOrderByCreatedAtDesc(String conversationId); + + Optional findByMemoryKey(String memoryKey); + + Optional findByAgentIdAndMemoryKey(String agentId, String memoryKey); + + List findByClassificationOrderByCreatedAtDesc(String classification); + + List findByAccessLevelOrderByCreatedAtDesc(String accessLevel); + + List findByCreatorUserIdOrderByCreatedAtDesc(String creatorUserId); + + // === Sharing === + + @Query(""" + SELECT m FROM AgentMemory m + WHERE (m.accessLevel = 'ALL_USERS' OR m.agentId = :agentId OR m.sharedWithAgents LIKE %:agentId%) + AND (m.expiresAt IS NULL OR m.expiresAt > :now) + """) + List findShareableMemories(@Param("agentId") String agentId, @Param("now") Instant now); + + // === Markings === + + @Query("SELECT m FROM AgentMemory m WHERE m.markings LIKE %:marking%") + List findByMarkingsContaining(@Param("marking") String marking); + + // === JPQL filterable query === + + @Query(""" + SELECT m FROM AgentMemory m + WHERE (:agentId IS NULL OR m.agentId = :agentId) + AND (:classification IS NULL OR m.classification = :classification) + AND (:markings IS NULL OR m.markings LIKE CONCAT('%', :markings, '%')) + AND (m.expiresAt IS NULL OR m.expiresAt > :now) + """) + Page findMemoriesWithFilters( + @Param("agentId") String agentId, + @Param("classification") String classification, + @Param("markings") String markings, + @Param("now") Instant now, + Pageable pageable); + + // === Native filterable query (explicit casting for Postgres) === + + @Query( + value = """ + SELECT * + FROM agent_memory m + WHERE (:agentId IS NULL OR m.agent_id = :agentId) + AND (:classification IS NULL OR m.classification = :classification) + AND (:markings IS NULL OR m.markings LIKE CONCAT('%', CAST(:markings AS VARCHAR), '%')) + AND (m.expires_at IS NULL OR m.expires_at > :now) + ORDER BY m.created_at DESC + """, + countQuery = """ + SELECT COUNT(*) + FROM agent_memory m + WHERE (:agentId IS NULL OR m.agent_id = :agentId) + AND (:classification IS NULL OR m.classification = :classification) + AND (:markings IS NULL OR m.markings LIKE CONCAT('%', CAST(:markings AS VARCHAR), '%')) + AND (m.expires_at IS NULL OR m.expires_at > :now) + """, + nativeQuery = true + ) + Page findMemoriesWithFiltersNative( + @Param("agentId") String agentId, + @Param("classification") String classification, + @Param("markings") String markings, + @Param("now") Instant now, + Pageable pageable); + + // === Expiration === + + @Query("SELECT m FROM AgentMemory m WHERE m.expiresAt IS NULL OR m.expiresAt > :now") + List findNonExpiredMemories(@Param("now") Instant now); + + @Query("SELECT m FROM AgentMemory m WHERE m.expiresAt IS NOT NULL AND m.expiresAt <= :now") + List findExpiredMemories(@Param("now") Instant now); + + void deleteByExpiresAtLessThanEqual(Instant expiredBefore); + + // === Search === + + @Query("SELECT m FROM AgentMemory m WHERE LOWER(m.memoryValue) LIKE LOWER(CONCAT('%', :searchTerm, '%'))") + List searchByMemoryValue(@Param("searchTerm") String searchTerm); + + // === Counts === + + long countByAgentId(String agentId); + + long countByClassification(String classification); + + @Query( + value = "SELECT COUNT(*) FROM agent_memory WHERE embedding IS NOT NULL", + nativeQuery = true + ) + long countMemoriesWithEmbeddings(); + + + // === Embeddings === + + @Query( + value = "SELECT m FROM AgentMemory m WHERE m.embedding IS NULL ORDER BY m.createdAt DESC", + nativeQuery = true) + List findMemoriesWithoutEmbeddings(Pageable pageable); + + // === Vector similarity searches === + + @Query(value = """ + SELECT * FROM agent_memory + WHERE embedding IS NOT NULL + ORDER BY embedding <=> CAST(:queryEmbedding AS vector) + LIMIT :limit + """, nativeQuery = true) + List findSimilarMemories(@Param("queryEmbedding") String queryEmbedding, + @Param("limit") int limit); + + @Query(value = """ + SELECT * FROM agent_memory + WHERE embedding IS NOT NULL + AND classification = :classification + ORDER BY embedding <=> CAST(:queryEmbedding AS vector) + LIMIT :limit + """, nativeQuery = true) + List findSimilarMemoriesByClassification(@Param("queryEmbedding") String queryEmbedding, + @Param("classification") String classification, + @Param("limit") int limit); + + @Query(value = """ + SELECT * FROM agent_memory + WHERE embedding IS NOT NULL + AND markings LIKE CONCAT('%', CAST(:markings AS VARCHAR), '%') + ORDER BY embedding <=> CAST(:queryEmbedding AS vector) + LIMIT :limit + """, nativeQuery = true) + List findSimilarMemoriesByMarkings(@Param("queryEmbedding") String queryEmbedding, + @Param("markings") String markings, + @Param("limit") int limit); + + @Query(value = """ + SELECT *, (embedding <=> CAST(:queryEmbedding AS vector)) AS distance + FROM agent_memory + WHERE embedding IS NOT NULL + AND (agent_id = :agentId OR access_level = 'ALL_USERS' OR shared_with_agents LIKE CONCAT('%', CAST(:agentId AS VARCHAR), '%')) + AND (embedding <=> CAST(:queryEmbedding AS vector)) < :threshold + ORDER BY distance + LIMIT :limit + """, nativeQuery = true) + List findSimilarMemoriesForAgent(@Param("queryEmbedding") String queryEmbedding, + @Param("agentId") String agentId, + @Param("threshold") double threshold, + @Param("limit") int limit); + + @Query(value = """ + SELECT *, (embedding <=> CAST(:queryEmbedding AS vector)) AS distance + FROM agent_memory + WHERE embedding IS NOT NULL + AND ( + LOWER(memory_value) LIKE LOWER(CONCAT('%', :searchTerm, '%')) + OR markings LIKE CONCAT('%', CAST(:searchTerm AS VARCHAR), '%') + OR (embedding <=> CAST(:queryEmbedding AS vector)) < :threshold + ) + ORDER BY + CASE WHEN LOWER(memory_value) LIKE LOWER(CONCAT('%', :searchTerm, '%')) THEN 0 ELSE 1 END, + distance + LIMIT :limit + """, nativeQuery = true) + List hybridSearch(@Param("searchTerm") String searchTerm, + @Param("queryEmbedding") String queryEmbedding, + @Param("threshold") double threshold, + @Param("limit") int limit); +} + diff --git a/dataplane/src/main/java/io/sentrius/sso/core/repository/MemoryAccessPolicyRepository.java b/dataplane/src/main/java/io/sentrius/sso/core/repository/MemoryAccessPolicyRepository.java new file mode 100644 index 00000000..86ad4c9c --- /dev/null +++ b/dataplane/src/main/java/io/sentrius/sso/core/repository/MemoryAccessPolicyRepository.java @@ -0,0 +1,54 @@ +package io.sentrius.sso.core.repository; + +import io.sentrius.sso.core.model.agents.MemoryAccessPolicy; +import org.springframework.data.jpa.repository.JpaRepository; +import org.springframework.data.jpa.repository.Query; +import org.springframework.data.repository.query.Param; +import org.springframework.stereotype.Repository; + +import java.util.List; +import java.util.Optional; + +@Repository +public interface MemoryAccessPolicyRepository extends JpaRepository { + + // Find by policy name + Optional findByPolicyNameAndIsActiveTrue(String policyName); + + // Find all active policies + List findByIsActiveTrueOrderByPolicyName(); + + // Find policies by classification + List findByTargetClassificationAndIsActiveTrue(String classification); + + // Find policies by access type + List findByAccessTypeAndIsActiveTrue(String accessType); + + // Find policies that apply to specific markings + @Query("SELECT p FROM MemoryAccessPolicy p WHERE " + + "p.isActive = true AND " + + "(p.targetMarkings IS NULL OR p.targetMarkings LIKE %:marking%)") + List findPoliciesForMarkings(@Param("marking") String marking); + + // Find policies that apply to specific classification and markings + @Query("SELECT p FROM MemoryAccessPolicy p WHERE " + + "p.isActive = true AND " + + "(p.targetClassification IS NULL OR p.targetClassification = :classification) AND " + + "(p.targetMarkings IS NULL OR p.targetMarkings LIKE %:markings%)") + List findApplicablePolicies(@Param("classification") String classification, + @Param("markings") String markings); + + // Find policies by access type and classification + @Query("SELECT p FROM MemoryAccessPolicy p WHERE " + + "p.isActive = true AND " + + "p.accessType = :accessType AND " + + "(p.targetClassification IS NULL OR p.targetClassification = :classification)") + List findPoliciesForAccess(@Param("accessType") String accessType, + @Param("classification") String classification); + + // Count active policies + long countByIsActiveTrue(); + + // Count policies by classification + long countByTargetClassificationAndIsActiveTrue(String classification); +} \ No newline at end of file diff --git a/dataplane/src/main/java/io/sentrius/sso/core/repository/UserAttributeRepository.java b/dataplane/src/main/java/io/sentrius/sso/core/repository/UserAttributeRepository.java new file mode 100644 index 00000000..2488c02e --- /dev/null +++ b/dataplane/src/main/java/io/sentrius/sso/core/repository/UserAttributeRepository.java @@ -0,0 +1,66 @@ +package io.sentrius.sso.core.repository; + +import io.sentrius.sso.core.model.users.UserAttribute; +import org.springframework.data.jpa.repository.JpaRepository; +import org.springframework.data.jpa.repository.Query; +import org.springframework.data.repository.query.Param; +import org.springframework.stereotype.Repository; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +@Repository +public interface UserAttributeRepository extends JpaRepository { + + // Find by user ID + List findByUserIdAndIsActiveTrue(String userId); + + // Find by user ID and attribute name + Optional findByUserIdAndAttributeNameAndIsActiveTrue(String userId, String attributeName); + + // Find by attribute name across all users + List findByAttributeNameAndIsActiveTrue(String attributeName); + + // Find by attribute value + List findByAttributeValueAndIsActiveTrue(String attributeValue); + + // Find by source + List findBySourceAndIsActiveTrue(String source); + + // Find Keycloak synced attributes + List findBySyncedFromKeycloakTrueAndIsActiveTrue(); + + // Find by user and source + List findByUserIdAndSourceAndIsActiveTrue(String userId, String source); + + // Search attributes by name pattern + @Query("SELECT ua FROM UserAttribute ua WHERE ua.attributeName LIKE %:namePattern% AND ua.isActive = true") + List findByAttributeNameContaining(@Param("namePattern") String namePattern); + + // Get user attributes as map + @Query("SELECT NEW map(ua.attributeName as name, ua.attributeValue as value) " + + "FROM UserAttribute ua WHERE ua.userId = :userId AND ua.isActive = true") + List> getUserAttributesAsMap(@Param("userId") String userId); + + // Check if user has specific attribute value + @Query("SELECT COUNT(ua) > 0 FROM UserAttribute ua WHERE " + + "ua.userId = :userId AND ua.attributeName = :name AND ua.attributeValue = :value AND ua.isActive = true") + boolean userHasAttributeValue(@Param("userId") String userId, + @Param("name") String attributeName, + @Param("value") String attributeValue); + + // Find users with specific attribute + @Query("SELECT DISTINCT ua.userId FROM UserAttribute ua WHERE " + + "ua.attributeName = :name AND ua.attributeValue = :value AND ua.isActive = true") + List findUserIdsWithAttribute(@Param("name") String attributeName, @Param("value") String attributeValue); + + // Count attributes by user + long countByUserIdAndIsActiveTrue(String userId); + + // Count attributes by name + long countByAttributeNameAndIsActiveTrue(String attributeName); + + // Delete by user and attribute name + void deleteByUserIdAndAttributeName(String userId, String attributeName); +} \ No newline at end of file diff --git a/dataplane/src/main/java/io/sentrius/sso/core/services/agents/MemoryAccessControlService.java b/dataplane/src/main/java/io/sentrius/sso/core/services/agents/MemoryAccessControlService.java new file mode 100644 index 00000000..be81452d --- /dev/null +++ b/dataplane/src/main/java/io/sentrius/sso/core/services/agents/MemoryAccessControlService.java @@ -0,0 +1,290 @@ +package io.sentrius.sso.core.services.agents; + +import io.sentrius.sso.core.model.agents.AgentMemory; +import io.sentrius.sso.core.model.agents.MemoryAccessPolicy; +import io.sentrius.sso.core.model.users.UserAttribute; +import io.sentrius.sso.core.repository.MemoryAccessPolicyRepository; +import io.sentrius.sso.core.repository.UserAttributeRepository; +import lombok.extern.slf4j.Slf4j; +import org.springframework.boot.autoconfigure.condition.ConditionalOnWebApplication; +import org.springframework.stereotype.Service; + +import java.util.*; +import java.util.stream.Collectors; + +@Slf4j +@Service +@ConditionalOnWebApplication(type = ConditionalOnWebApplication.Type.SERVLET) +public class MemoryAccessControlService { + + private final MemoryAccessPolicyRepository policyRepository; + private final UserAttributeRepository userAttributeRepository; + + public MemoryAccessControlService( + MemoryAccessPolicyRepository policyRepository, + UserAttributeRepository userAttributeRepository) { + this.policyRepository = policyRepository; + this.userAttributeRepository = userAttributeRepository; + } + + /** + * Main ABAC evaluation method - determines if a user can access a memory item + */ + public boolean canAccessMemory(AgentMemory memory, String userId, String agentId, String accessType) { + log.debug("Evaluating access for user: {}, agent: {}, memory: {}, access: {}", + userId, agentId, memory.getMemoryKey(), accessType); + + // Quick checks for obvious cases + if (memory.isExpired()) { + log.debug("Memory expired, denying access"); + return false; + } + + // If memory is public and access type is READ, allow + if ("PUBLIC".equals(memory.getClassification()) && "READ".equals(accessType)) { + log.debug("Public memory read access granted"); + return true; + } + + // If user is the creator, allow all access types + if (userId != null && userId.equals(memory.getCreatorUserId())) { + log.debug("Creator access granted"); + return true; + } + + // If agent is accessing its own memory, allow based on access level + /* + if (agentId != null && agentId.equals(memory.getAgentId())) { + return evaluateAgentAccess(memory, accessType); + }*/ + + // Check if memory can be shared with the agent + if (agentId != null && memory.canBeSharedWith(agentId)) { + return evaluateSharedAccess(memory, userId, accessType); + } + + // Get user attributes for ABAC evaluation + Map userAttributes = getUserAttributesMap(userId); + + // Get agent attributes (if available) + Map agentAttributes = getAgentAttributesMap(agentId); + + // Find applicable policies + List applicablePolicies = findApplicablePolicies( + memory.getClassification(), memory.getMarkings(), accessType); + + // Evaluate policies + for (MemoryAccessPolicy policy : applicablePolicies) { + if (evaluatePolicy(policy, userAttributes, agentAttributes, memory, userId)) { + log.debug("Access granted by policy: {}", policy.getPolicyName()); + return true; + } + } + + log.debug("Access denied - no applicable policies matched"); + return false; + } + + /** + * Evaluate agent access based on access level + */ + private boolean evaluateAgentAccess(AgentMemory memory, String accessType) { + String accessLevel = memory.getAccessLevel(); + + if ("ALL_USERS".equals(accessLevel)) { + return true; + } + + if ("AGENT_ONLY".equals(accessLevel)) { + return !"DELETE".equals(accessType); // Agent can read/write but not delete its own memory + } + + return false; + } + + /** + * Evaluate shared access + */ + private boolean evaluateSharedAccess(AgentMemory memory, String userId, String accessType) { + // For shared memories, typically allow read access, restrict write/delete + if ("READ".equals(accessType)) { + return true; + } + + // Check if user is creator for write/delete + return userId != null && userId.equals(memory.getCreatorUserId()); + } + + /** + * Find applicable policies for memory access + */ + private List findApplicablePolicies(String classification, String markings, String accessType) { + List allPolicies = policyRepository.findByIsActiveTrueOrderByPolicyName(); + + return allPolicies.stream() + .filter(policy -> policy.appliesToClassification(classification)) + .filter(policy -> policy.appliesToMarkings(markings)) + .filter(policy -> policy.allowsAccessType(accessType)) + .collect(Collectors.toList()); + } + + /** + * Evaluate a specific policy + */ + private boolean evaluatePolicy(MemoryAccessPolicy policy, Map userAttributes, + Map agentAttributes, AgentMemory memory, String userId) { + log.debug("Evaluating policy: {}", policy.getPolicyName()); + + // Special handling for user_id in required attributes + Map evaluatedUserAttributes = new HashMap<>(userAttributes); + if (userId != null) { + evaluatedUserAttributes.put("user_id", userId); + evaluatedUserAttributes.put("created_by", memory.getCreatorUserId()); + } + + // Evaluate user attributes + if (!policy.evaluateUserAttributes(evaluatedUserAttributes)) { + log.debug("Policy {} failed user attribute evaluation", policy.getPolicyName()); + return false; + } + + // Evaluate agent attributes + if (!policy.evaluateAgentAttributes(agentAttributes)) { + log.debug("Policy {} failed agent attribute evaluation", policy.getPolicyName()); + return false; + } + + log.debug("Policy {} passed all evaluations", policy.getPolicyName()); + return true; + } + + /** + * Get user attributes as a map + */ + private Map getUserAttributesMap(String userId) { + if (userId == null) { + return new HashMap<>(); + } + + List attributes = userAttributeRepository.findByUserIdAndIsActiveTrue(userId); + Map attributeMap = new HashMap<>(); + + for (UserAttribute attr : attributes) { + attributeMap.put(attr.getAttributeName(), attr.getAttributeValue()); + } + + // Add default attributes + attributeMap.put("user_id", userId); + + log.debug("Loaded {} attributes for user: {}", attributeMap.size(), userId); + return attributeMap; + } + + /** + * Get agent attributes as a map + * This is a placeholder - actual implementation would depend on how agent attributes are stored + */ + private Map getAgentAttributesMap(String agentId) { + Map attributeMap = new HashMap<>(); + + if (agentId != null) { + attributeMap.put("agent_id", agentId); + // Add more agent-specific attributes as needed + // For example: agent type, capabilities, permissions, etc. + } + + return attributeMap; + } + + /** + * Create a new memory access policy + */ + public MemoryAccessPolicy createPolicy(String policyName, String description, String targetClassification, + String targetMarkings, Map requiredUserAttributes, + String accessType) { + log.info("Creating new memory access policy: {}", policyName); + + MemoryAccessPolicy policy = MemoryAccessPolicy.builder() + .policyName(policyName) + .policyDescription(description) + .targetClassification(targetClassification) + .targetMarkings(targetMarkings) + .accessType(accessType) + .isActive(true) + .build(); + + policy.setRequiredUserAttributesFromMap(requiredUserAttributes); + + return policyRepository.save(policy); + } + + /** + * Check if user has specific attribute value + */ + public boolean userHasAttributeValue(String userId, String attributeName, String attributeValue) { + return userAttributeRepository.userHasAttributeValue(userId, attributeName, attributeValue); + } + + /** + * Get all users with a specific attribute + */ + public List findUsersWithAttribute(String attributeName, String attributeValue) { + return userAttributeRepository.findUserIdsWithAttribute(attributeName, attributeValue); + } + + /** + * Validate memory access request + */ + public AccessValidationResult validateAccess(String agentId, String memoryKey, String userId, String accessType) { + // This method can be used for pre-validation before actual access attempts + // It returns detailed information about why access was granted or denied + + AccessValidationResult result = new AccessValidationResult(); + result.setUserId(userId); + result.setAgentId(agentId); + result.setMemoryKey(memoryKey); + result.setAccessType(accessType); + + // Implementation would include detailed validation logic + // For now, this is a placeholder + result.setAllowed(false); + result.setReason("Validation not implemented"); + + return result; + } + + /** + * Data class for access validation results + */ + public static class AccessValidationResult { + private String userId; + private String agentId; + private String memoryKey; + private String accessType; + private boolean allowed; + private String reason; + private List appliedPolicies = new ArrayList<>(); + + // Getters and setters + public String getUserId() { return userId; } + public void setUserId(String userId) { this.userId = userId; } + + public String getAgentId() { return agentId; } + public void setAgentId(String agentId) { this.agentId = agentId; } + + public String getMemoryKey() { return memoryKey; } + public void setMemoryKey(String memoryKey) { this.memoryKey = memoryKey; } + + public String getAccessType() { return accessType; } + public void setAccessType(String accessType) { this.accessType = accessType; } + + public boolean isAllowed() { return allowed; } + public void setAllowed(boolean allowed) { this.allowed = allowed; } + + public String getReason() { return reason; } + public void setReason(String reason) { this.reason = reason; } + + public List getAppliedPolicies() { return appliedPolicies; } + public void setAppliedPolicies(List appliedPolicies) { this.appliedPolicies = appliedPolicies; } + } +} \ No newline at end of file diff --git a/dataplane/src/main/java/io/sentrius/sso/core/services/agents/PersistentAgentMemoryStore.java b/dataplane/src/main/java/io/sentrius/sso/core/services/agents/PersistentAgentMemoryStore.java new file mode 100644 index 00000000..d37256bb --- /dev/null +++ b/dataplane/src/main/java/io/sentrius/sso/core/services/agents/PersistentAgentMemoryStore.java @@ -0,0 +1,321 @@ +package io.sentrius.sso.core.services.agents; + +import io.sentrius.sso.core.config.SystemOptions; +import io.sentrius.sso.core.model.agents.AgentMemory; +import io.sentrius.sso.core.model.agents.MemoryAccessPolicy; +import io.sentrius.sso.core.repository.AgentMemoryRepository; +import io.sentrius.sso.core.repository.MemoryAccessPolicyRepository; +import io.sentrius.sso.core.repository.UserAttributeRepository; +import io.sentrius.sso.core.utils.JsonUtil; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.autoconfigure.condition.ConditionalOnWebApplication; +import org.springframework.data.domain.Page; +import org.springframework.data.domain.PageRequest; +import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Sort; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +import java.time.Instant; +import java.util.*; +import java.util.stream.Collectors; + +@Slf4j +@Service +@ConditionalOnWebApplication(type = ConditionalOnWebApplication.Type.SERVLET) +public class PersistentAgentMemoryStore { + + private final AgentMemoryRepository agentMemoryRepository; + private final MemoryAccessPolicyRepository policyRepository; + private final UserAttributeRepository userAttributeRepository; + private final MemoryAccessControlService accessControlService; + + private final SystemOptions systemOptions; + + public PersistentAgentMemoryStore( + AgentMemoryRepository agentMemoryRepository, + MemoryAccessPolicyRepository policyRepository, + UserAttributeRepository userAttributeRepository, + MemoryAccessControlService accessControlService, SystemOptions systemOptions) { + this.agentMemoryRepository = agentMemoryRepository; + this.policyRepository = policyRepository; + this.userAttributeRepository = userAttributeRepository; + this.accessControlService = accessControlService; + this.systemOptions = systemOptions; + } + + /** + * Check if memory store is enabled via configuration + */ + private boolean isMemoryStoreEnabled() { + return systemOptions.getEnableMemoryStore(); + } + + /** + * Store memory with markings and access control + */ + @Transactional + public AgentMemory storeMemory(String agentId, String memoryKey, Object memoryValue, + String classification, String[] markings, String creatorUserId) { + + if (!isMemoryStoreEnabled()) { + log.warn("Memory store is disabled via configuration - cannot store memory for agent: {}", agentId); + throw new IllegalStateException("Memory store is disabled"); + } + + log.info("Storing memory for agent: {}, key: {}, classification: {}", agentId, memoryKey, classification); + + try { + // Convert value to JSON string + String valueJson = JsonUtil.MAPPER.writeValueAsString(memoryValue); + + // Check if memory already exists + Optional existing = agentMemoryRepository.findByAgentIdAndMemoryKey(agentId, memoryKey); + + AgentMemory memory; + if (existing.isPresent()) { + memory = existing.get(); + memory.setMemoryValue(valueJson); + memory.setClassification(classification); + memory.setMarkingsArray(markings); + log.info("Updated existing memory for agent: {}, key: {}", agentId, memoryKey); + } else { + memory = AgentMemory.builder() + .agentId(agentId) + .memoryKey(memoryKey) + .memoryValue(valueJson) + .memoryType("JSON") + .classification(classification != null ? classification : "PRIVATE") + .creatorUserId(creatorUserId) + .accessLevel("AGENT_ONLY") + .build(); + memory.setMarkingsArray(markings); + log.info("Created new memory for agent: {}, key: {}", agentId, memoryKey); + } + + return agentMemoryRepository.save(memory); + } catch (JsonProcessingException e) { + log.error("Error serializing memory value for agent: {}, key: {}", agentId, memoryKey, e); + throw new RuntimeException("Failed to store memory", e); + } + } + + /** + * Retrieve memory with access control validation + */ + public Optional retrieveMemory(String agentId, String memoryKey, String requestingUserId) { + if (!isMemoryStoreEnabled()) { + log.warn("Memory store is disabled via configuration - cannot retrieve memory for agent: {}", agentId); + return Optional.empty(); + } + + log.debug("Retrieving memory for agent: {}, key: {}, user: {}", agentId, memoryKey, requestingUserId); + + Optional memoryOpt = agentMemoryRepository.findByAgentIdAndMemoryKey(agentId, memoryKey); + + if (memoryOpt.isEmpty()) { + log.debug("Memory not found for agent: {}, key: {}", agentId, memoryKey); + return Optional.empty(); + } + + AgentMemory memory = memoryOpt.get(); + + // Check if memory is expired + if (memory.isExpired()) { + log.debug("Memory expired for agent: {}, key: {}", agentId, memoryKey); + return Optional.empty(); + } + + // Validate access using ABAC + if (!accessControlService.canAccessMemory(memory, requestingUserId, agentId, "READ")) { + log.warn("Access denied to memory for agent: {}, key: {}, user: {}", agentId, memoryKey, requestingUserId); + return Optional.empty(); + } + + return Optional.of(memory); + } + + /** + * Retrieve memory value as specific type + */ + public Optional retrieveMemoryValue(String agentId, String memoryKey, String requestingUserId, Class valueType) { + Optional memoryOpt = retrieveMemory(agentId, memoryKey, requestingUserId); + + if (memoryOpt.isEmpty()) { + return Optional.empty(); + } + + try { + T value = JsonUtil.MAPPER.readValue(memoryOpt.get().getMemoryValue(), valueType); + return Optional.of(value); + } catch (JsonProcessingException e) { + log.error("Error deserializing memory value for agent: {}, key: {}", agentId, memoryKey, e); + return Optional.empty(); + } + } + + /** + * Retrieve memory value as JsonNode + */ + public Optional retrieveMemoryAsJsonNode(String agentId, String memoryKey, String requestingUserId) { + Optional memoryOpt = retrieveMemory(agentId, memoryKey, requestingUserId); + + if (memoryOpt.isEmpty()) { + return Optional.empty(); + } + + try { + JsonNode value = JsonUtil.MAPPER.readTree(memoryOpt.get().getMemoryValue()); + return Optional.of(value); + } catch (JsonProcessingException e) { + log.error("Error parsing memory value as JsonNode for agent: {}, key: {}", agentId, memoryKey, e); + return Optional.empty(); + } + } + + /** + * Find shareable memories for an agent based on markings and access policies + */ + public List findShareableMemories(String agentId, String requestingUserId) { + log.debug("Finding shareable memories for agent: {}, user: {}", agentId, requestingUserId); + + List shareableMemories = agentMemoryRepository.findShareableMemories(agentId, Instant.now()); + + // Filter based on access control policies + return shareableMemories.stream() + .filter(memory -> accessControlService.canAccessMemory(memory, requestingUserId, agentId, "READ")) + .collect(Collectors.toList()); + } + + /** + * Search memories by markings + */ + public List findMemoriesByMarkings(String marking, String requestingUserId) { + log.debug("Searching memories by marking: {}, user: {}", marking, requestingUserId); + + List memories = agentMemoryRepository.findByMarkingsContaining(marking); + + // Filter based on access control policies + return memories.stream() + .filter(memory -> !memory.isExpired()) + .filter(memory -> accessControlService.canAccessMemory(memory, requestingUserId, null, "READ")) + .collect(Collectors.toList()); + } + + /** + * Query memories with filters and pagination + */ + public Page queryMemories(String agentId, String classification, String markings, + String requestingUserId, Pageable requestedPageable) { + log.debug("Querying memories with filters - agent: {}, classification: {}, markings: {}, user: {}", + agentId, classification, markings, requestingUserId); + + Pageable pageable = PageRequest.of(requestedPageable.getPageNumber(), requestedPageable.getPageSize(), Sort.unsorted()); + + Page memories = agentMemoryRepository.findMemoriesWithFiltersNative( + agentId, classification, markings, Instant.now(), pageable); + + // Note: For large datasets, consider implementing access control at the database level + // For now, we filter in memory + var filteredMemories = memories.map(memory -> + accessControlService.canAccessMemory(memory, requestingUserId, agentId, "READ") ? memory : null) + .map(memory -> memory); // Remove nulls would need additional implementation + return filteredMemories.stream().filter(x -> x != null).collect(Collectors.toList()) + .stream() + .collect(Collectors.collectingAndThen(Collectors.toList(), list -> + new org.springframework.data.domain.PageImpl<>(list, pageable, list.size()))); + } + + /** + * Share memory with specific agents + */ + @Transactional + public boolean shareMemoryWithAgents(String agentId, String memoryKey, String[] targetAgents, String requestingUserId) { + log.info("Sharing memory {} from agent {} with agents: {}", memoryKey, agentId, Arrays.toString(targetAgents)); + + Optional memoryOpt = agentMemoryRepository.findByAgentIdAndMemoryKey(agentId, memoryKey); + + if (memoryOpt.isEmpty()) { + log.warn("Memory not found for sharing: agent={}, key={}", agentId, memoryKey); + return false; + } + + AgentMemory memory = memoryOpt.get(); + + // Check if user can modify this memory + if (!accessControlService.canAccessMemory(memory, requestingUserId, agentId, "WRITE")) { + log.warn("User {} cannot modify memory: agent={}, key={}", requestingUserId, agentId, memoryKey); + return false; + } + + // Update shared agents list + Set currentShared = new HashSet<>(Arrays.asList(memory.getSharedAgentsArray())); + currentShared.addAll(Arrays.asList(targetAgents)); + memory.setSharedAgentsArray(currentShared.toArray(new String[0])); + + agentMemoryRepository.save(memory); + log.info("Successfully shared memory {} with {} agents", memoryKey, targetAgents.length); + return true; + } + + /** + * Delete memory + */ + @Transactional + public boolean deleteMemory(String agentId, String memoryKey, String requestingUserId) { + log.info("Deleting memory: agent={}, key={}, user={}", agentId, memoryKey, requestingUserId); + + Optional memoryOpt = agentMemoryRepository.findByAgentIdAndMemoryKey(agentId, memoryKey); + + if (memoryOpt.isEmpty()) { + log.warn("Memory not found for deletion: agent={}, key={}", agentId, memoryKey); + return false; + } + + AgentMemory memory = memoryOpt.get(); + + // Check if user can delete this memory + if (!accessControlService.canAccessMemory(memory, requestingUserId, agentId, "DELETE")) { + log.warn("User {} cannot delete memory: agent={}, key={}", requestingUserId, agentId, memoryKey); + return false; + } + + agentMemoryRepository.delete(memory); + log.info("Successfully deleted memory: agent={}, key={}", agentId, memoryKey); + return true; + } + + /** + * Clean up expired memories + */ + @Transactional + public void cleanupExpiredMemories() { + log.info("Cleaning up expired memories"); + + List expiredMemories = agentMemoryRepository.findExpiredMemories(Instant.now()); + + if (!expiredMemories.isEmpty()) { + agentMemoryRepository.deleteAll(expiredMemories); + log.info("Cleaned up {} expired memories", expiredMemories.size()); + } + } + + /** + * Get memory statistics for an agent + */ + public Map getMemoryStatistics(String agentId) { + Map stats = new HashMap<>(); + + stats.put("total_memories", agentMemoryRepository.countByAgentId(agentId)); + stats.put("public_memories", agentMemoryRepository.countByClassification("PUBLIC")); + stats.put("private_memories", agentMemoryRepository.countByClassification("PRIVATE")); + stats.put("shared_memories", agentMemoryRepository.countByClassification("SHARED")); + + return stats; + } +} \ No newline at end of file diff --git a/dataplane/src/main/java/io/sentrius/sso/core/services/agents/VectorAgentMemoryStore.java b/dataplane/src/main/java/io/sentrius/sso/core/services/agents/VectorAgentMemoryStore.java new file mode 100644 index 00000000..c79253a0 --- /dev/null +++ b/dataplane/src/main/java/io/sentrius/sso/core/services/agents/VectorAgentMemoryStore.java @@ -0,0 +1,293 @@ +package io.sentrius.sso.core.services.agents; + +import io.sentrius.sso.core.model.agents.AgentMemory; +import io.sentrius.sso.core.repository.AgentMemoryRepository; +import lombok.extern.slf4j.Slf4j; +import org.springframework.boot.autoconfigure.condition.ConditionalOnWebApplication; +import org.springframework.data.domain.PageRequest; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +import java.util.*; +import java.util.stream.Collectors; + +/** + * Vector-enhanced agent memory store that provides semantic search capabilities + * while maintaining the existing ABAC security model and markings-based access control. + */ +@Slf4j +@Service +@ConditionalOnWebApplication(type = ConditionalOnWebApplication.Type.SERVLET) +public class VectorAgentMemoryStore { + + private final PersistentAgentMemoryStore persistentMemoryStore; + private final AgentMemoryRepository agentMemoryRepository; + private final EmbeddingService embeddingService; + private final MemoryAccessControlService accessControlService; + + public VectorAgentMemoryStore( + PersistentAgentMemoryStore persistentMemoryStore, + AgentMemoryRepository agentMemoryRepository, + EmbeddingService embeddingService, + MemoryAccessControlService accessControlService) { + this.persistentMemoryStore = persistentMemoryStore; + this.agentMemoryRepository = agentMemoryRepository; + this.embeddingService = embeddingService; + this.accessControlService = accessControlService; + } + + /** + * Store memory with automatic embedding generation + */ + @Transactional + public AgentMemory storeMemoryWithEmbedding(String agentId, String memoryKey, Object memoryValue, + String classification, String[] markings, String creatorUserId) { + log.info("Storing memory with embedding for agent: {}, key: {}", agentId, memoryKey); + + // Store the memory using the existing service + AgentMemory memory = persistentMemoryStore.storeMemory(agentId, memoryKey, memoryValue, + classification, markings, creatorUserId); + + // Generate and store embedding if embedding service is available + if (embeddingService.isAvailable()) { + try { + generateAndStoreEmbedding(memory); + log.info("Generated embedding for memory: agent={}, key={}", agentId, memoryKey); + } catch (Exception e) { + log.warn("Failed to generate embedding for memory: agent={}, key={}, error={}", + agentId, memoryKey, e.getMessage()); + // Continue without embedding - memory is still stored with text-based search + } + } else { + log.debug("No embedding service available - storing memory without embedding"); + } + + return memory; + } + + /** + * Find semantically similar memories using vector similarity + */ + public List findSimilarMemories(String queryText, String requestingUserId, + int limit, double threshold) { + log.debug("Finding similar memories for query: {}, user: {}", queryText, requestingUserId); + + if (embeddingService == null || !embeddingService.isAvailable()) { + log.debug("No embedding service available - falling back to text search"); + return fallbackToTextSearch(queryText, requestingUserId, limit); + } + + try { + // Generate embedding for the query + float[] queryEmbedding = embeddingService.embed(queryText); + if (queryEmbedding == null) { + return fallbackToTextSearch(queryText, requestingUserId, limit); + } + + String embeddingString = Arrays.toString(queryEmbedding); + + // Find similar memories using vector similarity + List similarMemories = agentMemoryRepository.findSimilarMemories(embeddingString, limit * 2); + + // Filter based on access control and threshold + return similarMemories.stream() + .filter(memory -> !memory.isExpired()) + .filter(memory -> memory.calculateCosineSimilarity(queryEmbedding) >= threshold) + .filter(memory -> accessControlService.canAccessMemory(memory, requestingUserId, null, "READ")) + .limit(limit) + .collect(Collectors.toList()); + + } catch (Exception e) { + log.error("Error in semantic search, falling back to text search", e); + return fallbackToTextSearch(queryText, requestingUserId, limit); + } + } + + /** + * Find similar memories for a specific agent with access control + */ + public List findSimilarMemoriesForAgent(String queryText, String agentId, + String requestingUserId, int limit, double threshold) { + log.debug("Finding similar memories for agent: {}, query: {}, user: {}", agentId, queryText, requestingUserId); + + if (embeddingService == null || !embeddingService.isAvailable()) { + return persistentMemoryStore.findShareableMemories(agentId, requestingUserId) + .stream().limit(limit).collect(Collectors.toList()); + } + + try { + float[] queryEmbedding = embeddingService.embed(queryText); + if (queryEmbedding == null) { + return persistentMemoryStore.findShareableMemories(agentId, requestingUserId) + .stream().limit(limit).collect(Collectors.toList()); + } + + String embeddingString = Arrays.toString(queryEmbedding); + + List similarMemories = agentMemoryRepository.findSimilarMemoriesForAgent( + embeddingString, agentId, threshold, limit * 2); + + return similarMemories.stream() + .filter(memory -> !memory.isExpired()) + .filter(memory -> accessControlService.canAccessMemory(memory, requestingUserId, agentId, "READ")) + .limit(limit) + .collect(Collectors.toList()); + + } catch (Exception e) { + log.error("Error in agent-specific semantic search", e); + return persistentMemoryStore.findShareableMemories(agentId, requestingUserId) + .stream().limit(limit).collect(Collectors.toList()); + } + } + + /** + * Hybrid search combining text and vector similarity with markings filter + */ + public List hybridSearch(String searchTerm, String markingsFilter, + String requestingUserId, int limit, double threshold) { + log.debug("Hybrid search - term: {}, markings: {}, user: {}", searchTerm, markingsFilter, requestingUserId); + + if (embeddingService == null || !embeddingService.isAvailable()) { + return persistentMemoryStore.findMemoriesByMarkings(markingsFilter, requestingUserId) + .stream().limit(limit).collect(Collectors.toList()); + } + + try { + float[] queryEmbedding = embeddingService.embed(searchTerm); + if (queryEmbedding == null) { + return persistentMemoryStore.findMemoriesByMarkings(markingsFilter, requestingUserId) + .stream().limit(limit).collect(Collectors.toList()); + } + + String embeddingString = Arrays.toString(queryEmbedding); + + List results; + if (markingsFilter != null && !markingsFilter.trim().isEmpty()) { + results = agentMemoryRepository.findSimilarMemoriesByMarkings(embeddingString, markingsFilter, limit * 2); + } else { + results = agentMemoryRepository.hybridSearch(searchTerm, embeddingString, threshold, limit * 2); + } + + return results.stream() + .filter(memory -> !memory.isExpired()) + .filter(memory -> accessControlService.canAccessMemory(memory, requestingUserId, null, "READ")) + .limit(limit) + .collect(Collectors.toList()); + + } catch (Exception e) { + log.error("Error in hybrid search", e); + return persistentMemoryStore.findMemoriesByMarkings(markingsFilter, requestingUserId) + .stream().limit(limit).collect(Collectors.toList()); + } + } + + /** + * Generate embeddings for memories that don't have them yet + */ + @Transactional + public void generateMissingEmbeddings(int batchSize) { + if (embeddingService == null || !embeddingService.isAvailable()) { + log.debug("No embedding service available - skipping embedding generation"); + return; + } + + log.info("Generating missing embeddings with batch size: {}", batchSize); + + List memoriesWithoutEmbeddings = agentMemoryRepository + .findMemoriesWithoutEmbeddings(PageRequest.of(0, batchSize)); + + int processed = 0; + for (AgentMemory memory : memoriesWithoutEmbeddings) { + try { + generateAndStoreEmbedding(memory); + processed++; + + if (processed % 10 == 0) { + log.info("Generated embeddings for {} memories", processed); + } + } catch (Exception e) { + log.warn("Failed to generate embedding for memory ID: {}, error: {}", + memory.getId(), e.getMessage()); + } + } + + log.info("Completed embedding generation: {} out of {} memories processed", + processed, memoriesWithoutEmbeddings.size()); + } + + /** + * Get statistics about vector store usage + */ + public Map getVectorStoreStatistics() { + Map stats = new HashMap<>(); + + long totalMemories = agentMemoryRepository.count(); + long memoriesWithEmbeddings = agentMemoryRepository.countMemoriesWithEmbeddings(); + + stats.put("total_memories", totalMemories); + stats.put("memories_with_embeddings", memoriesWithEmbeddings); + stats.put("embedding_coverage_percentage", + totalMemories > 0 ? (memoriesWithEmbeddings * 100.0 / totalMemories) : 0.0); + stats.put("embedding_service_available", embeddingService != null && embeddingService.isAvailable()); + stats.put("vector_store_enabled", true); + + return stats; + } + + // Private helper methods + + private void generateAndStoreEmbedding(AgentMemory memory) { + // Create text for embedding from memory content and metadata + String textForEmbedding = buildTextForEmbedding(memory); + + // Generate embedding + float[] embedding = embeddingService.embed(textForEmbedding); + if (embedding == null) { + throw new RuntimeException("Failed to generate embedding"); + } + + // Store embedding in the memory object + memory.setEmbedding(embedding); + agentMemoryRepository.save(memory); + } + + private String buildTextForEmbedding(AgentMemory memory) { + StringBuilder text = new StringBuilder(); + + // Include memory key and value + if (memory.getMemoryKey() != null) { + text.append(memory.getMemoryKey()).append(" "); + } + if (memory.getMemoryValue() != null) { + text.append(memory.getMemoryValue()).append(" "); + } + + // Include markings for context + if (memory.getMarkings() != null) { + text.append("markings: ").append(memory.getMarkings()).append(" "); + } + + // Include classification for context + if (memory.getClassification() != null) { + text.append("classification: ").append(memory.getClassification()); + } + + return text.toString().trim(); + } + + private List fallbackToTextSearch(String queryText, String requestingUserId, int limit) { + // Use existing text-based search as fallback + return agentMemoryRepository.searchByMemoryValue(queryText) + .stream() + .filter(memory -> !memory.isExpired()) + .filter(memory -> accessControlService.canAccessMemory(memory, requestingUserId, null, "READ")) + .limit(limit) + .collect(Collectors.toList()); + } + + public AgentMemory storeMemoryWithProvidedEmbedding(String agentId, String memoryKey, String memoryValue, String classification, String[] markings, float[] embedding, String userId) { + AgentMemory memory = persistentMemoryStore.storeMemory(agentId, memoryKey, memoryValue, classification, markings, userId); + memory.setEmbedding(embedding); + return agentMemoryRepository.save(memory); + } +} \ No newline at end of file diff --git a/dataplane/src/main/java/io/sentrius/sso/core/services/users/UserAttributeService.java b/dataplane/src/main/java/io/sentrius/sso/core/services/users/UserAttributeService.java new file mode 100644 index 00000000..e14a47d7 --- /dev/null +++ b/dataplane/src/main/java/io/sentrius/sso/core/services/users/UserAttributeService.java @@ -0,0 +1,306 @@ +package io.sentrius.sso.core.services.users; + +import io.sentrius.sso.core.model.users.UserAttribute; +import io.sentrius.sso.core.repository.UserAttributeRepository; +import io.sentrius.sso.core.services.security.KeycloakService; +import lombok.extern.slf4j.Slf4j; +import org.springframework.boot.autoconfigure.condition.ConditionalOnWebApplication; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +import java.util.*; +import java.util.stream.Collectors; +import java.util.Comparator; + +@Slf4j +@Service +@ConditionalOnWebApplication(type = ConditionalOnWebApplication.Type.SERVLET) +public class UserAttributeService { + + private final UserAttributeRepository userAttributeRepository; + private final KeycloakService keycloakService; + + public UserAttributeService(UserAttributeRepository userAttributeRepository, KeycloakService keycloakService) { + this.userAttributeRepository = userAttributeRepository; + this.keycloakService = keycloakService; + } + + /** + * Get all active attributes for a user + */ + public List getUserAttributes(String userId) { + log.debug("Getting attributes for user: {}", userId); + return userAttributeRepository.findByUserIdAndIsActiveTrue(userId); + } + + /** + * Get user attributes as a map + */ + public Map getUserAttributesAsMap(String userId) { + List attributes = getUserAttributes(userId); + return attributes.stream() + .collect(Collectors.toMap( + UserAttribute::getAttributeName, + UserAttribute::getAttributeValue, + (existing, replacement) -> replacement)); + } + + /** + * Get a specific user attribute + */ + public Optional getUserAttribute(String userId, String attributeName) { + return userAttributeRepository.findByUserIdAndAttributeNameAndIsActiveTrue(userId, attributeName); + } + + /** + * Get a user attribute value + */ + public Optional getUserAttributeValue(String userId, String attributeName) { + return getUserAttribute(userId, attributeName) + .map(UserAttribute::getAttributeValue); + } + + /** + * Set a user attribute + */ + @Transactional + public UserAttribute setUserAttribute(String userId, String attributeName, String attributeValue, + String attributeType, String source) { + log.info("Setting attribute for user: {}, name: {}, value: {}", userId, attributeName, attributeValue); + + Optional existingOpt = userAttributeRepository + .findByUserIdAndAttributeNameAndIsActiveTrue(userId, attributeName); + + UserAttribute attribute; + if (existingOpt.isPresent()) { + attribute = existingOpt.get(); + attribute.setAttributeValue(attributeValue); + attribute.setAttributeType(attributeType != null ? attributeType : "STRING"); + attribute.setSource(source != null ? source : "SENTRIUS"); + log.debug("Updated existing attribute for user: {}, name: {}", userId, attributeName); + } else { + attribute = UserAttribute.builder() + .userId(userId) + .attributeName(attributeName) + .attributeValue(attributeValue) + .attributeType(attributeType != null ? attributeType : "STRING") + .source(source != null ? source : "SENTRIUS") + .isActive(true) + .syncedFromKeycloak(false) + .build(); + log.debug("Created new attribute for user: {}, name: {}", userId, attributeName); + } + + // Validate the attribute value for its type + if (!attribute.isValidForType()) { + throw new IllegalArgumentException("Invalid value for attribute type: " + attributeType); + } + + return userAttributeRepository.save(attribute); + } + + /** + * Set multiple user attributes at once + */ + @Transactional + public List setUserAttributes(String userId, Map attributes, String source) { + log.info("Setting {} attributes for user: {}", attributes.size(), userId); + + List savedAttributes = new ArrayList<>(); + for (Map.Entry entry : attributes.entrySet()) { + UserAttribute attr = setUserAttribute(userId, entry.getKey(), entry.getValue(), "STRING", source); + savedAttributes.add(attr); + } + + return savedAttributes; + } + + /** + * Remove a user attribute + */ + @Transactional + public boolean removeUserAttribute(String userId, String attributeName) { + log.info("Removing attribute for user: {}, name: {}", userId, attributeName); + + Optional attributeOpt = userAttributeRepository + .findByUserIdAndAttributeNameAndIsActiveTrue(userId, attributeName); + + if (attributeOpt.isPresent()) { + UserAttribute attribute = attributeOpt.get(); + attribute.setIsActive(false); + userAttributeRepository.save(attribute); + log.info("Deactivated attribute for user: {}, name: {}", userId, attributeName); + return true; + } + + log.warn("Attribute not found for removal: user={}, name={}", userId, attributeName); + return false; + } + + /** + * Sync user attributes from Keycloak + */ + @Transactional + public List syncUserAttributesFromKeycloak(String userId) { + log.info("Syncing user attributes from Keycloak for user: {}", userId); + + try { + Map> keycloakAttributes = keycloakService.getUserAttributes(userId); + List syncedAttributes = new ArrayList<>(); + + for (Map.Entry> entry : keycloakAttributes.entrySet()) { + String attributeName = entry.getKey(); + List values = entry.getValue(); + + if (values != null && !values.isEmpty()) { + // For multiple values, we'll store them as a comma-separated list + String attributeValue = values.size() == 1 ? values.get(0) : String.join(",", values); + String attributeType = values.size() == 1 ? "STRING" : "LIST"; + + UserAttribute attribute = setUserAttribute(userId, attributeName, attributeValue, + attributeType, "KEYCLOAK"); + attribute.setSyncedFromKeycloak(true); + attribute = userAttributeRepository.save(attribute); + syncedAttributes.add(attribute); + } + } + + log.info("Synced {} attributes from Keycloak for user: {}", syncedAttributes.size(), userId); + return syncedAttributes; + + } catch (Exception e) { + log.error("Error syncing attributes from Keycloak for user: {}", userId, e); + return Collections.emptyList(); + } + } + + /** + * Check if user has specific attribute value + */ + public boolean userHasAttributeValue(String userId, String attributeName, String attributeValue) { + return userAttributeRepository.userHasAttributeValue(userId, attributeName, attributeValue); + } + + /** + * Find users with a specific attribute + */ + public List findUsersWithAttribute(String attributeName, String attributeValue) { + return userAttributeRepository.findUserIdsWithAttribute(attributeName, attributeValue); + } + + /** + * Get all unique attribute names + */ + public List getAllAttributeNames() { + return userAttributeRepository.findAll() + .stream() + .filter(UserAttribute::getIsActive) + .map(UserAttribute::getAttributeName) + .distinct() + .sorted() + .collect(Collectors.toList()); + } + + /** + * Get attribute statistics + */ + public Map getAttributeStatistics() { + Map stats = new HashMap<>(); + + List allAttributes = findByIsActiveTrueOrderByAttributeName(); + + // Count by attribute name + Map countByName = allAttributes.stream() + .collect(Collectors.groupingBy(UserAttribute::getAttributeName, Collectors.counting())); + + // Count by source + Map countBySource = allAttributes.stream() + .collect(Collectors.groupingBy(UserAttribute::getSource, Collectors.counting())); + + // Count by type + Map countByType = allAttributes.stream() + .collect(Collectors.groupingBy(UserAttribute::getAttributeType, Collectors.counting())); + + stats.put("total_attributes", (long) allAttributes.size()); + stats.put("unique_attribute_names", (long) countByName.size()); + stats.put("keycloak_synced", countBySource.getOrDefault("KEYCLOAK", 0L)); + stats.put("sentrius_managed", countBySource.getOrDefault("SENTRIUS", 0L)); + + return stats; + } + + /** + * Validate user attributes for ABAC policies + */ + public boolean validateUserForPolicy(String userId, Map requiredAttributes) { + if (requiredAttributes == null || requiredAttributes.isEmpty()) { + return true; + } + + Map userAttributes = getUserAttributesAsMap(userId); + + for (Map.Entry required : requiredAttributes.entrySet()) { + String requiredKey = required.getKey(); + Object requiredValue = required.getValue(); + + // Special handling for user_id + if ("user_id".equals(requiredKey)) { + if (!userId.equals(requiredValue)) { + return false; + } + continue; + } + + if (!userAttributes.containsKey(requiredKey)) { + log.debug("User {} missing required attribute: {}", userId, requiredKey); + return false; + } + + String userValue = userAttributes.get(requiredKey); + if (!requiredValue.toString().equals(userValue)) { + log.debug("User {} attribute {} value mismatch: required={}, actual={}", + userId, requiredKey, requiredValue, userValue); + return false; + } + } + + return true; + } + + /** + * Bulk operation to sync all users' attributes from Keycloak + */ + @Transactional + public void syncAllUsersFromKeycloak() { + log.info("Starting bulk sync of all user attributes from Keycloak"); + + try { + // This would typically get all users from the User repository + // For now, this is a placeholder implementation + log.warn("Bulk sync not implemented - would need access to User repository"); + + } catch (Exception e) { + log.error("Error during bulk sync from Keycloak", e); + } + } + + /** + * Clean up inactive attributes older than specified days + */ + @Transactional + public int cleanupInactiveAttributes(int olderThanDays) { + log.info("Cleaning up inactive attributes older than {} days", olderThanDays); + + // This would require additional query in repository + // For now, return 0 as placeholder + return 0; + } + + private List findByIsActiveTrueOrderByAttributeName() { + return userAttributeRepository.findAll() + .stream() + .filter(UserAttribute::getIsActive) + .sorted(Comparator.comparing(UserAttribute::getAttributeName)) + .collect(Collectors.toList()); + } +} \ No newline at end of file diff --git a/dataplane/src/test/java/io/sentrius/sso/core/TestApplication.java b/dataplane/src/test/java/io/sentrius/sso/core/TestApplication.java new file mode 100644 index 00000000..4171a2c6 --- /dev/null +++ b/dataplane/src/test/java/io/sentrius/sso/core/TestApplication.java @@ -0,0 +1,12 @@ +package io.sentrius.sso.core; + + +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.boot.autoconfigure.domain.EntityScan; +import org.springframework.data.jpa.repository.config.EnableJpaRepositories; + +@EntityScan("io.sentrius.sso.core.model") +@EnableJpaRepositories("io.sentrius.sso.core.repository") +@SpringBootApplication +class TestApplication { +} \ No newline at end of file diff --git a/dataplane/src/test/java/io/sentrius/sso/core/services/agents/AgentMemoryUnitTest.java b/dataplane/src/test/java/io/sentrius/sso/core/services/agents/AgentMemoryUnitTest.java new file mode 100644 index 00000000..1b155f66 --- /dev/null +++ b/dataplane/src/test/java/io/sentrius/sso/core/services/agents/AgentMemoryUnitTest.java @@ -0,0 +1,256 @@ +package io.sentrius.sso.core.services.agents; + +import io.sentrius.sso.core.model.agents.AgentMemory; +import io.sentrius.sso.core.model.agents.MemoryAccessPolicy; +import io.sentrius.sso.core.model.users.UserAttribute; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.DisplayName; + +import java.util.*; + +import static org.junit.jupiter.api.Assertions.*; + +@DisplayName("Agent Memory and Access Control Tests") +class AgentMemoryUnitTest { + + @Test + @DisplayName("Should create agent memory with markings") + void testCreateAgentMemoryWithMarkings() { + // Arrange & Act + AgentMemory memory = AgentMemory.builder() + .agentId("test-agent") + .memoryKey("config-data") + .memoryValue("{\"key\": \"value\"}") + .classification("SHARED") + .markings("DEVELOPMENT,CONFIG") + .creatorUserId("user-123") + .accessLevel("TEAM_MEMBERS") + .build(); + + // Assert + assertNotNull(memory); + assertEquals("test-agent", memory.getAgentId()); + assertEquals("SHARED", memory.getClassification()); + assertTrue(memory.hasMarking("DEVELOPMENT")); + assertTrue(memory.hasMarking("CONFIG")); + assertFalse(memory.hasMarking("PRODUCTION")); + } + + @Test + @DisplayName("Should handle memory sharing between agents") + void testMemorySharing() { + // Arrange + AgentMemory memory = AgentMemory.builder() + .agentId("agent-1") + .memoryKey("shared-knowledge") + .memoryValue("\"shared information\"") + .classification("SHARED") + .sharedWithAgents("agent-2,agent-3") + .build(); + + // Act & Assert + assertTrue(memory.canBeSharedWith("agent-2")); + assertTrue(memory.canBeSharedWith("agent-3")); + assertFalse(memory.canBeSharedWith("agent-4")); + } + + @Test + @DisplayName("Should validate memory expiration") + void testMemoryExpiration() { + // Arrange + AgentMemory expiredMemory = AgentMemory.builder() + .memoryKey("expired-data") + .expiresAt(java.time.Instant.now().minusSeconds(3600)) // 1 hour ago + .build(); + + AgentMemory validMemory = AgentMemory.builder() + .memoryKey("valid-data") + .expiresAt(java.time.Instant.now().plusSeconds(3600)) // 1 hour from now + .build(); + + AgentMemory noExpirationMemory = AgentMemory.builder() + .memoryKey("permanent-data") + .expiresAt(null) // No expiration + .build(); + + // Act & Assert + assertTrue(expiredMemory.isExpired()); + assertFalse(validMemory.isExpired()); + assertFalse(noExpirationMemory.isExpired()); + } + + @Test + @DisplayName("Should validate user attributes") + void testUserAttributeValidation() { + // Arrange & Act + UserAttribute stringAttr = UserAttribute.builder() + .attributeName("team") + .attributeValue("development") + .attributeType("STRING") + .build(); + + UserAttribute intAttr = UserAttribute.builder() + .attributeName("priority") + .attributeValue("5") + .attributeType("INTEGER") + .build(); + + UserAttribute boolAttr = UserAttribute.builder() + .attributeName("active") + .attributeValue("true") + .attributeType("BOOLEAN") + .build(); + + UserAttribute invalidIntAttr = UserAttribute.builder() + .attributeName("invalid") + .attributeValue("not-a-number") + .attributeType("INTEGER") + .build(); + + // Assert + assertTrue(stringAttr.isValidForType()); + assertTrue(intAttr.isValidForType()); + assertTrue(boolAttr.isValidForType()); + assertFalse(invalidIntAttr.isValidForType()); + + assertEquals("development", stringAttr.getStringValue()); + assertEquals(5, intAttr.getIntegerValue()); + assertTrue(boolAttr.getBooleanValue()); + } + + @Test + @DisplayName("Should evaluate memory access policies") + void testMemoryAccessPolicyEvaluation() { + // Arrange + MemoryAccessPolicy policy = MemoryAccessPolicy.builder() + .policyName("TEAM_ACCESS") + .targetClassification("SHARED") + .targetMarkings("DEVELOPMENT") + .accessType("read") + .isActive(true) + .build(); + + Map requiredAttributes = new HashMap<>(); + requiredAttributes.put("team", "development"); + requiredAttributes.put("user_type", "DEVELOPER"); + policy.setRequiredUserAttributesFromMap(requiredAttributes); + + // Test data + Map validUserAttributes = new HashMap<>(); + validUserAttributes.put("team", "development"); + validUserAttributes.put("user_type", "DEVELOPER"); + + Map invalidUserAttributes = new HashMap<>(); + invalidUserAttributes.put("team", "operations"); + invalidUserAttributes.put("user_type", "DEVELOPER"); + + // Act & Assert + assertTrue(policy.appliesToClassification("SHARED")); + assertTrue(policy.appliesToMarkings("DEVELOPMENT,CONFIG")); + assertTrue(policy.allowsAccessType("read")); + + assertTrue(policy.evaluateUserAttributes(validUserAttributes)); + assertFalse(policy.evaluateUserAttributes(invalidUserAttributes)); + } + + @Test + @DisplayName("Should handle agent memory metadata") + void testAgentMemoryMetadata() { + // Arrange + AgentMemory memory = AgentMemory.builder() + .memoryKey("config-with-metadata") + .memoryValue("\"configuration\"") + .build(); + + Map metadata = new HashMap<>(); + metadata.put("category", "configuration"); + metadata.put("priority", 5); + metadata.put("tags", Arrays.asList("config", "system")); + + // Act + memory.setMetadataFromMap(metadata); + + // Assert + Map retrievedMetadata = memory.getMetadataAsMap(); + assertEquals("configuration", retrievedMetadata.get("category")); + assertEquals(5, retrievedMetadata.get("priority")); + assertNotNull(retrievedMetadata.get("tags")); + } + + @Test + @DisplayName("Should handle memory classification levels") + void testMemoryClassificationLevels() { + // Test all classification levels + String[] classifications = {"PUBLIC", "PRIVATE", "SHARED", "CONFIDENTIAL", "SECRET"}; + + for (String classification : classifications) { + AgentMemory memory = AgentMemory.builder() + .memoryKey("test-" + classification.toLowerCase()) + .classification(classification) + .build(); + + assertEquals(classification, memory.getClassification()); + } + } + + @Test + @DisplayName("Should handle complex markings scenarios") + void testComplexMarkingsScenarios() { + // Arrange + AgentMemory memory = AgentMemory.builder() + .memoryKey("complex-markings") + .markings("DEVELOPMENT,TESTING,CONFIG,SENSITIVE") + .build(); + + // Test multiple markings + String[] expectedMarkings = {"DEVELOPMENT", "TESTING", "CONFIG", "SENSITIVE"}; + String[] actualMarkings = memory.getMarkingsArray(); + + assertEquals(expectedMarkings.length, actualMarkings.length); + + for (String expectedMarking : expectedMarkings) { + assertTrue(memory.hasMarking(expectedMarking)); + } + + // Test case insensitive marking check + assertTrue(memory.hasMarking("development")); + assertTrue(memory.hasMarking("DEVELOPMENT")); + assertFalse(memory.hasMarking("PRODUCTION")); + } + + @Test + @DisplayName("Should demonstrate cross-agent memory sharing workflow") + void testCrossAgentMemorySharingWorkflow() { + // Simulate a cross-agent memory sharing scenario + + // Agent 1 creates memory + AgentMemory sharedMemory = AgentMemory.builder() + .agentId("intelligent-agent-1") + .memoryKey("learned-patterns") + .memoryValue("{\"patterns\": [\"pattern1\", \"pattern2\"]}") + .classification("SHARED") + .markings("MACHINE_LEARNING,PATTERNS") + .creatorUserId("data-scientist-1") + .accessLevel("ALL_USERS") + .build(); + + // Agent 1 shares with specific agents + String[] targetAgents = {"intelligent-agent-2", "intelligent-agent-3"}; + sharedMemory.setSharedAgentsArray(targetAgents); + + // Verify sharing setup + assertTrue(sharedMemory.canBeSharedWith("intelligent-agent-2")); + assertTrue(sharedMemory.canBeSharedWith("intelligent-agent-3")); + // ALL_USERS access level allows any agent to access + assertTrue(sharedMemory.canBeSharedWith("intelligent-agent-4")); + + // Verify markings for filtering + assertTrue(sharedMemory.hasMarking("MACHINE_LEARNING")); + assertTrue(sharedMemory.hasMarking("PATTERNS")); + + // Simulate access control check + assertEquals("SHARED", sharedMemory.getClassification()); + assertEquals("ALL_USERS", sharedMemory.getAccessLevel()); + } +} \ No newline at end of file diff --git a/dataplane/src/test/java/io/sentrius/sso/core/services/agents/PersistentAgentMemoryStoreTest.java b/dataplane/src/test/java/io/sentrius/sso/core/services/agents/PersistentAgentMemoryStoreTest.java new file mode 100644 index 00000000..cf937ca4 --- /dev/null +++ b/dataplane/src/test/java/io/sentrius/sso/core/services/agents/PersistentAgentMemoryStoreTest.java @@ -0,0 +1,233 @@ +package io.sentrius.sso.core.services.agents; + +import io.sentrius.sso.core.config.SystemOptions; +import io.sentrius.sso.core.model.agents.AgentMemory; +import io.sentrius.sso.core.model.agents.MemoryAccessPolicy; +import io.sentrius.sso.core.model.users.UserAttribute; +import io.sentrius.sso.core.repository.AgentMemoryRepository; +import io.sentrius.sso.core.repository.MemoryAccessPolicyRepository; +import io.sentrius.sso.core.repository.UserAttributeRepository; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.*; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +@ExtendWith(MockitoExtension.class) +class PersistentAgentMemoryStoreTest { + + @Mock + private AgentMemoryRepository agentMemoryRepository; + + @Mock + private MemoryAccessPolicyRepository policyRepository; + + @Mock + private UserAttributeRepository userAttributeRepository; + + @Mock + private MemoryAccessControlService accessControlService; + + private PersistentAgentMemoryStore memoryStore; + + @BeforeEach + void setUp() { + SystemOptions systemOptions = new SystemOptions(); + memoryStore = new PersistentAgentMemoryStore( + agentMemoryRepository, + policyRepository, + userAttributeRepository, + accessControlService, systemOptions + ); + } + + @Test + void testStoreMemory_NewMemory_ShouldCreateAndSave() { + // Arrange + String agentId = "test-agent"; + String memoryKey = "test-key"; + String memoryValue = "test-value"; + String classification = "PRIVATE"; + String[] markings = {"TEST", "DEMO"}; + String creatorUserId = "user-123"; + + when(agentMemoryRepository.findByAgentIdAndMemoryKey(agentId, memoryKey)) + .thenReturn(Optional.empty()); + when(agentMemoryRepository.save(any(AgentMemory.class))) + .thenAnswer(invocation -> invocation.getArgument(0)); + + // Act + AgentMemory result = memoryStore.storeMemory(agentId, memoryKey, memoryValue, + classification, markings, creatorUserId); + + // Assert + assertNotNull(result); + assertEquals(agentId, result.getAgentId()); + assertEquals(memoryKey, result.getMemoryKey()); + assertEquals(classification, result.getClassification()); + assertEquals(creatorUserId, result.getCreatorUserId()); + verify(agentMemoryRepository).save(any(AgentMemory.class)); + } + + @Test + void testRetrieveMemory_ExistingMemory_ShouldReturnMemory() { + // Arrange + String agentId = "test-agent"; + String memoryKey = "test-key"; + String requestingUserId = "user-123"; + + AgentMemory memory = AgentMemory.builder() + .agentId(agentId) + .memoryKey(memoryKey) + .memoryValue("\"test-value\"") + .classification("PRIVATE") + .creatorUserId(requestingUserId) + .build(); + + when(agentMemoryRepository.findByAgentIdAndMemoryKey(agentId, memoryKey)) + .thenReturn(Optional.of(memory)); + when(accessControlService.canAccessMemory(memory, requestingUserId, agentId, "READ")) + .thenReturn(true); + + // Act + Optional result = memoryStore.retrieveMemory(agentId, memoryKey, requestingUserId); + + // Assert + assertTrue(result.isPresent()); + assertEquals(memory, result.get()); + verify(accessControlService).canAccessMemory(memory, requestingUserId, agentId, "READ"); + } + + @Test + void testRetrieveMemory_AccessDenied_ShouldReturnEmpty() { + // Arrange + String agentId = "test-agent"; + String memoryKey = "test-key"; + String requestingUserId = "user-123"; + + AgentMemory memory = AgentMemory.builder() + .agentId(agentId) + .memoryKey(memoryKey) + .memoryValue("\"test-value\"") + .classification("CONFIDENTIAL") + .creatorUserId("other-user") + .build(); + + when(agentMemoryRepository.findByAgentIdAndMemoryKey(agentId, memoryKey)) + .thenReturn(Optional.of(memory)); + when(accessControlService.canAccessMemory(memory, requestingUserId, agentId, "READ")) + .thenReturn(false); + + // Act + Optional result = memoryStore.retrieveMemory(agentId, memoryKey, requestingUserId); + + // Assert + assertTrue(result.isEmpty()); + verify(accessControlService).canAccessMemory(memory, requestingUserId, agentId, "READ"); + } + + @Test + void testFindShareableMemories_ShouldFilterByAccessControl() { + // Arrange + String agentId = "test-agent"; + String requestingUserId = "user-123"; + + List shareableMemories = Arrays.asList( + AgentMemory.builder().agentId(agentId).memoryKey("key1").classification("PUBLIC").build(), + AgentMemory.builder().agentId("other-agent").memoryKey("key2").classification("SHARED").build() + ); + + when(agentMemoryRepository.findShareableMemories(eq(agentId), any())) + .thenReturn(shareableMemories); + when(accessControlService.canAccessMemory(any(), eq(requestingUserId), eq(agentId), eq("READ"))) + .thenReturn(true, false); // First memory allowed, second denied + + // Act + List result = memoryStore.findShareableMemories(agentId, requestingUserId); + + // Assert + assertEquals(1, result.size()); + assertEquals("key1", result.get(0).getMemoryKey()); + } + + @Test + void testShareMemoryWithAgents_SuccessfulSharing_ShouldUpdateMemory() { + // Arrange + String agentId = "test-agent"; + String memoryKey = "test-key"; + String[] targetAgents = {"agent-1", "agent-2"}; + String requestingUserId = "user-123"; + + AgentMemory memory = AgentMemory.builder() + .agentId(agentId) + .memoryKey(memoryKey) + .sharedWithAgents("") + .build(); + + when(agentMemoryRepository.findByAgentIdAndMemoryKey(agentId, memoryKey)) + .thenReturn(Optional.of(memory)); + when(accessControlService.canAccessMemory(memory, requestingUserId, agentId, "WRITE")) + .thenReturn(true); + when(agentMemoryRepository.save(any(AgentMemory.class))) + .thenAnswer(invocation -> invocation.getArgument(0)); + + // Act + boolean result = memoryStore.shareMemoryWithAgents(agentId, memoryKey, targetAgents, requestingUserId); + + // Assert + assertTrue(result); + verify(agentMemoryRepository).save(memory); + assertTrue(memory.canBeSharedWith("agent-1")); + assertTrue(memory.canBeSharedWith("agent-2")); + } + + @Test + void testDeleteMemory_SuccessfulDeletion_ShouldDeleteMemory() { + // Arrange + String agentId = "test-agent"; + String memoryKey = "test-key"; + String requestingUserId = "user-123"; + + AgentMemory memory = AgentMemory.builder() + .agentId(agentId) + .memoryKey(memoryKey) + .build(); + + when(agentMemoryRepository.findByAgentIdAndMemoryKey(agentId, memoryKey)) + .thenReturn(Optional.of(memory)); + when(accessControlService.canAccessMemory(memory, requestingUserId, agentId, "DELETE")) + .thenReturn(true); + + // Act + boolean result = memoryStore.deleteMemory(agentId, memoryKey, requestingUserId); + + // Assert + assertTrue(result); + verify(agentMemoryRepository).delete(memory); + } + + @Test + void testGetMemoryStatistics_ShouldReturnCorrectCounts() { + // Arrange + String agentId = "test-agent"; + when(agentMemoryRepository.countByAgentId(agentId)).thenReturn(5L); + when(agentMemoryRepository.countByClassification("PUBLIC")).thenReturn(2L); + when(agentMemoryRepository.countByClassification("PRIVATE")).thenReturn(3L); + when(agentMemoryRepository.countByClassification("SHARED")).thenReturn(1L); + + // Act + Map stats = memoryStore.getMemoryStatistics(agentId); + + // Assert + assertEquals(5L, stats.get("total_memories")); + assertEquals(2L, stats.get("public_memories")); + assertEquals(3L, stats.get("private_memories")); + assertEquals(1L, stats.get("shared_memories")); + } +} \ No newline at end of file diff --git a/dataplane/src/test/java/io/sentrius/sso/core/services/agents/VectorAgentMemoryStoreTest.java b/dataplane/src/test/java/io/sentrius/sso/core/services/agents/VectorAgentMemoryStoreTest.java new file mode 100644 index 00000000..b056c333 --- /dev/null +++ b/dataplane/src/test/java/io/sentrius/sso/core/services/agents/VectorAgentMemoryStoreTest.java @@ -0,0 +1,218 @@ +package io.sentrius.sso.core.services.agents; + +import io.sentrius.sso.core.model.agents.AgentMemory; +import io.sentrius.sso.core.repository.AgentMemoryRepository; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +@ExtendWith(MockitoExtension.class) +class VectorAgentMemoryStoreTest { + + @Mock + private PersistentAgentMemoryStore persistentMemoryStore; + + @Mock + private AgentMemoryRepository agentMemoryRepository; + + @Mock + private EmbeddingService embeddingService; + + @Mock + private MemoryAccessControlService accessControlService; + + private VectorAgentMemoryStore vectorMemoryStore; + + @BeforeEach + void setUp() { + vectorMemoryStore = new VectorAgentMemoryStore( + persistentMemoryStore, + agentMemoryRepository, + embeddingService, + accessControlService + ); + } + + @Test + void testStoreMemoryWithEmbedding_Success() { + // Arrange + String agentId = "test-agent"; + String memoryKey = "test-key"; + String memoryValue = "test memory content"; + String classification = "PRIVATE"; + String[] markings = {"AI", "MEMORY"}; + String userId = "test-user"; + + AgentMemory savedMemory = new AgentMemory(); + savedMemory.setId(1L); + savedMemory.setAgentId(agentId); + savedMemory.setMemoryKey(memoryKey); + savedMemory.setMemoryValue(memoryValue); + + float[] mockEmbedding = {0.1f, 0.2f, 0.3f}; + + when(persistentMemoryStore.storeMemory(any(), any(), any(), any(), any(), any())) + .thenReturn(savedMemory); + when(embeddingService.isAvailable()).thenReturn(true); + when(embeddingService.embed(anyString())).thenReturn(mockEmbedding); + when(agentMemoryRepository.save(any())).thenReturn(savedMemory); + + // Act + AgentMemory result = vectorMemoryStore.storeMemoryWithEmbedding( + agentId, memoryKey, memoryValue, classification, markings, userId); + + // Assert + assertNotNull(result); + assertEquals(agentId, result.getAgentId()); + verify(persistentMemoryStore).storeMemory(agentId, memoryKey, memoryValue, classification, markings, userId); + verify(embeddingService).embed(anyString()); + verify(agentMemoryRepository).save(savedMemory); + } + + @Test + void testStoreMemoryWithEmbedding_EmbeddingServiceNotAvailable() { + // Arrange + String agentId = "test-agent"; + String memoryKey = "test-key"; + String memoryValue = "test memory content"; + String classification = "PRIVATE"; + String[] markings = {"AI", "MEMORY"}; + String userId = "test-user"; + + AgentMemory savedMemory = new AgentMemory(); + savedMemory.setId(1L); + savedMemory.setAgentId(agentId); + + when(persistentMemoryStore.storeMemory(any(), any(), any(), any(), any(), any())) + .thenReturn(savedMemory); + when(embeddingService.isAvailable()).thenReturn(false); + + // Act + AgentMemory result = vectorMemoryStore.storeMemoryWithEmbedding( + agentId, memoryKey, memoryValue, classification, markings, userId); + + // Assert + assertNotNull(result); + verify(persistentMemoryStore).storeMemory(agentId, memoryKey, memoryValue, classification, markings, userId); + verify(embeddingService, never()).embed(anyString()); + verify(agentMemoryRepository, never()).save(any()); + } + + @Test + void testFindSimilarMemories_WithEmbeddingService() { + // Arrange + String queryText = "test query"; + String userId = "test-user"; + int limit = 5; + double threshold = 0.7; + + float[] queryEmbedding = {0.1f, 0.2f, 0.3f}; + AgentMemory memory1 = createTestMemory(1L, "agent1", "key1", "value1"); + memory1.setEmbedding(new float[]{0.11f, 0.21f, 0.31f}); + + when(embeddingService.isAvailable()).thenReturn(true); + when(embeddingService.embed(queryText)).thenReturn(queryEmbedding); + when(agentMemoryRepository.findSimilarMemories(anyString(), eq(limit * 2))) + .thenReturn(Arrays.asList(memory1)); + when(accessControlService.canAccessMemory(any(), eq(userId), any(), eq("READ"))) + .thenReturn(true); + + // Act + List result = vectorMemoryStore.findSimilarMemories(queryText, userId, limit, threshold); + + // Assert + assertNotNull(result); + assertEquals(1, result.size()); + assertEquals(memory1, result.get(0)); + verify(embeddingService).embed(queryText); + verify(agentMemoryRepository).findSimilarMemories(anyString(), eq(limit * 2)); + } + + @Test + void testFindSimilarMemories_FallbackToTextSearch() { + // Arrange + String queryText = "test query"; + String userId = "test-user"; + int limit = 5; + double threshold = 0.7; + + AgentMemory memory1 = createTestMemory(1L, "agent1", "key1", "value1"); + + when(embeddingService.isAvailable()).thenReturn(false); + when(agentMemoryRepository.searchByMemoryValue(queryText)) + .thenReturn(Arrays.asList(memory1)); + when(accessControlService.canAccessMemory(any(), eq(userId), any(), eq("READ"))) + .thenReturn(true); + + // Act + List result = vectorMemoryStore.findSimilarMemories(queryText, userId, limit, threshold); + + // Assert + assertNotNull(result); + assertEquals(1, result.size()); + verify(agentMemoryRepository).searchByMemoryValue(queryText); + verify(embeddingService, never()).embed(anyString()); + } + + @Test + void testGetVectorStoreStatistics() { + // Arrange + when(agentMemoryRepository.count()).thenReturn(100L); + when(agentMemoryRepository.countMemoriesWithEmbeddings()).thenReturn(75L); + when(embeddingService.isAvailable()).thenReturn(true); + + // Act + Map stats = vectorMemoryStore.getVectorStoreStatistics(); + + // Assert + assertNotNull(stats); + assertEquals(100L, stats.get("total_memories")); + assertEquals(75L, stats.get("memories_with_embeddings")); + assertEquals(75.0, stats.get("embedding_coverage_percentage")); + assertEquals(true, stats.get("embedding_service_available")); + assertEquals(true, stats.get("vector_store_enabled")); + } + + @Test + void testGenerateMissingEmbeddings() { + // Arrange + int batchSize = 10; + AgentMemory memory1 = createTestMemory(1L, "agent1", "key1", "value1"); + float[] mockEmbedding = {0.1f, 0.2f, 0.3f}; + + when(embeddingService.isAvailable()).thenReturn(true); + when(agentMemoryRepository.findMemoriesWithoutEmbeddings(any())) + .thenReturn(Arrays.asList(memory1)); + when(embeddingService.embed(anyString())).thenReturn(mockEmbedding); + when(agentMemoryRepository.save(any())).thenReturn(memory1); + + // Act + vectorMemoryStore.generateMissingEmbeddings(batchSize); + + // Assert + verify(agentMemoryRepository).findMemoriesWithoutEmbeddings(any()); + verify(embeddingService).embed(anyString()); + verify(agentMemoryRepository).save(memory1); + } + + private AgentMemory createTestMemory(Long id, String agentId, String key, String value) { + AgentMemory memory = new AgentMemory(); + memory.setId(id); + memory.setAgentId(agentId); + memory.setMemoryKey(key); + memory.setMemoryValue(value); + memory.setClassification("PRIVATE"); + memory.setAccessLevel("AGENT_ONLY"); + return memory; + } +} \ No newline at end of file diff --git a/docker/keycloak/realms/sentrius-realm.json.template b/docker/keycloak/realms/sentrius-realm.json.template index ae86dc63..059350c5 100644 --- a/docker/keycloak/realms/sentrius-realm.json.template +++ b/docker/keycloak/realms/sentrius-realm.json.template @@ -17,7 +17,8 @@ "attributes": { "access.token.claim": "true", "id.token.claim": "true", - "userinfo.token.claim": "true" + "userinfo.token.claim": "true", + "initial.user.type": "System Admin" }, "protocolMappers": [ { @@ -33,6 +34,20 @@ "jsonType.label": "String", "user.attribute": "userType" } + }, + { + "name": "initialUserType", + "protocol": "openid-connect", + "protocolMapper": "oidc-hardcoded-claim-mapper", + "consentRequired": false, + "config": { + "claim.name": "initial_user_type", + "jsonType.label": "String", + "claim.value": "System Admin", + "access.token.claim": "true", + "id.token.claim": "true", + "userinfo.token.claim": "true" + } } ] }, diff --git a/docs/capabilities-api.md b/docs/capabilities-api.md index 51a85cf0..1052f17f 100644 --- a/docs/capabilities-api.md +++ b/docs/capabilities-api.md @@ -34,7 +34,7 @@ Returns all available endpoints (both REST and Verb) with optional filtering: "type": "REST", "httpMethod": "GET", "path": "/api/v1/users/list", - "className": "io.sentrius.sso.controllers.api.UserApiController", + "className": "io.sentrius.sso.controllers.api.users.UserApiController", "methodName": "listusers", "requiresAuthentication": true, "accessLimitations": { diff --git a/docs/vector-store-enhancement.md b/docs/vector-store-enhancement.md new file mode 100644 index 00000000..394d1804 --- /dev/null +++ b/docs/vector-store-enhancement.md @@ -0,0 +1,344 @@ +# Agent Memory Store Vector Search Enhancement + +## Overview + +The Agent Memory Store has been enhanced with vector search capabilities, enabling semantic similarity search while maintaining the existing ABAC (Attribute-Based Access Control) security model. This enhancement allows agents to discover conceptually related memories through embeddings rather than just exact keyword matches. + +## Features + +### Core Vector Store Capabilities + +1. **PostgreSQL + pgvector Integration** + - Uses pgvector extension for efficient vector operations + - Stores 1536-dimensional embeddings (compatible with OpenAI's text-embedding-3-small) + - Cosine similarity distance calculations + +2. **Automatic Embedding Generation** + - Integrates with OpenAI's embedding API + - Configurable via `spring.ai.openai.api-key` property + - Automatic embedding generation for new memories when enabled + +3. **Hybrid Search Capabilities** + - Combines vector similarity with traditional text search + - Maintains all existing markings and classification filters + - Preserves ABAC security model + +4. **Access Control Integration** + - All vector searches respect existing security policies + - Markings-based filtering applied to vector results + - User attribute validation maintained + +## Database Schema Changes + +### Migration V22: Vector Support + +```sql +-- Enable pgvector extension +CREATE EXTENSION IF NOT EXISTS vector; + +-- Add embedding column to agent_memory table +ALTER TABLE agent_memory ADD COLUMN embedding vector(1536); + +-- Create vector similarity index +CREATE INDEX idx_agent_memory_embedding ON agent_memory +USING ivfflat (embedding vector_cosine_ops); + +-- Create hybrid search indexes +CREATE INDEX idx_agent_memory_embedding_classification +ON agent_memory (classification, embedding); + +CREATE INDEX idx_agent_memory_embedding_markings +ON agent_memory (markings, embedding); +``` + +## Configuration + +### Required Properties + +```properties +# OpenAI API Key (required for embedding generation) +spring.ai.openai.api-key=your-openai-api-key + +# Vector store configuration (optional) +sentrius.memory.vector.dimension=1536 +sentrius.memory.vector.similarity-threshold=0.7 +sentrius.memory.vector.enabled=true +``` + +### Optional Configuration + +```properties +# Database connection must support pgvector +spring.datasource.url=jdbc:postgresql://localhost:5432/sentrius_db +``` + +## API Endpoints + +### Enhanced Memory Storage + +```http +POST /api/v1/agents/memory/{agentId}?generateEmbedding=true +Content-Type: application/json + +{ + "memoryKey": "user_preferences", + "memoryValue": "User prefers dark mode and compact layouts", + "classification": "PRIVATE", + "markings": ["UI", "PREFERENCES"] +} +``` + +### Semantic Search + +```http +POST /api/v1/agents/memory/search/semantic +Content-Type: application/json + +{ + "query": "user interface settings", + "limit": 10, + "threshold": 0.7 +} +``` + +### Agent-Specific Semantic Search + +```http +POST /api/v1/agents/memory/search/semantic/{agentId} +Content-Type: application/json + +{ + "query": "machine learning algorithms", + "limit": 5, + "threshold": 0.8 +} +``` + +### Hybrid Search + +```http +POST /api/v1/agents/memory/search/hybrid +Content-Type: application/json + +{ + "searchTerm": "neural networks", + "markings": "AI,RESEARCH", + "limit": 10, + "threshold": 0.7 +} +``` + +### Embedding Management + +```http +# Generate embeddings for existing memories without embeddings +POST /api/v1/agents/memory/embeddings/generate?batchSize=100 + +# Get vector store statistics +GET /api/v1/agents/memory/statistics/vector +``` + +## Usage Examples + +### Java Service Layer + +```java +@Autowired +private VectorAgentMemoryStore vectorMemoryStore; + +// Store memory with automatic embedding generation +AgentMemory memory = vectorMemoryStore.storeMemoryWithEmbedding( + "agent-001", + "conversation_summary", + "Discussion about machine learning best practices", + "SHARED", + new String[]{"AI", "CONVERSATION"}, + "user-123" +); + +// Find semantically similar memories +List similar = vectorMemoryStore.findSimilarMemories( + "artificial intelligence techniques", + "user-123", + 10, + 0.7 +); + +// Hybrid search with markings filter +List results = vectorMemoryStore.hybridSearch( + "deep learning", + "AI,RESEARCH", + "user-123", + 5, + 0.8 +); +``` + +### REST API Usage + +```bash +# Store memory with embedding +curl -X POST "http://localhost:8080/api/v1/agents/memory/agent-001?generateEmbedding=true" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $JWT_TOKEN" \ + -d '{ + "memoryKey": "ai_research_notes", + "memoryValue": "Recent advances in transformer architectures show promising results", + "classification": "SHARED", + "markings": ["AI", "RESEARCH", "TRANSFORMERS"] + }' + +# Semantic search +curl -X POST "http://localhost:8080/api/v1/agents/memory/search/semantic" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $JWT_TOKEN" \ + -d '{ + "query": "neural network architectures", + "limit": 5, + "threshold": 0.75 + }' + +# Check vector store statistics +curl -X GET "http://localhost:8080/api/v1/agents/memory/statistics/vector" \ + -H "Authorization: Bearer $JWT_TOKEN" +``` + +## Search Patterns + +### 1. Exact Match + Semantic Fallback + +```java +// First try exact match, then semantic if no results +List exact = persistentMemoryStore.findMemoriesByMarkings("AI", userId); +if (exact.isEmpty()) { + List semantic = vectorMemoryStore.findSimilarMemories("AI", userId, 10, 0.7); +} +``` + +### 2. Hybrid Search for Best Coverage + +```java +// Combine text matching with semantic similarity +List hybrid = vectorMemoryStore.hybridSearch( + "machine learning", "AI", userId, 10, 0.7); +``` + +### 3. Cross-Agent Knowledge Discovery + +```java +// Find related memories across all accessible agents +List discoveries = vectorMemoryStore.findSimilarMemories( + "recommendation systems", userId, 20, 0.6); +``` + +## Performance Considerations + +### Embedding Generation + +- Embeddings are generated asynchronously when possible +- Batch processing available for existing memories +- OpenAI API rate limits apply (consider caching strategies) + +### Vector Search Performance + +- pgvector indexes optimize similarity queries +- Consider partitioning for large datasets +- Monitor query performance and adjust similarity thresholds + +### Storage Impact + +- Each embedding adds ~6KB per memory (1536 float values) +- Consider memory lifecycle policies for embedding cleanup +- Index maintenance overhead for large datasets + +## Security Model + +### Access Control Preservation + +All vector search operations maintain existing security: + +- **ABAC Policies**: User attributes checked before returning results +- **Markings**: Classification and markings filters applied to vector results +- **Agent Ownership**: Agent-specific searches respect ownership rules +- **Expiration**: Expired memories excluded from vector searches + +### Privacy Considerations + +- Embeddings contain semantic information about original text +- Consider classification-based embedding access policies +- Audit trail maintained for all vector operations + +## Monitoring and Maintenance + +### Statistics Available + +```json +{ + "total_memories": 1500, + "memories_with_embeddings": 1200, + "embedding_coverage_percentage": 80.0, + "embedding_service_available": true, + "vector_store_enabled": true +} +``` + +### Maintenance Operations + +```bash +# Generate missing embeddings +curl -X POST "http://localhost:8080/api/v1/agents/memory/embeddings/generate?batchSize=50" + +# Clean up expired memories (includes embeddings) +curl -X POST "http://localhost:8080/api/v1/agents/memory/cleanup/expired" +``` + +## Migration Path + +### For Existing Installations + +1. **Database Setup**: Apply migration V22 to add vector support +2. **Configuration**: Add OpenAI API key to configuration +3. **Embedding Generation**: Use batch endpoint to generate embeddings for existing memories +4. **Application Update**: Deploy updated services with vector capabilities +5. **Verification**: Check vector store statistics and test semantic search + +### Gradual Adoption + +- Vector features are optional and backwards compatible +- Existing text search continues to work unchanged +- Embeddings generated on-demand for new memories +- Fallback to text search when vector search unavailable + +## Troubleshooting + +### Common Issues + +1. **No embeddings generated**: Check OpenAI API key configuration +2. **Slow vector queries**: Verify pgvector indexes are created +3. **Memory without embeddings**: Use batch generation endpoint +4. **API rate limits**: Implement request throttling for embedding generation + +### Logs to Monitor + +``` +# Embedding service availability +VectorStoreConfig: Vector store configuration: enabled=true + +# Embedding generation +VectorAgentMemoryStore: Generated embedding for memory: agent=agent-001, key=summary + +# Search performance +VectorAgentMemoryStore: Semantic search query: neural networks, limit: 10, threshold: 0.7 +``` + +## Future Enhancements + +### Planned Features + +- **Embedding Model Selection**: Support for different embedding models +- **Vector Quantization**: Optimize storage for large scale deployments +- **Semantic Clustering**: Group related memories automatically +- **Cross-Modal Search**: Image and text embedding integration +- **Vector Database Options**: Support for specialized vector databases + +This enhancement provides a solid foundation for semantic memory search while maintaining the security and access control features of the existing system. \ No newline at end of file diff --git a/integration-proxy/src/main/java/io/sentrius/sso/config/SecurityConfig.java b/integration-proxy/src/main/java/io/sentrius/sso/config/SecurityConfig.java index 4b379185..5c8912be 100644 --- a/integration-proxy/src/main/java/io/sentrius/sso/config/SecurityConfig.java +++ b/integration-proxy/src/main/java/io/sentrius/sso/config/SecurityConfig.java @@ -10,6 +10,7 @@ import io.sentrius.sso.core.security.CustomAuthenticationSuccessHandler; import io.sentrius.sso.core.services.CustomUserDetailsService; import io.sentrius.sso.core.services.UserService; +import io.sentrius.sso.core.services.security.KeycloakService; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; @@ -33,6 +34,7 @@ public class SecurityConfig { private final CustomUserDetailsService userDetailsService; private final CustomAuthenticationSuccessHandler successHandler; private final KeycloakAuthSuccessHandler keycloakAuthSuccessHandler; + final UserService userService; @Value("${https.required:false}") // Default is false @@ -84,8 +86,18 @@ public JwtAuthenticationConverter jwtAuthenticationConverterForKeycloak() { User user = userService.getUserByUsername(username); if (user == null) { + + var initialUserType = jwt.getClaimAsString("initial_user_type"); var type = userService.getUserType( UserType.createUnknownUser()); + if (null != initialUserType) { + log.info("Initial user type from token: {}", initialUserType); + type = userService.getUserType(initialUserType); + } else { + log.warn("No initial user type found in token, defaulting to UNKNOWN"); + } + + if (type.isEmpty()) { log.error("Failed to create base user type"); return authorities; diff --git a/integration-proxy/src/main/java/io/sentrius/sso/controllers/api/EmbeddingProxyController.java b/integration-proxy/src/main/java/io/sentrius/sso/controllers/api/EmbeddingProxyController.java new file mode 100644 index 00000000..9f34a119 --- /dev/null +++ b/integration-proxy/src/main/java/io/sentrius/sso/controllers/api/EmbeddingProxyController.java @@ -0,0 +1,248 @@ +package io.sentrius.sso.controllers.api; + +import io.sentrius.sso.core.config.SystemOptions; +import io.sentrius.sso.core.controllers.BaseController; +import io.sentrius.sso.core.integrations.external.ExternalIntegrationDTO; +import io.sentrius.sso.core.services.ErrorOutputService; +import io.sentrius.sso.core.services.UserService; +import io.sentrius.sso.core.services.security.IntegrationSecurityTokenService; +import io.sentrius.sso.core.services.security.KeycloakService; +import io.sentrius.sso.core.utils.JsonUtil; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.*; +import org.springframework.web.client.RestTemplate; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Proxy controller for OpenAI embedding operations. + * Handles embedding generation through the integration proxy for proper security and tracing. + */ +@Slf4j +@RestController +@RequestMapping("/api/v1/embeddings") +public class EmbeddingProxyController extends BaseController { + + private final IntegrationSecurityTokenService integrationSecurityTokenService; + private final KeycloakService keycloakService; + private final RestTemplate restTemplate; + private final String openAiApiUrl = "https://api.openai.com/v1/embeddings"; + + public EmbeddingProxyController( + UserService userService, + SystemOptions systemOptions, + ErrorOutputService errorOutputService, + IntegrationSecurityTokenService integrationSecurityTokenService, + KeycloakService keycloakService, + RestTemplate restTemplate) { + super(userService, systemOptions, errorOutputService); + this.integrationSecurityTokenService = integrationSecurityTokenService; + this.keycloakService = keycloakService; + this.restTemplate = restTemplate; + } + + /** + * Generate embedding for the given text using OpenAI's embedding model + */ + @PostMapping("/generate") + public ResponseEntity generateEmbedding( + @RequestHeader("Authorization") String token, + @RequestBody Map request, + HttpServletRequest httpRequest, + HttpServletResponse httpResponse) { + + String compactJwt = token.startsWith("Bearer ") ? token.substring(7) : token; + + if (!keycloakService.validateJwt(compactJwt)) { + log.warn("Invalid Keycloak token for embedding generation"); + return ResponseEntity.status(HttpStatus.UNAUTHORIZED).body("Invalid Keycloak token"); + } + + var operatingUser = getOperatingUser(httpRequest, httpResponse); + if (operatingUser == null) { + var username = keycloakService.extractUsername(compactJwt); + operatingUser = userService.getUserByUsername(username); + } + + if (operatingUser == null) { + log.warn("No operating user found for embedding generation"); + return ResponseEntity.status(HttpStatus.UNAUTHORIZED).body("User not found"); + } + + // Get OpenAI integration token + var openAiToken = integrationSecurityTokenService.findByConnectionType("openai") + .stream().findFirst().orElse(null); + + if (openAiToken == null) { + log.warn("No OpenAI integration found for embedding generation"); + return ResponseEntity.status(HttpStatus.UNAUTHORIZED).body("No OpenAI integration found"); + } + + try { + ExternalIntegrationDTO integrationDTO = JsonUtil.MAPPER.readValue( + openAiToken.getConnectionInfo(), ExternalIntegrationDTO.class); + + String text = (String) request.get("text"); + if (text == null || text.trim().isEmpty()) { + text = (String) request.get("input"); + if (text == null || text.trim().isEmpty()) { + return ResponseEntity.badRequest().body("Text is required for embedding generation"); + } + } + + // Prepare OpenAI API request + HttpHeaders headers = new HttpHeaders(); + headers.set("Authorization", "Bearer " + integrationDTO.getApiToken()); + headers.set("Content-Type", "application/json"); + + Map requestBody = new HashMap<>(); + requestBody.put("input", text); + requestBody.put("model", "text-embedding-3-small"); + + HttpEntity> entity = new HttpEntity<>(requestBody, headers); + + log.debug("Generating embedding for user: {}, text length: {}", + operatingUser.getUsername(), text.length()); + + // Make API call to OpenAI + ResponseEntity response = restTemplate.exchange( + openAiApiUrl, HttpMethod.POST, entity, Map.class); + + if (response.getStatusCode().is2xxSuccessful() && response.getBody() != null) { + Map responseBody = response.getBody(); + @SuppressWarnings("unchecked") + List> data = (List>) responseBody.get("data"); + + if (data != null && !data.isEmpty()) { + @SuppressWarnings("unchecked") + List embedding = (List) data.get(0).get("embedding"); + + // Convert to float array for consistency with database storage + float[] result = new float[embedding.size()]; + for (int i = 0; i < embedding.size(); i++) { + result[i] = embedding.get(i).floatValue(); + } + + Map responseMap = new HashMap<>(); + responseMap.put("embedding", result); + responseMap.put("dimensions", result.length); + responseMap.put("text_length", text.length()); + + log.debug("Generated embedding with {} dimensions for user: {}", + result.length, operatingUser.getUsername()); + + return ResponseEntity.ok(responseMap); + } + } + + log.warn("Failed to generate embedding - unexpected response format"); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR) + .body("Failed to generate embedding"); + + } catch (Exception e) { + log.error("Error generating embedding for user: {}", operatingUser.getUsername(), e); + return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR) + .body("Error generating embedding: " + e.getMessage()); + } + } + + /** + * Generate embeddings for multiple texts in batch + */ + @PostMapping("/generate/batch") + public ResponseEntity generateEmbeddingBatch( + @RequestHeader("Authorization") String token, + @RequestBody Map request, + HttpServletRequest httpRequest, + HttpServletResponse httpResponse) { + + String compactJwt = token.startsWith("Bearer ") ? token.substring(7) : token; + + if (!keycloakService.validateJwt(compactJwt)) { + log.warn("Invalid Keycloak token for batch embedding generation"); + return ResponseEntity.status(HttpStatus.UNAUTHORIZED).body("Invalid Keycloak token"); + } + + var operatingUser = getOperatingUser(httpRequest, httpResponse); + if (operatingUser == null) { + var username = keycloakService.extractUsername(compactJwt); + operatingUser = userService.getUserByUsername(username); + } + + if (operatingUser == null) { + log.warn("No operating user found for batch embedding generation"); + return ResponseEntity.status(HttpStatus.UNAUTHORIZED).body("User not found"); + } + + @SuppressWarnings("unchecked") + List texts = (List) request.get("texts"); + if (texts == null || texts.isEmpty()) { + return ResponseEntity.badRequest().body("Texts array is required for batch embedding generation"); + } + + Map results = new HashMap<>(); + + // Process each text individually for now (could be optimized for true batch processing) + for (String text : texts) { + Map singleRequest = new HashMap<>(); + singleRequest.put("text", text); + + ResponseEntity response = generateEmbedding(token, singleRequest, httpRequest, httpResponse); + + if (response.getStatusCode().is2xxSuccessful()) { + @SuppressWarnings("unchecked") + Map responseBody = (Map) response.getBody(); + if (responseBody != null && responseBody.containsKey("embedding")) { + float[] embedding = (float[]) responseBody.get("embedding"); + results.put(text, embedding); + } + } + } + + Map batchResponse = new HashMap<>(); + batchResponse.put("embeddings", results); + batchResponse.put("processed_count", results.size()); + batchResponse.put("total_requested", texts.size()); + + log.info("Generated batch embeddings: {}/{} successful for user: {}", + results.size(), texts.size(), operatingUser.getUsername()); + + return ResponseEntity.ok(batchResponse); + } + + /** + * Check if embedding service is available + */ + @GetMapping("/status") + public ResponseEntity getEmbeddingServiceStatus( + @RequestHeader("Authorization") String token, + HttpServletRequest httpRequest, + HttpServletResponse httpResponse) { + + String compactJwt = token.startsWith("Bearer ") ? token.substring(7) : token; + + if (!keycloakService.validateJwt(compactJwt)) { + return ResponseEntity.status(HttpStatus.UNAUTHORIZED).body("Invalid Keycloak token"); + } + + var openAiToken = integrationSecurityTokenService.findByConnectionType("openai") + .stream().findFirst().orElse(null); + + Map status = new HashMap<>(); + status.put("available", openAiToken != null); + status.put("integration_configured", openAiToken != null); + status.put("service", "OpenAI Embeddings"); + status.put("model", "text-embedding-3-small"); + + return ResponseEntity.ok(status); + } +} \ No newline at end of file diff --git a/integration-proxy/src/main/java/io/sentrius/sso/controllers/api/JiraProxyController.java b/integration-proxy/src/main/java/io/sentrius/sso/controllers/api/JiraProxyController.java index cfa53b63..3db0cf72 100644 --- a/integration-proxy/src/main/java/io/sentrius/sso/controllers/api/JiraProxyController.java +++ b/integration-proxy/src/main/java/io/sentrius/sso/controllers/api/JiraProxyController.java @@ -4,6 +4,8 @@ import java.util.Optional; import java.util.concurrent.ExecutionException; import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.fasterxml.jackson.databind.node.TextNode; import io.opentelemetry.api.GlobalOpenTelemetry; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.Tracer; @@ -20,6 +22,7 @@ import io.sentrius.sso.core.services.UserService; import io.sentrius.sso.core.services.security.IntegrationSecurityTokenService; import io.sentrius.sso.core.services.security.KeycloakService; +import io.sentrius.sso.core.utils.JsonUtil; import io.sentrius.sso.integrations.exceptions.HttpException; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; @@ -162,7 +165,7 @@ public ResponseEntity fetchJiraIssue( @PostMapping("/rest/api/3/issue/comment") @LimitAccess(applicationAccess = {ApplicationAccessEnum.CAN_LOG_IN}) - public ResponseEntity addComment( + public ResponseEntity addCommentToJiraIssue( @RequestHeader("Authorization") String token, @RequestParam(name="issueKey") String issueKey, @RequestBody CommentRequest commentRequest, @@ -217,6 +220,67 @@ public ResponseEntity addComment( } } + @GetMapping("/rest/api/3/issue/comment") + @LimitAccess(applicationAccess = {ApplicationAccessEnum.CAN_LOG_IN}) + public ResponseEntity getJiraIssueComments( + @RequestHeader("Authorization") String token, + @RequestParam(name="issueKey") String issueKey, + @RequestBody CommentRequest commentRequest, + HttpServletRequest request, + HttpServletResponse response + ) throws JsonProcessingException { + + Span span = tracer.spanBuilder("jira-proxy-add-comment").startSpan(); + try (Scope scope = span.makeCurrent()) { + String compactJwt = token.startsWith("Bearer ") ? token.substring(7) : token; + + if (!keycloakService.validateJwt(compactJwt)) { + log.warn("Invalid Keycloak token"); + return ResponseEntity.status(HttpStatus.SC_UNAUTHORIZED).body("Invalid Keycloak token"); + } + + var operatingUser = getOperatingUser(request, response); + if (null == operatingUser) { + return ResponseEntity.status(HttpStatus.SC_UNAUTHORIZED).body("User not authenticated"); + } + + List jiraIntegrations = integrationSecurityTokenService + .findByConnectionType("jira"); + + if (jiraIntegrations.isEmpty()) { + return ResponseEntity.status(HttpStatus.SC_NOT_FOUND).body("No JIRA integration configured"); + } + + IntegrationSecurityToken jiraIntegration = jiraIntegrations.get(0); + JiraService jiraService = new JiraService(new RestTemplate(), jiraIntegration); + + // Extract comment text from the request + String commentText = extractCommentText(commentRequest); + if (commentText == null || commentText.trim().isEmpty()) { + return ResponseEntity.badRequest().body("Comment text is required"); + } + + List comments = jiraService.getComments(issueKey); + + span.setAttribute("issue.key", issueKey); + span.setAttribute("comment.success", comments != null && !comments.isEmpty()); + + if (comments != null && !comments.isEmpty()) { + ObjectNode responseNode = JsonUtil.MAPPER.createObjectNode(); + responseNode.putArray("comments").addAll(comments.stream() + .map(TextNode::new) + .toList()); + return ResponseEntity.ok(responseNode); + } else { + return ResponseEntity.status(HttpStatus.SC_INTERNAL_SERVER_ERROR) + .body("Failed to add comment to issue"); + } + + } finally { + span.end(); + } + } + @PutMapping("/rest/api/3/issue/assignee") @LimitAccess(applicationAccess = {ApplicationAccessEnum.CAN_LOG_IN}) public ResponseEntity assignJiraIssue( diff --git a/integration-proxy/src/main/java/io/sentrius/sso/controllers/api/MemoryController.java b/integration-proxy/src/main/java/io/sentrius/sso/controllers/api/MemoryController.java new file mode 100644 index 00000000..4a3fa666 --- /dev/null +++ b/integration-proxy/src/main/java/io/sentrius/sso/controllers/api/MemoryController.java @@ -0,0 +1,347 @@ +package io.sentrius.sso.controllers.api; + +import java.time.LocalDateTime; +import java.util.ArrayList; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import com.fasterxml.jackson.core.JsonProcessingException; +import io.opentelemetry.api.GlobalOpenTelemetry; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Scope; +import io.sentrius.sso.config.ApplicationEnvironmentConfig; +import io.sentrius.sso.core.config.SystemOptions; +import io.sentrius.sso.core.controllers.BaseController; +import io.sentrius.sso.core.integrations.external.ExternalIntegrationDTO; +import io.sentrius.sso.core.model.verbs.Endpoint; +import io.sentrius.sso.core.services.ATPLPolicyService; +import io.sentrius.sso.core.services.ErrorOutputService; +import io.sentrius.sso.core.services.UserService; +import io.sentrius.sso.core.services.agents.AgentService; +import io.sentrius.sso.core.services.security.CryptoService; +import io.sentrius.sso.core.services.security.IntegrationSecurityTokenService; +import io.sentrius.sso.core.services.security.KeycloakService; +import io.sentrius.sso.core.services.security.ZeroTrustAccessTokenService; +import io.sentrius.sso.core.services.security.ZeroTrustRequestService; +import io.sentrius.sso.core.services.terminal.SessionTrackingService; +import io.sentrius.sso.core.utils.JsonUtil; +import io.sentrius.sso.genai.GenerativeAPI; +import io.sentrius.sso.genai.Message; +import io.sentrius.sso.genai.model.EmbeddingRequest; +import io.sentrius.sso.genai.model.LLMRequest; +import io.sentrius.sso.genai.model.endpoints.EmbeddingApiRequest; +import io.sentrius.sso.genai.model.endpoints.RawConversationRequest; +import io.sentrius.sso.genai.spring.ai.AgentCommunicationMemoryStore; +import io.sentrius.sso.integrations.exceptions.HttpException; +import io.sentrius.sso.provenance.ProvenanceEvent; +import io.sentrius.sso.provenance.kafka.ProvenanceKafkaProducer; +import io.sentrius.sso.security.ApiKey; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import lombok.extern.slf4j.Slf4j; +import org.apache.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestHeader; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +@RestController +@RequestMapping("/api/v1/memory") +@Slf4j +public class MemoryController extends BaseController { + + final CryptoService cryptoService; + final SessionTrackingService sessionTrackingService; + final KeycloakService keycloakService; + final ATPLPolicyService atplPolicyService; + final ZeroTrustAccessTokenService ztatService; + final ZeroTrustRequestService ztrService; + final IntegrationSecurityTokenService integrationSecurityTokenService; + final AgentService agentService; + private final ApplicationEnvironmentConfig applicationConfig; + final AgentCommunicationMemoryStore agentCommunicationMemoryStore; + final ProvenanceKafkaProducer provenanceKafkaProducer; + + Tracer tracer = GlobalOpenTelemetry.getTracer("io.sentrius.sso"); + + protected MemoryController( + UserService userService, SystemOptions systemOptions, + ErrorOutputService errorOutputService, CryptoService cryptoService, + SessionTrackingService sessionTrackingService, KeycloakService keycloakService, + ATPLPolicyService atplPolicyService, ZeroTrustAccessTokenService ztatService, ZeroTrustRequestService ztrService, + IntegrationSecurityTokenService integrationSecurityTokenService, AgentService agentService, + ApplicationEnvironmentConfig applicationConfig, ProvenanceKafkaProducer provenanceKafkaProducer + ) { + super(userService, systemOptions, errorOutputService); + this.cryptoService = cryptoService; + this.sessionTrackingService = sessionTrackingService; + this.keycloakService = keycloakService; + this.atplPolicyService = atplPolicyService; + this.ztatService = ztatService; + this.ztrService = ztrService; + this.integrationSecurityTokenService = integrationSecurityTokenService; + this.agentService = agentService; + this.applicationConfig = applicationConfig; + agentCommunicationMemoryStore = new AgentCommunicationMemoryStore(agentService); + this.provenanceKafkaProducer = provenanceKafkaProducer; + } + + @PostMapping("/completions") + @Endpoint(description = "Proxy for OpenAI completions endpoint") + // require a registered user with an active ztat + //@LimitAccess(applicationAccess = {ApplicationAccessEnum.CAN_LOG_IN}) + public ResponseEntity chat(@RequestHeader("Authorization") String token, + @RequestHeader("X-Communication-Id") String communicationId, + HttpServletRequest request, HttpServletResponse response, + @RequestBody String rawBody) throws JsonProcessingException, HttpException { + + String compactJwt = token.startsWith("Bearer ") ? token.substring(7) : token; + + + + if (!keycloakService.validateJwt(compactJwt)) { + log.warn("Invalid Keycloak token"); + return ResponseEntity.status(HttpStatus.SC_UNAUTHORIZED).body("Invalid Keycloak token"); + } + + var operatingUser = getOperatingUser(request, response ); + + // Extract agent identity from the JWT + String agentId = keycloakService.extractAgentId(compactJwt); + + if (null == operatingUser) { + log.warn("No operating user found for agent: {}", agentId); + var username = keycloakService.extractUsername(compactJwt); + log.info("Extracted username from JWT: {}", username); + operatingUser = userService.getUserByUsername(username); + + } + + log.info("Operating user: {}", operatingUser); + + // we've reached this point, so we can assume the user is allowed to access OpenAI + + var openAiToken = + integrationSecurityTokenService.findByConnectionType("openai").stream().findFirst().orElse(null); + if (openAiToken == null) { + log.info("no integration"); + return ResponseEntity.status(HttpStatus.SC_UNAUTHORIZED).body("No OpenAI integration found"); + } + + + + ExternalIntegrationDTO externalIntegrationDTO = null; + try { + externalIntegrationDTO = JsonUtil.MAPPER.readValue(openAiToken.getConnectionInfo(), + ExternalIntegrationDTO.class); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + ApiKey key = + ApiKey.builder().apiKey(externalIntegrationDTO.getApiToken()).principal(externalIntegrationDTO.getUsername()).build(); + + GenerativeAPI endpoint = new GenerativeAPI(key); + + + + log.info("Chat request: {}", rawBody); + LLMRequest chatRequest = JsonUtil.MAPPER.readValue(rawBody, LLMRequest.class); + + + var comm = agentService.saveCommunication(communicationId, + operatingUser.getUsername(), + applicationConfig.getServiceName(), + "chat_request", + rawBody + ); + + + ProvenanceEvent event = ProvenanceEvent.builder() + .eventId(communicationId) + .sessionId(communicationId) + .actor(operatingUser.getUsername()) + .triggeringUser("LLM") + .eventType(ProvenanceEvent.EventType.KNOWLEDGE_REQUESTED) + .outputSummary("prompt LLM" + chatRequest.getMessages().get(0).getContent()) + .timestamp(LocalDateTime.now().toInstant(java.time.ZoneOffset.UTC)) + .build(); + provenanceKafkaProducer.send(event); + + event = ProvenanceEvent.builder() + .eventId(communicationId) + .sessionId(communicationId) + .actor("LLM") + .triggeringUser(operatingUser.getUsername()) + .eventType(ProvenanceEvent.EventType.KNOWLEDGE_GENERATED) + .outputSummary("prompt LLM") + .timestamp(LocalDateTime.now().toInstant(java.time.ZoneOffset.UTC)) + .build(); + provenanceKafkaProducer.send(event); + + + + Span span = tracer.spanBuilder("AgentToAgentCommunication").startSpan(); + int retries = 2; + try (Scope scope = span.makeCurrent()) { + HttpException httpException = null; + do { + try { + var resp = endpoint.sample(RawConversationRequest.builder().request(chatRequest).build()); + span.setAttribute("communication.id", comm.get().getId().toString()); + span.setAttribute("source.agent", operatingUser.getUsername()); + span.setAttribute("target.agent", "SYSTEM"); + span.setAttribute("message.type", "interpretation_request"); + return ResponseEntity.ok(resp); + }catch(HttpException e){ + if (e.getMessage().contains("timeout")) { + httpException = e; + } else { + throw e; + } + } + } while(retries-- > 0); + if (null != httpException) { + throw httpException; + } + } catch (ExecutionException | InterruptedException e) { + throw new RuntimeException(e); + } finally { + span.end(); + } + + return null; + } + + @PostMapping("/justify") + // require a registered user with an active ztat + //@LimitAccess(applicationAccess = {ApplicationAccessEnum.CAN_LOG_IN}) + public ResponseEntity justify(@RequestHeader("Authorization") String token, + @RequestHeader("X-Communication-Id") String communicationId, + HttpServletRequest request, HttpServletResponse response, + @RequestBody String rawBody) throws JsonProcessingException, HttpException { + + String compactJwt = token.startsWith("Bearer ") ? token.substring(7) : token; + + + + if (!keycloakService.validateJwt(compactJwt)) { + log.warn("Invalid Keycloak token"); + return ResponseEntity.status(HttpStatus.SC_UNAUTHORIZED).body("Invalid Keycloak token"); + } + + var operatingUser = getOperatingUser(request, response ); + + // Extract agent identity from the JWT + String agentId = keycloakService.extractAgentId(compactJwt); + + if (null == operatingUser) { + log.warn("No operating user found for agent: {}", agentId); + var username = keycloakService.extractUsername(compactJwt); + operatingUser = userService.getUserByUsername(username); + + } + + // we've reached this point, so we can assume the user is allowed to access OpenAI + + var openAiToken = + integrationSecurityTokenService.findByConnectionType("openai").stream().findFirst().orElse(null); + if (openAiToken == null) { + log.info("no integration"); + return ResponseEntity.status(HttpStatus.SC_UNAUTHORIZED).body("No OpenAI integration found"); + } + ExternalIntegrationDTO externalIntegrationDTO = null; + try { + externalIntegrationDTO = JsonUtil.MAPPER.readValue(openAiToken.getConnectionInfo(), + ExternalIntegrationDTO.class); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + ApiKey key = + ApiKey.builder().apiKey(externalIntegrationDTO.getApiToken()).principal(externalIntegrationDTO.getUsername()).build(); + + GenerativeAPI endpoint = new GenerativeAPI(key); + + log.info("Chat request: {}", rawBody); + LLMRequest chatRequest = JsonUtil.MAPPER.readValue(rawBody, LLMRequest.class); + var previousCommunications = agentService.getCommunications( + UUID.fromString(communicationId)); + + /** + * Create a new list of messages and add the previous messages to it + */ + var newMessages = new ArrayList(); + for (var previousCommunication : previousCommunications) { + try { + var message = JsonUtil.MAPPER.readValue(previousCommunication.getPayload(), Message.class); + newMessages.add(message); + } catch (JsonProcessingException e) { + // not a message? + } + } + newMessages.addAll(chatRequest.getMessages()); + chatRequest.setMessages(newMessages); + + var comm = agentService.saveCommunication(communicationId, + operatingUser.getUsername(), + applicationConfig.getServiceName(), + "chat_request", + rawBody + ); + + Span span = tracer.spanBuilder("AgentToAgentCommunication").startSpan(); + try (Scope scope = span.makeCurrent()) { + var resp = endpoint.sample(RawConversationRequest.builder().request(chatRequest).build()); + span.setAttribute("communication.id", comm.get().getId().toString()); + span.setAttribute("source.agent", operatingUser.getUsername()); + span.setAttribute("target.agent", "SYSTEM"); + span.setAttribute("message.type", "interpretation_request"); + return ResponseEntity.ok(resp); + } catch (ExecutionException | InterruptedException e) { + throw new RuntimeException(e); + } finally { + span.end(); + } + + + } + + @PostMapping("/embeddings") + @Endpoint(description = "Proxy for OpenAI embeddings endpoint") + public ResponseEntity getEmbedding(@RequestHeader("Authorization") String token, + @RequestHeader("X-Communication-Id") String communicationId, + HttpServletRequest request, HttpServletResponse response, + @RequestBody String rawBody) throws JsonProcessingException, HttpException { + + String compactJwt = token.startsWith("Bearer ") ? token.substring(7) : token; + + if (!keycloakService.validateJwt(compactJwt)) { + log.warn("Invalid Keycloak token"); + return ResponseEntity.status(HttpStatus.SC_UNAUTHORIZED).body("Invalid Keycloak token"); + } + + var operatingUser = getOperatingUser(request, response); + if (operatingUser == null) { + var username = keycloakService.extractUsername(compactJwt); + operatingUser = userService.getUserByUsername(username); + } + + var openAiToken = integrationSecurityTokenService.findByConnectionType("openai").stream().findFirst().orElse(null); + if (openAiToken == null) { + return ResponseEntity.status(HttpStatus.SC_UNAUTHORIZED).body("No OpenAI integration found"); + } + + var externalIntegrationDTO = JsonUtil.MAPPER.readValue(openAiToken.getConnectionInfo(), ExternalIntegrationDTO.class); + var key = ApiKey.builder().apiKey(externalIntegrationDTO.getApiToken()).principal(externalIntegrationDTO.getUsername()).build(); + var generativeAPI = new GenerativeAPI(key); + + EmbeddingRequest embeddingRequest = JsonUtil.MAPPER.readValue(rawBody, EmbeddingRequest.class); + + EmbeddingApiRequest embeddingApiRequest = EmbeddingApiRequest.builder().input(embeddingRequest.getInput()).model(embeddingRequest.getModel()).build(); + // Example payload: {"input": "get user endpoint", "model": "text-embedding-3-small"} + var resp = generativeAPI.getEmbedding(embeddingApiRequest); + + return ResponseEntity.ok(resp); + } +} diff --git a/integration-proxy/src/main/java/io/sentrius/sso/controllers/api/OpenAIProxyController.java b/integration-proxy/src/main/java/io/sentrius/sso/controllers/api/OpenAIProxyController.java index 0a83ee29..d30d5da7 100644 --- a/integration-proxy/src/main/java/io/sentrius/sso/controllers/api/OpenAIProxyController.java +++ b/integration-proxy/src/main/java/io/sentrius/sso/controllers/api/OpenAIProxyController.java @@ -13,6 +13,7 @@ import io.sentrius.sso.core.config.SystemOptions; import io.sentrius.sso.core.controllers.BaseController; import io.sentrius.sso.core.integrations.external.ExternalIntegrationDTO; +import io.sentrius.sso.core.model.verbs.Endpoint; import io.sentrius.sso.core.services.ATPLPolicyService; import io.sentrius.sso.core.services.ErrorOutputService; import io.sentrius.sso.core.services.UserService; @@ -27,7 +28,9 @@ import io.sentrius.sso.genai.GenerativeAPI; import io.sentrius.sso.genai.Message; +import io.sentrius.sso.genai.model.EmbeddingRequest; import io.sentrius.sso.genai.model.LLMRequest; +import io.sentrius.sso.genai.model.endpoints.EmbeddingApiRequest; import io.sentrius.sso.genai.model.endpoints.RawConversationRequest; import io.sentrius.sso.genai.spring.ai.AgentCommunicationMemoryStore; import io.sentrius.sso.integrations.exceptions.HttpException; @@ -87,6 +90,7 @@ protected OpenAIProxyController( } @PostMapping("/completions") + @Endpoint(description = "Proxy for OpenAI completions endpoint") // require a registered user with an active ztat //@LimitAccess(applicationAccess = {ApplicationAccessEnum.CAN_LOG_IN}) public ResponseEntity chat(@RequestHeader("Authorization") String token, @@ -303,4 +307,42 @@ public ResponseEntity justify(@RequestHeader("Authorization") String token, } + + @PostMapping("/embeddings") + @Endpoint(description = "Proxy for OpenAI embeddings endpoint") + public ResponseEntity getEmbedding(@RequestHeader("Authorization") String token, + @RequestHeader("X-Communication-Id") String communicationId, + HttpServletRequest request, HttpServletResponse response, + @RequestBody String rawBody) throws JsonProcessingException, HttpException { + + String compactJwt = token.startsWith("Bearer ") ? token.substring(7) : token; + + if (!keycloakService.validateJwt(compactJwt)) { + log.warn("Invalid Keycloak token"); + return ResponseEntity.status(HttpStatus.SC_UNAUTHORIZED).body("Invalid Keycloak token"); + } + + var operatingUser = getOperatingUser(request, response); + if (operatingUser == null) { + var username = keycloakService.extractUsername(compactJwt); + operatingUser = userService.getUserByUsername(username); + } + + var openAiToken = integrationSecurityTokenService.findByConnectionType("openai").stream().findFirst().orElse(null); + if (openAiToken == null) { + return ResponseEntity.status(HttpStatus.SC_UNAUTHORIZED).body("No OpenAI integration found"); + } + + var externalIntegrationDTO = JsonUtil.MAPPER.readValue(openAiToken.getConnectionInfo(), ExternalIntegrationDTO.class); + var key = ApiKey.builder().apiKey(externalIntegrationDTO.getApiToken()).principal(externalIntegrationDTO.getUsername()).build(); + var generativeAPI = new GenerativeAPI(key); + + EmbeddingRequest embeddingRequest = JsonUtil.MAPPER.readValue(rawBody, EmbeddingRequest.class); + + EmbeddingApiRequest embeddingApiRequest = EmbeddingApiRequest.builder().input(embeddingRequest.getInput()).model(embeddingRequest.getModel()).build(); + // Example payload: {"input": "get user endpoint", "model": "text-embedding-3-small"} + var resp = generativeAPI.getEmbedding(embeddingApiRequest); + + return ResponseEntity.ok(resp); + } } diff --git a/integration-proxy/src/main/java/io/sentrius/sso/services/QdrantMemoryStore.java b/integration-proxy/src/main/java/io/sentrius/sso/services/QdrantMemoryStore.java new file mode 100644 index 00000000..394db5e4 --- /dev/null +++ b/integration-proxy/src/main/java/io/sentrius/sso/services/QdrantMemoryStore.java @@ -0,0 +1,60 @@ +package io.sentrius.sso.services; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import com.fasterxml.jackson.databind.JsonNode; +import io.sentrius.sso.core.data.VectorMemoryStore; +import io.sentrius.sso.core.data.VectorResult; +import org.springframework.http.ResponseEntity; +import org.springframework.web.client.RestTemplate; + +public class QdrantMemoryStore implements VectorMemoryStore { + + private final RestTemplate restTemplate = new RestTemplate(); + private final String baseUrl; + + public QdrantMemoryStore(String baseUrl) { + this.baseUrl = baseUrl; + } + + @Override + public void upsert(String collection, String id, float[] vector, Map payload) { + var request = Map.of("points", List.of(Map.of( + "id", id, + "vector", vector, + "payload", payload + ))); + restTemplate.postForEntity(baseUrl + "/collections/" + collection + "/points", request, Void.class); + } + + @Override + public List search(String collection, float[] queryVector, int topK, Map filter) { + Map body = new HashMap<>(); + body.put("vector", queryVector); + body.put("top", topK); + if (filter != null && !filter.isEmpty()) { + body.put("filter", Map.of("must", filter.entrySet().stream() + .map(e -> Map.of("key", e.getKey(), "match", Map.of("value", e.getValue()))) + .toList())); + } + + ResponseEntity resp = restTemplate.postForEntity( + baseUrl + "/collections/" + collection + "/points/search", + body, + JsonNode.class + ); + + JsonNode results = resp.getBody(); + List vectorResults = new ArrayList<>(); + for (JsonNode r : results.get("result")) { + vectorResults.add(new VectorResult( + r.get("id").asText(), + (float) r.get("score").asDouble(), + r.get("payload") + )); + } + return vectorResults; + } +} diff --git a/integration-proxy/src/test/java/io/sentrius/sso/controllers/api/JiraProxyControllerTest.java b/integration-proxy/src/test/java/io/sentrius/sso/controllers/api/JiraProxyControllerTest.java index 1abb65ef..641c5376 100644 --- a/integration-proxy/src/test/java/io/sentrius/sso/controllers/api/JiraProxyControllerTest.java +++ b/integration-proxy/src/test/java/io/sentrius/sso/controllers/api/JiraProxyControllerTest.java @@ -172,7 +172,7 @@ void addCommentReturnsUnauthorizedWhenTokenIsInvalid() throws Exception { commentRequest.setText("Test comment"); // When - ResponseEntity result = jiraProxyController.addComment( + ResponseEntity result = jiraProxyController.addCommentToJiraIssue( invalidToken, "TEST-123", commentRequest, request, response ); diff --git a/llm-core/src/main/java/io/sentrius/sso/genai/model/EmbeddingRequest.java b/llm-core/src/main/java/io/sentrius/sso/genai/model/EmbeddingRequest.java new file mode 100644 index 00000000..36c5924a --- /dev/null +++ b/llm-core/src/main/java/io/sentrius/sso/genai/model/EmbeddingRequest.java @@ -0,0 +1,39 @@ +package io.sentrius.sso.genai.model; + +import java.util.List; +import java.util.Map; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.sentrius.sso.genai.Message; +import io.sentrius.sso.genai.api.BaseGenerativeRequest; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.experimental.SuperBuilder; + +/** + *

+ * Inspired by LiLittleCat's ChatCopmletionRequestBody + *

+ * see: + * https://platform.openai.com/docs/api-reference/chat + * + * borrowed from LiLittleCat + * + * @since 2023/3/2 + */ +@Data +@SuperBuilder(toBuilder = true) +@NoArgsConstructor +@AllArgsConstructor +@JsonIgnoreProperties(ignoreUnknown = true) +public class EmbeddingRequest extends BaseGenerativeRequest { + /** + * Required + *

+ * The messages to generate chat completions for, in the . + */ + @JsonProperty(value = "input") + private String input; +} diff --git a/llm-core/src/main/java/io/sentrius/sso/genai/model/EmbeddingResponse.java b/llm-core/src/main/java/io/sentrius/sso/genai/model/EmbeddingResponse.java new file mode 100644 index 00000000..332d15cf --- /dev/null +++ b/llm-core/src/main/java/io/sentrius/sso/genai/model/EmbeddingResponse.java @@ -0,0 +1,22 @@ +package io.sentrius.sso.genai.model; + +import java.util.List; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +@Getter +@Setter +@Data +@Builder +@NoArgsConstructor +@AllArgsConstructor +public class EmbeddingResponse { + private String model; + private int promptTokens; + private int totalTokens; + private List embedding; +} \ No newline at end of file diff --git a/llm-core/src/main/java/io/sentrius/sso/genai/model/endpoints/EmbeddingApiRequest.java b/llm-core/src/main/java/io/sentrius/sso/genai/model/endpoints/EmbeddingApiRequest.java new file mode 100644 index 00000000..34a1c75b --- /dev/null +++ b/llm-core/src/main/java/io/sentrius/sso/genai/model/endpoints/EmbeddingApiRequest.java @@ -0,0 +1,74 @@ +package io.sentrius.sso.genai.model.endpoints; + +import java.util.ArrayList; +import java.util.List; +import io.sentrius.sso.genai.Message; +import io.sentrius.sso.genai.model.ApiEndPointRequest; +import io.sentrius.sso.genai.model.EmbeddingRequest; +import io.sentrius.sso.genai.model.LLMRequest; +import io.sentrius.sso.genai.model.LLMResponse; +import lombok.Builder; +import lombok.Data; +import lombok.experimental.SuperBuilder; + +/** + * Represents a request to the OpenAI Chat API endpoint. + * + * This class provides a convenient way to build a request to the OpenAI Chat API. It includes methods to set the input + * text, the model to use, and the parameters for the request, among others. Once the request is built, it can be sent + * using the {@link ChatApiEndpoint#send(EmbeddingApiRequest)} method. + * + * Example usage: + * + *

{@code
+ * ChatApiEndpointRequest request = new ChatApiEndpointRequest.builder().model("davinci").input("Hello, world!")
+ *         .build();
+ *
+ * ChatApiEndpoint endpoint = new ChatApiEndpoint(apiKey);
+ * ChatApiResponse response = endpoint.send(request);
+ * }
+ * + */ +@Data +@SuperBuilder +public class EmbeddingApiRequest extends ApiEndPointRequest { + + public static final String API_ENDPOINT = "https://api.openai.com/v1/embeddings"; + + + @Override + public String getEndpoint() { + return API_ENDPOINT; + } + + @Builder.Default + private String input = ""; + + @Builder.Default + private String model = "text-embedding-3-small"; + + /** + * Creates a new instance of the ChatApiEndpoint with the specified API key. + * + * This method is used to create a new instance of the ChatApiEndpoint with the specified API key. The API key is + * required to send requests to the OpenAI Chat API endpoint. If the API key is invalid or not provided, an + * IllegalArgumentException will be thrown. + * + * Example usage: + * + *
{@code
+     * ChatApiEndpoint endpoint = ChatApiEndpoint.create("my-api-key");
+     * }
+ * + * + * @return A new instance of the ChatApiEndpoint. + * + * @throws IllegalArgumentException + * If the API key is null or empty. + */ + @Override + public Object create() { + return EmbeddingRequest.builder().input(input).model(model).build(); + } + +} diff --git a/llm-dataplane/src/main/java/io/sentrius/sso/genai/GenerativeAPI.java b/llm-dataplane/src/main/java/io/sentrius/sso/genai/GenerativeAPI.java index 5f767633..baaa4166 100644 --- a/llm-dataplane/src/main/java/io/sentrius/sso/genai/GenerativeAPI.java +++ b/llm-dataplane/src/main/java/io/sentrius/sso/genai/GenerativeAPI.java @@ -6,6 +6,8 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import io.sentrius.sso.genai.model.EmbeddingRequest; +import io.sentrius.sso.genai.model.endpoints.EmbeddingApiRequest; import io.sentrius.sso.security.TokenProvider; import io.sentrius.sso.genai.model.ApiEndPointRequest; import io.sentrius.sso.integrations.exceptions.HttpException; @@ -68,9 +70,7 @@ String buildRequestBody(final ApiEndPointRequest request) { /** * ask for response message * - * @param apiRequest - * Api Request object - * + * @param apiRequest Api Request object * @return ChatCompletionResponseBody */ public String sample(final ApiEndPointRequest apiRequest) throws HttpException { @@ -111,4 +111,8 @@ public String sample(final ApiEndPointRequest apiRequest) throws HttpException { public T sample(final ApiEndPointRequest apiRequest, Class clazz) throws HttpException, JsonProcessingException { return (T) objectMapper.readValue(sample(apiRequest), clazz); } + + public String getEmbedding(final EmbeddingApiRequest request) throws HttpException { + return sample(request); + } } diff --git a/pom.xml b/pom.xml index 3cad2d93..9b722709 100644 --- a/pom.xml +++ b/pom.xml @@ -55,6 +55,7 @@ 1.0.0-M7 5.28.5 24.0.0 + 6.6.11.Final @@ -79,6 +80,11 @@ commons-codec ${commons-codec-version}
+ + org.apache.accumulo + accumulo-access + 1.0.0-beta + org.apache.commons commons-collections4 @@ -114,11 +120,17 @@ ${lombok.version} provided + + + org.hibernate.orm + hibernate-vector + ${hibernate-vector.version} + io.github.classgraph classgraph @@ -145,11 +157,12 @@ ${spring.boot.version} true + org.springframework.boot spring-boot-starter-test @@ -366,15 +379,13 @@ - + + org.springframework.ai + spring-ai-bom + ${spring-ai-version} + pom + import + diff --git a/python-agent/README.md b/python-agent/README.md index f8cfc432..084db8d8 100644 --- a/python-agent/README.md +++ b/python-agent/README.md @@ -64,7 +64,7 @@ context: | { "previousOperation": "", "nextOperation": "", - "terminalSummaryForLLM": "", + "summaryForLLM": "", "responseForUser": "" } ``` diff --git a/python-agent/agents/chat_helper/chat_helper_agent.py b/python-agent/agents/chat_helper/chat_helper_agent.py index 83f559f0..c53ba982 100644 --- a/python-agent/agents/chat_helper/chat_helper_agent.py +++ b/python-agent/agents/chat_helper/chat_helper_agent.py @@ -53,7 +53,7 @@ def _process_chat_request(self, task_data: Optional[Dict[str, Any]]) -> Dict[str return { "previousOperation": "initialization", "nextOperation": "waiting_for_user_input", - "terminalSummaryForLLM": "Chat helper agent initialized and ready", + "summaryForLLM": "Chat helper agent initialized and ready", "responseForUser": "Hello! I'm your chat helper agent. How can I assist you today?" } @@ -65,7 +65,7 @@ def _process_chat_request(self, task_data: Optional[Dict[str, Any]]) -> Dict[str return { "previousOperation": "user_message_received", "nextOperation": "generate_response", - "terminalSummaryForLLM": f"User asked: {user_message}", + "summaryForLLM": f"User asked: {user_message}", "responseForUser": f"I received your message: '{user_message}'. I'm a helpful chat assistant ready to help!" } diff --git a/python-agent/chat-helper.yaml b/python-agent/chat-helper.yaml index f8ecd04f..2117044a 100644 --- a/python-agent/chat-helper.yaml +++ b/python-agent/chat-helper.yaml @@ -5,6 +5,6 @@ context: | { "previousOperation": "", "nextOperation": "", - "terminalSummaryForLLM": "", + "summaryForLLM": "", "responseForUser": "" } \ No newline at end of file diff --git a/sentrius-chart-launcher/templates/configmap.yaml b/sentrius-chart-launcher/templates/configmap.yaml index 7464b54b..24256c60 100644 --- a/sentrius-chart-launcher/templates/configmap.yaml +++ b/sentrius-chart-launcher/templates/configmap.yaml @@ -13,7 +13,7 @@ data: { "previousOperation": "", "nextOperation": "", - "terminalSummaryForLLM": "", + "summaryForLLM": "", "responseForUser": "" } assessor-config.yaml: | diff --git a/sentrius-chart/templates/qdrant-deployment.yaml b/sentrius-chart/templates/qdrant-deployment.yaml new file mode 100644 index 00000000..286edaf4 --- /dev/null +++ b/sentrius-chart/templates/qdrant-deployment.yaml @@ -0,0 +1,48 @@ +# qdrant-deployment.yaml +apiVersion: apps/v1 +kind: Deployment +metadata: + name: qdrant +spec: + replicas: 1 + selector: + matchLabels: + app: qdrant + template: + metadata: + labels: + app: qdrant + spec: + containers: + - name: qdrant + image: "{{ .Values.qdrant.image.repository }}:{{ .Values.qdrant.image.tag }}" + imagePullPolicy: "{{ .Values.qdrant.image.pullPolicy }}" + ports: + - containerPort: {{ .Values.qdrant.port }} + volumeMounts: + - name: qdrant-storage + mountPath: /qdrant/storage + + volumes: + - name: qdrant-storage + persistentVolumeClaim: + claimName: qdrant-pvc +--- +apiVersion: networking.k8s.io/v1 +kind: NetworkPolicy +metadata: + name: allow-qdrant-from-integration-proxy +spec: + podSelector: + matchLabels: + app: qdrant + ingress: + - from: + - podSelector: + matchLabels: + app: integration-proxy + ports: + - protocol: TCP + port: {{ .Values.qdrant.port }} + policyTypes: + - Ingress diff --git a/sentrius-chart/templates/qdrant-pvc.yaml b/sentrius-chart/templates/qdrant-pvc.yaml new file mode 100644 index 00000000..b2ff8b8c --- /dev/null +++ b/sentrius-chart/templates/qdrant-pvc.yaml @@ -0,0 +1,12 @@ +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: qdrant-pvc + labels: + app: qdrant +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: {{ .Values.qdrant.storageSize | default "10Gi" }} diff --git a/sentrius-chart/templates/qdrant-service.yaml b/sentrius-chart/templates/qdrant-service.yaml new file mode 100644 index 00000000..613a04c7 --- /dev/null +++ b/sentrius-chart/templates/qdrant-service.yaml @@ -0,0 +1,11 @@ +apiVersion: v1 +kind: Service +metadata: + name: qdrant +spec: + type: ClusterIP + selector: + app: qdrant + ports: + - port: {{ .Values.qdrant.port }} + targetPort: {{ .Values.qdrant.port }} diff --git a/sentrius-chart/values.yaml b/sentrius-chart/values.yaml index 17dfbdbc..e0155b76 100644 --- a/sentrius-chart/values.yaml +++ b/sentrius-chart/values.yaml @@ -182,11 +182,21 @@ launcherservice: azure: service.beta.kubernetes.io/azure-load-balancer-internal: "true" +qdrant: + image: + repository: qdrant/qdrant + tag: v1.9.0 + pullPolicy: IfNotPresent + port: 6333 + storageSize: 10Gi + resources: {} + # PostgreSQL configuration postgres: image: - repository: postgres - tag: 15 + repository: pgvector/pgvector + tag: pg15 + pullPolicy: IfNotPresent port: 5432 storageSize: 10Gi env: