Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
bdc0d26
implemented Source unit for S3.
mustansir14 Jun 26, 2025
21d9334
use bucket as source unit
mustansir14 Jul 7, 2025
a25afff
remove code duplication, reuse from Chunks
mustansir14 Nov 19, 2025
ea21d02
remove unnecessary change
mustansir14 Nov 19, 2025
9915187
remove unused functions
mustansir14 Nov 19, 2025
c32e12c
revisit tests
mustansir14 Nov 19, 2025
5161090
revert unnecessary change
mustansir14 Nov 19, 2025
ef324d1
change SourceUnitKind to s3_bucket
mustansir14 Nov 20, 2025
84e0cda
handle nil objectCount inside scanBucket
mustansir14 Nov 20, 2025
b5a66d5
handle nil objectCount outside loop
mustansir14 Nov 20, 2025
966007f
add bucket to resume log
mustansir14 Nov 20, 2025
50e5a90
Merge branch 'main' into INS-104-Support-units-in-S3-source
amanfcp Nov 20, 2025
0faa70e
add bucket and role to error log, remove enumerating log
mustansir14 Nov 21, 2025
474172c
Merge branch 'INS-104-Support-units-in-S3-source' of mustansir:mustan…
mustansir14 Nov 21, 2025
10f91ff
implement sub unit resumption
mustansir14 Nov 24, 2025
6bfbc14
add comment to checkpointer for unit scans
mustansir14 Nov 24, 2025
1863659
Merge branch 'main' into INS-104-Support-units-in-S3-source
mustansir14 Nov 25, 2025
5ca4151
implement SourceUnitUnmarshaller on source with the new S3SourceUnit,…
mustansir14 Nov 26, 2025
45a133b
Merge branch 'main' into INS-104-Support-units-in-S3-source
mustansir14 Dec 1, 2025
549e6be
add role to SourceUnitID
mustansir14 Dec 2, 2025
b5cb928
Merge branch 'INS-104-Support-units-in-S3-source' of mustansir:mustan…
mustansir14 Dec 2, 2025
1cee9af
Revert "add role to SourceUnitID"
mustansir14 Dec 2, 2025
6f06776
add role to source unit ID, keep track of resumption using source uni…
mustansir14 Dec 3, 2025
85e681b
Merge branch 'main' into INS-104-Support-units-in-S3-source
mustansir14 Dec 3, 2025
53a91c8
rename bucket -> unitID in UnmarshalSourceUnit
mustansir14 Dec 4, 2025
3e8e6b9
Merge branch 'main' into INS-104-Support-units-in-S3-source
mustansir14 Dec 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 134 additions & 57 deletions pkg/sources/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ type Source struct {
var _ sources.Source = (*Source)(nil)
var _ sources.SourceUnitUnmarshaller = (*Source)(nil)
var _ sources.Validator = (*Source)(nil)
var _ sources.SourceUnitEnumChunker = (*Source)(nil)

// Type returns the type of source
func (s *Source) Type() sourcespb.SourceType { return SourceType }
Expand Down Expand Up @@ -316,16 +317,7 @@ func (s *Source) scanBuckets(

bucketsToScanCount := len(bucketsToScan)
for bucketIdx := pos.index; bucketIdx < bucketsToScanCount; bucketIdx++ {
s.metricsCollector.RecordBucketForRole(role)
bucket := bucketsToScan[bucketIdx]
ctx := context.WithValue(ctx, "bucket", bucket)

if common.IsDone(ctx) {
ctx.Logger().Error(ctx.Err(), "context done, while scanning bucket")
return
}

ctx.Logger().V(3).Info("Scanning bucket")

s.SetProgressComplete(
bucketIdx,
Expand All @@ -334,53 +326,16 @@ func (s *Source) scanBuckets(
s.Progress.EncodedResumeInfo,
)

regionalClient, err := s.getRegionalClientForBucket(ctx, client, role, bucket)
if err != nil {
ctx.Logger().Error(err, "could not get regional client for bucket")
continue
}

errorCount := sync.Map{}

input := &s3.ListObjectsV2Input{Bucket: &bucket}
var startAfter *string
if bucket == pos.bucket && pos.startAfter != "" {
input.StartAfter = &pos.startAfter
startAfter = &pos.startAfter
ctx.Logger().V(3).Info(
"Resuming bucket scan",
"start_after", pos.startAfter,
)
}

pageNumber := 1
paginator := s3.NewListObjectsV2Paginator(regionalClient, input)
for paginator.HasMorePages() {
output, err := paginator.NextPage(ctx)
if err != nil {
if role == "" {
ctx.Logger().Error(err, "could not list objects in bucket")
} else {
// Our documentation blesses specifying a role to assume without specifying buckets to scan, which will
// often cause this to happen a lot (because in that case the scanner tries to scan every bucket in the
// account, but the role probably doesn't have access to all of them). This makes it expected behavior
// and therefore not an error.
ctx.Logger().V(3).Info("could not list objects in bucket", "err", err)
}
break
}
pageMetadata := pageMetadata{
bucket: bucket,
pageNumber: pageNumber,
client: regionalClient,
page: output,
}
processingState := processingState{
errorCount: &errorCount,
objectCount: &objectCount,
}
s.pageChunker(ctx, pageMetadata, processingState, chunksChan)

pageNumber++
}
s.scanBucket(ctx, client, role, bucket, sources.ChanReporter{Ch: chunksChan}, startAfter, &objectCount)
}

s.SetProgressComplete(
Expand All @@ -391,6 +346,75 @@ func (s *Source) scanBuckets(
)
}

func (s *Source) scanBucket(
ctx context.Context,
client *s3.Client,
role string,
bucket string,
reporter sources.ChunkReporter,
startAfter *string,
objectCount *uint64,
) {
s.metricsCollector.RecordBucketForRole(role)

ctx = context.WithValue(ctx, "bucket", bucket)

if common.IsDone(ctx) {
ctx.Logger().Error(ctx.Err(), "context done, while scanning bucket")
return
}

ctx.Logger().V(3).Info("Scanning bucket")

regionalClient, err := s.getRegionalClientForBucket(ctx, client, role, bucket)
if err != nil {
ctx.Logger().Error(err, "could not get regional client for bucket")
return
}

errorCount := sync.Map{}

input := &s3.ListObjectsV2Input{Bucket: &bucket}
if startAfter != nil {
input.StartAfter = startAfter
}

pageNumber := 1
paginator := s3.NewListObjectsV2Paginator(regionalClient, input)
if objectCount == nil {
var newObjectCount uint64
objectCount = &newObjectCount
}
for paginator.HasMorePages() {
output, err := paginator.NextPage(ctx)
if err != nil {
if role == "" {
ctx.Logger().Error(err, "could not list objects in bucket")
} else {
// Our documentation blesses specifying a role to assume without specifying buckets to scan, which will
// often cause this to happen a lot (because in that case the scanner tries to scan every bucket in the
// account, but the role probably doesn't have access to all of them). This makes it expected behavior
// and therefore not an error.
ctx.Logger().V(3).Info("could not list objects in bucket", "err", err)
}
break
}
pageMetadata := pageMetadata{
bucket: bucket,
pageNumber: pageNumber,
client: regionalClient,
page: output,
}
processingState := processingState{
errorCount: &errorCount,
objectCount: objectCount,
}
s.pageChunker(ctx, pageMetadata, processingState, reporter)

pageNumber++
}
}

// Chunks emits chunks of bytes over a channel.
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ ...sources.ChunkingTarget) error {
visitor := func(c context.Context, defaultRegionClient *s3.Client, roleArn string, buckets []string) error {
Expand Down Expand Up @@ -429,14 +453,12 @@ func (s *Source) pageChunker(
ctx context.Context,
metadata pageMetadata,
state processingState,
chunksChan chan *sources.Chunk,
reporter sources.ChunkReporter,
) {
s.checkpointer.Reset() // Reset the checkpointer for each PAGE
ctx = context.WithValues(ctx, "bucket", metadata.bucket, "page_number", metadata.pageNumber)

for objIdx, obj := range metadata.page.Contents {
ctx = context.WithValues(ctx, "key", *obj.Key, "size", *obj.Size)

if common.IsDone(ctx) {
return
}
Expand Down Expand Up @@ -572,12 +594,11 @@ func (s *Source) pageChunker(
Verify: s.verify,
}

if err := handlers.HandleFile(ctx, res.Body, chunkSkel, sources.ChanReporter{Ch: chunksChan}); err != nil {
if err := handlers.HandleFile(ctx, res.Body, chunkSkel, reporter); err != nil {
ctx.Logger().Error(err, "error handling file")
s.metricsCollector.RecordObjectError(metadata.bucket)
return nil
}

atomic.AddUint64(state.objectCount, 1)
ctx.Logger().V(5).Info("S3 object scanned.", "object_count", state.objectCount)
nErr, ok = state.errorCount.Load(prefix)
Expand All @@ -587,17 +608,14 @@ func (s *Source) pageChunker(
if nErr.(int) > 0 {
state.errorCount.Store(prefix, 0)
}

// Update progress after successful processing.
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for scanned object")
}
s.metricsCollector.RecordObjectScanned(metadata.bucket, float64(*obj.Size))

return nil
})
}

_ = s.jobPool.Wait()
}

Expand Down Expand Up @@ -681,3 +699,62 @@ func (s *Source) visitRoles(
func makeS3Link(bucket, region, key string) string {
return fmt.Sprintf("https://%s.s3.%s.amazonaws.com/%s", bucket, region, key)
}

type S3SourceUnit struct {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You've defined this unit type, but you haven't modified the source to actually use it. The source type still embeds CommonSourceUnitUnmarshaller, so it will still unmarshal source units to CommonSourceUnit instead of your new type. You'll need to define custom unmarshalling logic. (The git source has an example of custom unmarshalling logic you can look at.)

Also, I recommend putting the unit struct and related code in a separate file, because we do that for several other sources, and I think it makes things more readable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, thanks for pointing this out. I wasn't aware of this. I'll make the changes.

Bucket string
Role string
}

func (s S3SourceUnit) SourceUnitID() (string, sources.SourceUnitKind) {
// The ID is the bucket name, and the kind is "s3_bucket".
return s.Bucket, "s3_bucket"
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mcastorina I forget - is it a problem if SourceUnitID can't be used to round-trip a unit? (In this case, we lose the Role field.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll wait for @mcastorina's answer before making changes here, but here's what the description comment says for SourceUnitID():

// SourceUnitID uniquely identifies a source unit. It does not need to
// be human readable or two-way, however, it should be canonical and
// stable across runs.

The bucket name is a globally unique value, so with that aspect we should be good.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, good catch. I guess the round-trip-abillity happens in the source manager somewhere? (@mcastorina?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We take the full unit object and JSON marshal it, so the fields need to be public. Idk if I documented that anywhere though, but that's why a source needs to implement unmarshalling but not marshalling.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the fields are public, so we're good there. But based on our discussion in the thread below regarding having the role in resumption info, it seems like a good idea to have the role in the SourceUnitID as well. I'll add it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it and realized it might not be best to add this yet as this also affects sub-unit resumption because the resumption info is supposed to be saved against the SourceUnitID, and our current checkpointer only works with buckets, not roles. I have reverted the changes and will wait for your responses to decide if we want to go with roles being part of resumption info or not.


func (s S3SourceUnit) Display() string {
return s.Bucket
}

var _ sources.SourceUnit = S3SourceUnit{}

// Enumerate implements SourceUnitEnumerator interface. This implementation visits
// each configured role and passes each s3 bucket as a source unit
func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) error {
visitor := func(c context.Context, defaultRegionClient *s3.Client, roleArn string, buckets []string) error {
for _, bucket := range buckets {
if common.IsDone(ctx) {
return ctx.Err()
}

ctx.Logger().V(5).Info("Enumerating bucket", "bucket", bucket)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, "enumerating bucket" implies that the bucket itself is being ranged over, not that the bucket is being emitted from an enumeration of all buckets. I'd write this like "Found bucket" or something.

I'm not sure we even need a log message here - the bucket will show up in a list elsewhere, so logging it just duplicates that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. I'll just remove the logs.


unit := S3SourceUnit{
Bucket: bucket,
Role: roleArn,
}

if err := reporter.UnitOk(ctx, unit); err != nil {
return err
}
}
return nil
}

return s.visitRoles(ctx, visitor)
}

func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporter sources.ChunkReporter) error {

s3unit, ok := unit.(S3SourceUnit)
if !ok {
return fmt.Errorf("expected *S3SourceUnit, got %T", unit)
}
bucket := s3unit.Bucket

defaultClient, err := s.newClient(ctx, defaultAWSRegion, s3unit.Role)
if err != nil {
return fmt.Errorf("could not create s3 client: %w", err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend adding the bucket and role to this error message so that if any end users see it they have a better chance of self-diagnosing the issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

}

s.scanBucket(ctx, defaultClient, s3unit.Role, bucket, reporter, nil, nil)
return nil
}
89 changes: 89 additions & 0 deletions pkg/sources/s3/s3_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
"github.com/trufflesecurity/trufflehog/v3/pkg/sourcestest"
)

func TestSource_ChunksCount(t *testing.T) {
Expand Down Expand Up @@ -391,3 +392,91 @@ func TestSourceChunksResumptionMultipleBucketsIgnoredBucket(t *testing.T) {

assert.Equal(t, 103, count, "Should have processed all remaining data on resume")
}

func TestSource_Enumerate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()

secret, err := common.GetTestSecret(ctx)
if err != nil {
t.Fatal(fmt.Errorf("failed to access secret: %v", err))
}

s3key := secret.MustGetField("AWS_S3_KEY")
s3secret := secret.MustGetField("AWS_S3_SECRET")

connection := &sourcespb.S3{
Credential: &sourcespb.S3_AccessKey{
AccessKey: &credentialspb.KeySecret{
Key: s3key,
Secret: s3secret,
},
},
Buckets: []string{"truffletestbucket"},
}

conn, err := anypb.New(connection)
if err != nil {
t.Fatal(err)
}

s := Source{}
err = s.Init(ctx, "test enumerate", 0, 0, false, conn, 1)
assert.NoError(t, err)

reporter := sourcestest.TestReporter{}
err = s.Enumerate(ctx, &reporter)
assert.NoError(t, err)

assert.Equal(t, len(reporter.Units), 1)
assert.Equal(t, 0, len(reporter.UnitErrs), "Expected no errors during enumeration")

for _, unit := range reporter.Units {
id, _ := unit.SourceUnitID()
assert.NotEmpty(t, id, "Unit ID should not be empty")
}
}

func TestSource_ChunkUnit(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()

secret, err := common.GetTestSecret(ctx)
if err != nil {
t.Fatal(fmt.Errorf("failed to access secret: %v", err))
}

s3key := secret.MustGetField("AWS_S3_KEY")
s3secret := secret.MustGetField("AWS_S3_SECRET")

connection := &sourcespb.S3{
Credential: &sourcespb.S3_AccessKey{
AccessKey: &credentialspb.KeySecret{
Key: s3key,
Secret: s3secret,
},
},
Buckets: []string{"truffletestbucket"},
}

conn, err := anypb.New(connection)
if err != nil {
t.Fatal(err)
}

s := Source{}
err = s.Init(ctx, "test enumerate", 0, 0, false, conn, 1)
assert.NoError(t, err)

reporter := sourcestest.TestReporter{}
err = s.Enumerate(ctx, &reporter)
assert.NoError(t, err)

for _, unit := range reporter.Units {
err = s.ChunkUnit(ctx, unit, &reporter)
assert.NoError(t, err, "Expected no error during ChunkUnit")
}

assert.Equal(t, 103, len(reporter.Chunks))
assert.Equal(t, 0, len(reporter.ChunkErrs))
}
Loading