Skip to content

Commit a8288aa

Browse files
committed
fix(ENGKNOW-3224): Add adaptive S3 multipart upload part size, allows upload of larger files (700GB vs 50GB)
1 parent da6df57 commit a8288aa

2 files changed

Lines changed: 340 additions & 9 deletions

File tree

drivers/src/main/java/org/gorpipe/s3/driver/S3MultipartOutputStream.java

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
package org.gorpipe.s3.driver;
22

3-
import software.amazon.awssdk.core.async.AsyncRequestBody;
4-
import software.amazon.awssdk.services.s3.S3AsyncClient;
53
import software.amazon.awssdk.services.s3.model.*;
64

75
import java.io.IOException;
@@ -17,24 +15,42 @@
1715
public abstract class S3MultipartOutputStream extends OutputStream {
1816
private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(S3MultipartOutputStream.class);
1917

20-
private static final int PART_SIZE = 5 * 1024 * 1024; // 5 MiB
21-
private static final int MAX_RETRIES = 3;
22-
private static final int MAX_CONCURRENT_UPLOADS = 4;
23-
private static final int RETRY_SLEEP_BASE_MS = 1000;
18+
public static final int MIN_S3_PART_SIZE = 5 * 1024 * 1024; // 5 MiB, min parts size allowed by the S3 dreiver
19+
public static final int INIT_PART_SIZE = Math.max(Integer.parseInt(System.getProperty(
20+
"gor.s3.multipart.initpartsize", String.valueOf(MIN_S3_PART_SIZE))), MIN_S3_PART_SIZE);
21+
public static final int MAX_PARTS = 10_000;
22+
23+
public static final int MAX_RETRIES = 3;
24+
public static final int MAX_CONCURRENT_UPLOADS = 4;
25+
public static final int RETRY_SLEEP_BASE_MS = 1000;
2426

2527
private final String bucket;
2628
private final String key;
29+
protected int currentPartSize = INIT_PART_SIZE;
30+
private ByteBuffer buffer;
2731
private final List<CompletedPart> completedParts = new ArrayList<>();
28-
private final ByteBuffer buffer = ByteBuffer.allocate(PART_SIZE);
2932
private final ExecutorService executor = Executors.newFixedThreadPool(MAX_CONCURRENT_UPLOADS);
3033

3134
private String uploadId;
3235
private int partNumber = 1;
3336
private boolean closed = false;
3437

35-
public S3MultipartOutputStream(String bucket, String key) throws IOException {
38+
public S3MultipartOutputStream(String bucket, String key) {
3639
this.bucket = bucket;
3740
this.key = key;
41+
buffer = ByteBuffer.allocate(currentPartSize);
42+
}
43+
44+
public int getUploadPartDone() {
45+
return completedParts.size();
46+
}
47+
48+
public int getUploadPartStarted() {
49+
return partNumber - 1;
50+
}
51+
52+
public int getCurrentPartSize() {
53+
return currentPartSize;
3854
}
3955

4056
abstract protected CreateMultipartUploadResponse sendCreateMultipartUploadRequest(CreateMultipartUploadRequest req) throws ExecutionException, InterruptedException;
@@ -83,7 +99,7 @@ private void uploadPartAsync() throws IOException {
8399
byte[] partData = new byte[buffer.limit()];
84100
buffer.get(partData);
85101
buffer.clear();
86-
102+
adaptPartSize();
87103
int currentPart = partNumber++;
88104

89105
if (uploadId == null) {
@@ -176,4 +192,15 @@ private void abortMultipartUpload() {
176192
logger.warn("Failed to abort multipart upload (ignoring exception)", ignored);
177193
}
178194
}
195+
196+
/*
197+
Simply quaddrouble the buffer every 2500 parts, given 5MB start buffer, allows for approx 850GB file, with
198+
the first segment (2500 parts) allowing 12GB files, and max buffer approx 300Mb.
199+
*/
200+
private void adaptPartSize() {
201+
if (partNumber < MAX_PARTS && partNumber % 2500 == 0) {
202+
currentPartSize *= 4;
203+
buffer = ByteBuffer.allocate(currentPartSize);
204+
}
205+
}
179206
}
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
package org.gorpipe.s3.driver;
2+
3+
import org.junit.Test;
4+
import software.amazon.awssdk.services.s3.model.*;
5+
6+
import java.io.ByteArrayOutputStream;
7+
import java.io.IOException;
8+
import java.lang.reflect.Field;
9+
import java.lang.reflect.Method;
10+
import java.nio.ByteBuffer;
11+
import java.security.MessageDigest;
12+
import java.util.ArrayList;
13+
import java.util.List;
14+
import java.util.Random;
15+
16+
import static org.junit.Assert.*;
17+
18+
public class UTestS3MultipartOutputStream {
19+
20+
/**
21+
* A test-only concrete implementation that captures uploaded parts in memory.
22+
*/
23+
static class TestMultipartOutputStream extends S3MultipartOutputStream {
24+
final List<byte[]> uploadedParts = new ArrayList<>();
25+
final ByteArrayOutputStream allBytes = new ByteArrayOutputStream();
26+
boolean uploadAborted = false;
27+
boolean uploadCompleted = false;
28+
29+
TestMultipartOutputStream(String bucket, String key) {
30+
super(bucket, key);
31+
}
32+
33+
TestMultipartOutputStream(String bucket, String key, int partSize) {
34+
super(bucket, key);
35+
this.currentPartSize = partSize;
36+
try {
37+
Field bufferField = S3MultipartOutputStream.class.getDeclaredField("buffer");
38+
bufferField.setAccessible(true);
39+
bufferField.set(this, ByteBuffer.allocate(partSize));
40+
} catch (ReflectiveOperationException e) {
41+
throw new RuntimeException(e);
42+
}
43+
}
44+
45+
@Override
46+
protected CreateMultipartUploadResponse sendCreateMultipartUploadRequest(CreateMultipartUploadRequest req) {
47+
return CreateMultipartUploadResponse.builder()
48+
.uploadId("test-upload-id")
49+
.build();
50+
}
51+
52+
@Override
53+
protected UploadPartResponse sendUploadPartRequest(UploadPartRequest req, byte[] data) throws Exception {
54+
uploadedParts.add(data.clone());
55+
allBytes.write(data, 0, data.length);
56+
return UploadPartResponse.builder()
57+
.eTag("etag-" + req.partNumber())
58+
.build();
59+
}
60+
61+
@Override
62+
protected CompleteMultipartUploadResponse sendCompleteMultipartUploadRequest(CompleteMultipartUploadRequest req) {
63+
uploadCompleted = true;
64+
return CompleteMultipartUploadResponse.builder().build();
65+
}
66+
67+
@Override
68+
protected AbortMultipartUploadResponse sendAbortMultipartUploadRequest(AbortMultipartUploadRequest req) {
69+
uploadAborted = true;
70+
return AbortMultipartUploadResponse.builder().build();
71+
}
72+
}
73+
74+
/**
75+
* Subclass that fails uploads to test abort behavior.
76+
*/
77+
static class FailingMultipartOutputStream extends TestMultipartOutputStream {
78+
FailingMultipartOutputStream(String bucket, String key) throws IOException {
79+
super(bucket, key);
80+
}
81+
82+
@Override
83+
protected UploadPartResponse sendUploadPartRequest(UploadPartRequest req, byte[] data) throws Exception {
84+
throw new RuntimeException("Simulated upload failure");
85+
}
86+
}
87+
88+
@Test
89+
public void testSmallWrite() throws IOException {
90+
try (TestMultipartOutputStream out = new TestMultipartOutputStream("bucket", "key")) {
91+
byte[] data = new byte[100];
92+
for (int i = 0; i < data.length; i++) data[i] = (byte) (i % 127);
93+
out.write(data);
94+
assertEquals(0, out.getUploadPartStarted());
95+
assertEquals(0, out.getUploadPartDone());
96+
}
97+
}
98+
99+
@Test
100+
public void testWriteExactlyOnePartSize() throws IOException {
101+
TestMultipartOutputStream out = new TestMultipartOutputStream("bucket", "key");
102+
byte[] data = new byte[S3MultipartOutputStream.INIT_PART_SIZE];
103+
for (int i = 0; i < data.length; i++) data[i] = (byte) (i % 256);
104+
out.write(data);
105+
out.close();
106+
107+
assertEquals(1, out.getUploadPartDone());
108+
assertEquals(1, out.getUploadPartStarted());
109+
assertEquals(S3MultipartOutputStream.INIT_PART_SIZE, out.uploadedParts.get(0).length);
110+
assertTrue(out.uploadCompleted);
111+
}
112+
113+
@Test
114+
public void testWriteMultipleParts() throws IOException {
115+
int partSize = S3MultipartOutputStream.INIT_PART_SIZE;
116+
TestMultipartOutputStream out = new TestMultipartOutputStream("bucket", "key");
117+
118+
// Write 2.5 parts worth of data
119+
int totalSize = partSize * 2 + partSize / 2;
120+
byte[] data = new byte[totalSize];
121+
for (int i = 0; i < data.length; i++) data[i] = (byte) (i % 256);
122+
out.write(data);
123+
out.close();
124+
125+
// Should produce 3 parts: 2 full + 1 partial
126+
assertEquals(3, out.getUploadPartDone());
127+
assertEquals(3, out.getUploadPartStarted());
128+
assertEquals(partSize, out.uploadedParts.get(0).length);
129+
assertEquals(partSize, out.uploadedParts.get(1).length);
130+
assertEquals(partSize / 2, out.uploadedParts.get(2).length);
131+
assertTrue(out.uploadCompleted);
132+
}
133+
134+
@Test
135+
public void testDataIntegrity() throws IOException {
136+
TestMultipartOutputStream out = new TestMultipartOutputStream("bucket", "key");
137+
138+
int totalSize = S3MultipartOutputStream.INIT_PART_SIZE * 3 + 1000;
139+
byte[] data = new byte[totalSize];
140+
for (int i = 0; i < data.length; i++) data[i] = (byte) (i % 256);
141+
out.write(data);
142+
out.close();
143+
144+
byte[] result = out.allBytes.toByteArray();
145+
assertEquals(totalSize, result.length);
146+
assertArrayEquals(data, result);
147+
}
148+
149+
@Test
150+
public void testSingleByteWrite() throws IOException {
151+
TestMultipartOutputStream out = new TestMultipartOutputStream("bucket", "key");
152+
out.write(42);
153+
out.close();
154+
155+
assertEquals(1, out.getUploadPartDone());
156+
assertEquals(1, out.uploadedParts.get(0).length);
157+
assertEquals(42, out.uploadedParts.get(0)[0]);
158+
assertTrue(out.uploadCompleted);
159+
}
160+
161+
@Test
162+
public void testCloseIdempotent() throws IOException {
163+
TestMultipartOutputStream out = new TestMultipartOutputStream("bucket", "key");
164+
out.write(new byte[]{1, 2, 3});
165+
out.close();
166+
int partsAfterFirstClose = out.getUploadPartDone();
167+
out.close(); // should be no-op
168+
assertEquals(partsAfterFirstClose, out.getUploadPartDone());
169+
}
170+
171+
@Test
172+
public void testFailedUploadAbortsMultipart() throws IOException {
173+
FailingMultipartOutputStream out = new FailingMultipartOutputStream("bucket", "key");
174+
byte[] data = new byte[S3MultipartOutputStream.INIT_PART_SIZE + 1]; // enough to trigger a part upload
175+
176+
try {
177+
out.write(data);
178+
fail("Expected IOException");
179+
} catch (IOException e) {
180+
assertTrue(e.getMessage().contains("Failed to upload part"));
181+
}
182+
assertTrue(out.uploadAborted);
183+
}
184+
185+
@Test
186+
public void testInitialPartSize() throws IOException {
187+
TestMultipartOutputStream out = new TestMultipartOutputStream("bucket", "key");
188+
assertEquals(S3MultipartOutputStream.INIT_PART_SIZE, out.getCurrentPartSize());
189+
assertTrue(out.getCurrentPartSize() >= S3MultipartOutputStream.MIN_S3_PART_SIZE);
190+
out.close();
191+
}
192+
193+
@Test
194+
public void testAdaptPartSize() throws Exception {
195+
TestMultipartOutputStream out = new TestMultipartOutputStream("bucket", "key");
196+
197+
Method adaptMethod = S3MultipartOutputStream.class.getDeclaredMethod("adaptPartSize");
198+
adaptMethod.setAccessible(true);
199+
Field partNumberField = S3MultipartOutputStream.class.getDeclaredField("partNumber");
200+
partNumberField.setAccessible(true);
201+
202+
int initialPartSize = out.getCurrentPartSize();
203+
204+
// Set partNumber to 2500 (first trigger point)
205+
partNumberField.set(out, 2500);
206+
adaptMethod.invoke(out);
207+
assertEquals("Part size should quadruple at 2500 parts", initialPartSize * 4, out.getCurrentPartSize());
208+
209+
// Set partNumber to 5000 (second trigger point)
210+
partNumberField.set(out, 5000);
211+
adaptMethod.invoke(out);
212+
assertEquals("Part size should quadruple again at 5000 parts", initialPartSize * 16, out.getCurrentPartSize());
213+
214+
// Set partNumber to 7500 (third trigger point)
215+
partNumberField.set(out, 7500);
216+
adaptMethod.invoke(out);
217+
assertEquals("Part size should quadruple again at 7500 parts", initialPartSize * 64, out.getCurrentPartSize());
218+
219+
out.close();
220+
}
221+
222+
@Test
223+
public void testAdaptPartSizeDoesNotTriggerBetweenBoundaries() throws Exception {
224+
TestMultipartOutputStream out = new TestMultipartOutputStream("bucket", "key");
225+
226+
Method adaptMethod = S3MultipartOutputStream.class.getDeclaredMethod("adaptPartSize");
227+
adaptMethod.setAccessible(true);
228+
Field partNumberField = S3MultipartOutputStream.class.getDeclaredField("partNumber");
229+
partNumberField.setAccessible(true);
230+
231+
int initialPartSize = out.getCurrentPartSize();
232+
233+
// Part numbers that should NOT trigger adaptation
234+
for (int pn : new int[]{1, 100, 1000, 2499, 2501, 3000}) {
235+
partNumberField.set(out, pn);
236+
adaptMethod.invoke(out);
237+
assertEquals("Part size should not change at partNumber " + pn, initialPartSize, out.getCurrentPartSize());
238+
}
239+
240+
out.close();
241+
}
242+
243+
@Test
244+
public void testAdaptPartSizeDoesNotTriggerAtOrAboveMaxParts() throws Exception {
245+
TestMultipartOutputStream out = new TestMultipartOutputStream("bucket", "key");
246+
247+
Method adaptMethod = S3MultipartOutputStream.class.getDeclaredMethod("adaptPartSize");
248+
adaptMethod.setAccessible(true);
249+
Field partNumberField = S3MultipartOutputStream.class.getDeclaredField("partNumber");
250+
partNumberField.setAccessible(true);
251+
252+
int initialPartSize = out.getCurrentPartSize();
253+
254+
// At MAX_PARTS (10000), even though 10000 % 2500 == 0, should not trigger because partNumber >= MAX_PARTS
255+
partNumberField.set(out, S3MultipartOutputStream.MAX_PARTS);
256+
adaptMethod.invoke(out);
257+
assertEquals("Part size should not change at MAX_PARTS", initialPartSize, out.getCurrentPartSize());
258+
259+
out.close();
260+
}
261+
262+
@Test
263+
public void testLargeStreamWithAdaptivePartSizeMd5Integrity() throws Exception {
264+
int smallPartSize = 256; // 256 bytes
265+
266+
TestMultipartOutputStream out = new TestMultipartOutputStream("bucket", "key", smallPartSize);
267+
268+
// With 256-byte parts and adaptive sizing:
269+
// Parts 1-2500: 256 bytes each (adapt fires at partNumber 2500, data already extracted)
270+
// Parts 2501-5000: 1KB each (after first quadruple)
271+
// Parts 5001-6000: 4KB each (after second quadruple)
272+
// Total = 2500*256 + 2500*1024 + 1000*4096 = 7,296,000 bytes
273+
int totalBytes = 2500 * smallPartSize + 2500 * (smallPartSize * 4) + 1000 * (smallPartSize * 16);
274+
275+
Random rng = new Random(42); // fixed seed for reproducibility
276+
MessageDigest inputMd5 = MessageDigest.getInstance("MD5");
277+
278+
// Stream random data in variable-sized chunks
279+
int written = 0;
280+
byte[] chunk = new byte[8192];
281+
while (written < totalBytes) {
282+
int toWrite = Math.min(chunk.length, totalBytes - written);
283+
rng.nextBytes(chunk);
284+
inputMd5.update(chunk, 0, toWrite);
285+
out.write(chunk, 0, toWrite);
286+
written += toWrite;
287+
}
288+
out.close();
289+
290+
assertEquals(6000, out.getUploadPartDone());
291+
assertTrue(out.uploadCompleted);
292+
293+
// Verify adaptive sizing kicked in
294+
assertEquals(smallPartSize * 16, out.getCurrentPartSize());
295+
296+
// Compare MD5 of input vs output
297+
byte[] inputDigest = inputMd5.digest();
298+
MessageDigest outputMd5 = MessageDigest.getInstance("MD5");
299+
outputMd5.update(out.allBytes.toByteArray());
300+
byte[] outputDigest = outputMd5.digest();
301+
302+
assertArrayEquals("MD5 of input and output must match", inputDigest, outputDigest);
303+
}
304+
}

0 commit comments

Comments
 (0)