diff --git a/src/main/java/io/moderne/jsonrpc/formatter/JsonMessageFormatter.java b/src/main/java/io/moderne/jsonrpc/formatter/JsonMessageFormatter.java index bbfd205..7d50c9d 100644 --- a/src/main/java/io/moderne/jsonrpc/formatter/JsonMessageFormatter.java +++ b/src/main/java/io/moderne/jsonrpc/formatter/JsonMessageFormatter.java @@ -17,11 +17,14 @@ import com.fasterxml.jackson.annotation.JsonAutoDetect; import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonToken; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.ObjectWriter; import com.fasterxml.jackson.databind.SerializationFeature; import com.fasterxml.jackson.databind.cfg.ConstructorDetector; import com.fasterxml.jackson.databind.json.JsonMapper; +import com.fasterxml.jackson.databind.util.TokenBuffer; import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; import com.fasterxml.jackson.module.paramnames.ParameterNamesModule; import io.moderne.jsonrpc.JsonRpcError; @@ -33,10 +36,15 @@ import java.io.InputStream; import java.io.OutputStream; import java.lang.reflect.Type; -import java.util.Map; public class JsonMessageFormatter implements MessageFormatter { private final ObjectMapper mapper; + private final ClassValue writerCache = new ClassValue() { + @Override + protected ObjectWriter computeValue(Class type) { + return mapper.writerFor(type); + } + }; public JsonMessageFormatter() { this(JsonMapper.builder() @@ -77,27 +85,90 @@ public JsonMessageFormatter(ObjectMapper mapper) { @Override public JsonRpcMessage deserialize(InputStream in) throws IOException { - Map payload = mapper.readValue(in, new TypeReference>() { - }); - if (payload.containsKey("method")) { - return mapper.convertValue(payload, JsonRpcRequest.class); - } else if (payload.containsKey("error")) { - return mapper.convertValue(payload, JsonRpcError.class); + JsonParser parser = mapper.getFactory().createParser(in); + + Object id = null; + String method = null; + TokenBuffer params = null; + TokenBuffer errorBuffer = null; + Object result = null; + + if (parser.nextToken() != JsonToken.START_OBJECT) { + return JsonRpcError.invalidRequest(null, "Expected JSON object"); } - Object id = payload.get("id"); + + while (parser.nextToken() != JsonToken.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case "jsonrpc": + parser.skipChildren(); + break; + case "id": + id = normalizeId(parser.readValueAs(Object.class)); + break; + case "method": + method = parser.getValueAsString(); + break; + case "params": + params = captureValue(parser); + break; + case "error": + errorBuffer = captureValue(parser); + break; + case "result": + JsonToken token = parser.currentToken(); + if (token == JsonToken.START_OBJECT || token == JsonToken.START_ARRAY) { + result = captureValue(parser); + } else { + result = parser.readValueAs(Object.class); + } + break; + default: + parser.skipChildren(); + break; + } + } + + if (method != null) { + return new JsonRpcRequest(id, method, params); + } else if (errorBuffer != null) { + JsonRpcError.Detail detail = convertValue(errorBuffer, JsonRpcError.Detail.class); + return new JsonRpcError(id, detail); + } + return JsonRpcSuccess.fromPayload(id, result, this); + } + + private TokenBuffer captureValue(JsonParser parser) throws IOException { + TokenBuffer buffer = new TokenBuffer(parser); + buffer.copyCurrentStructure(parser); + return buffer; + } + + private Object normalizeId(Object id) { if (id instanceof Number) { - id = ((Number) id).intValue(); + return ((Number) id).intValue(); } - return JsonRpcSuccess.fromPayload(id, payload.get("result"), this); + return id; } @Override public void serialize(JsonRpcMessage message, OutputStream out) throws IOException { - mapper.writeValue(out, message); + writerCache.get(message.getClass()).writeValue(out, message); } @Override public T convertValue(Object value, Type type) { + if (value instanceof TokenBuffer) { + try { + JsonParser bufferParser = ((TokenBuffer) value).asParser(); + bufferParser.nextToken(); + return mapper.readValue(bufferParser, mapper.getTypeFactory().constructType(type)); + } catch (IOException e) { + throw new RuntimeException("Failed to convert TokenBuffer", e); + } + } return mapper.convertValue(value, mapper.getTypeFactory().constructType(type)); } } diff --git a/src/main/java/io/moderne/jsonrpc/handler/HeaderDelimitedMessageHandler.java b/src/main/java/io/moderne/jsonrpc/handler/HeaderDelimitedMessageHandler.java index d91747b..046b256 100644 --- a/src/main/java/io/moderne/jsonrpc/handler/HeaderDelimitedMessageHandler.java +++ b/src/main/java/io/moderne/jsonrpc/handler/HeaderDelimitedMessageHandler.java @@ -18,6 +18,7 @@ import io.moderne.jsonrpc.JsonRpcError; import io.moderne.jsonrpc.JsonRpcMessage; import io.moderne.jsonrpc.formatter.MessageFormatter; +import io.moderne.jsonrpc.internal.LimitedInputStream; import lombok.RequiredArgsConstructor; import org.jspecify.annotations.Nullable; @@ -35,6 +36,8 @@ @RequiredArgsConstructor public class HeaderDelimitedMessageHandler implements MessageHandler { private static final Pattern CONTENT_LENGTH = Pattern.compile("Content-Length: (\\d+)"); + private static final ThreadLocal SEND_BUFFER = + ThreadLocal.withInitial(() -> new ByteArrayOutputStream(8192)); private final InputStream inputStream; private final OutputStream outputStream; @@ -82,20 +85,11 @@ public JsonRpcMessage receive(MessageFormatter formatter) { } } - byte[] content = new byte[Integer.parseInt(contentLengthMatcher.group(1))]; - for (int totalRead = 0; totalRead < content.length; ) { - int bytesRead = inputStream.read(content, totalRead, content.length - totalRead); - if (bytesRead == -1) { - // Stream ended unexpectedly before reading full content - return JsonRpcError.invalidRequest(null, - "Content length mismatch. Expected " + content.length + - " but received " + totalRead); - } - totalRead += bytesRead; - } - - ByteArrayInputStream bis = new ByteArrayInputStream(content); - return effectiveFormatter.deserialize(bis); + int length = Integer.parseInt(contentLengthMatcher.group(1)); + LimitedInputStream limited = new LimitedInputStream(inputStream, length); + JsonRpcMessage msg = effectiveFormatter.deserialize(limited); + limited.skipRemaining(); + return msg; } catch (IOException e) { return JsonRpcError.invalidRequest(null, e.getMessage()); } @@ -119,16 +113,16 @@ private String readLineFromInputStream() throws IOException { public void send(JsonRpcMessage msg, MessageFormatter formatter) { MessageFormatter effectiveFormatter = this.formatter != null ? this.formatter : formatter; try { - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - effectiveFormatter.serialize(msg, bos); - byte[] content = bos.toByteArray(); - outputStream.write(("Content-Length: " + content.length + "\r\n").getBytes()); + ByteArrayOutputStream buffer = SEND_BUFFER.get(); + buffer.reset(); + effectiveFormatter.serialize(msg, buffer); + outputStream.write(("Content-Length: " + buffer.size() + "\r\n").getBytes()); if (effectiveFormatter.getEncoding() != StandardCharsets.UTF_8) { outputStream.write(("Content-Type: application/vscode-jsonrpc;charset=" + effectiveFormatter.getEncoding().name() + "\r\n").getBytes()); } outputStream.write('\r'); outputStream.write('\n'); - outputStream.write(content); + buffer.writeTo(outputStream); outputStream.flush(); } catch (IOException e) { throw new UncheckedIOException(e); diff --git a/src/main/java/io/moderne/jsonrpc/internal/LimitedInputStream.java b/src/main/java/io/moderne/jsonrpc/internal/LimitedInputStream.java new file mode 100644 index 0000000..b93b393 --- /dev/null +++ b/src/main/java/io/moderne/jsonrpc/internal/LimitedInputStream.java @@ -0,0 +1,94 @@ +/* + * Copyright 2025 the original author or authors. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.moderne.jsonrpc.internal; + +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; + +/** + * An InputStream wrapper that limits the number of bytes that can be read + * from the underlying stream. Once the limit is reached, further reads + * return -1 (EOF). + */ +public class LimitedInputStream extends FilterInputStream { + private long remaining; + + public LimitedInputStream(InputStream in, long limit) { + super(in); + this.remaining = limit; + } + + @Override + public int read() throws IOException { + if (remaining <= 0) { + return -1; + } + int result = in.read(); + if (result != -1) { + remaining--; + } + return result; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + if (remaining <= 0) { + return -1; + } + int toRead = (int) Math.min(len, remaining); + int result = in.read(b, off, toRead); + if (result > 0) { + remaining -= result; + } + return result; + } + + @Override + public long skip(long n) throws IOException { + long toSkip = Math.min(n, remaining); + long skipped = in.skip(toSkip); + remaining -= skipped; + return skipped; + } + + @Override + public int available() throws IOException { + return (int) Math.min(in.available(), remaining); + } + + @Override + public void close() { + // Do not close the underlying stream - we need it for subsequent messages + } + + /** + * Skips any remaining bytes up to the limit. + * Call this after reading to ensure the underlying stream is + * positioned correctly for subsequent reads. + */ + public void skipRemaining() throws IOException { + while (remaining > 0) { + long skipped = skip(remaining); + if (skipped <= 0) { + // skip() didn't work, try reading + if (read() == -1) { + break; + } + } + } + } +}