diff --git a/CHANGELOG.md b/CHANGELOG.md index a4ae615..2edec06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -58,6 +58,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: `data.Setter.AddDataToContext` no longer takes a variadic. Signature is now `AddDataToContext(ctx, data map[string]any) (context.Context, error)`. Callers with multiple sources must merge them first. +- **BREAKING**: `Compiler.Compile` and `Loader.GetReader` now take a + `context.Context` as the first argument. `NewExecutableUnit` gains a leading + `ctx` too. The Extism compiler's `WithContext` option and the FromHTTP + loader's `GetReaderWithContext` helper are removed (now redundant). ### Deprecated - The twelve legacy top-level constructors (`FromRisorFile`, diff --git a/engines/extism/compiler/compiler.go b/engines/extism/compiler/compiler.go index e46d10d..7f5c1a1 100644 --- a/engines/extism/compiler/compiler.go +++ b/engines/extism/compiler/compiler.go @@ -19,7 +19,6 @@ type compileFunc func(ctx context.Context, wasmBytes []byte, opts *compile.Setti // Compiler implements the script.Compiler interface for WASM modules type Compiler struct { entryPointName string - ctx context.Context options *compile.Settings logHandler slog.Handler logger *slog.Logger @@ -56,24 +55,25 @@ func (c *Compiler) String() string { return "extism.Compiler" } -// Compile implements script.Compiler -func (c *Compiler) Compile(scriptReader io.ReadCloser) (script.ExecutableContent, error) { +// Compile implements script.Compiler. Cancelling ctx halts the WASM +// compilation and the entry-point probe. +func (c *Compiler) Compile(ctx context.Context, scriptReader io.ReadCloser) (script.ExecutableContent, error) { logger := c.logger.WithGroup("compile") if scriptReader == nil { return nil, ErrContentNil } + defer func() { + if err := scriptReader.Close(); err != nil { + logger.Warn("failed to close script reader", "error", err) + } + }() scriptBytes, err := io.ReadAll(scriptReader) if err != nil { return nil, fmt.Errorf("failed to read script: %w", err) } - err = scriptReader.Close() - if err != nil { - return nil, fmt.Errorf("failed to close reader: %w", err) - } - if len(scriptBytes) == 0 { logger.Error("Compile called with empty script") return nil, ErrContentNil @@ -82,7 +82,7 @@ func (c *Compiler) Compile(scriptReader io.ReadCloser) (script.ExecutableContent logger.Debug("Starting WASM compilation", "scriptLength", len(scriptBytes)) // Compile the WASM module using the configured compile function (defaults to compile.CompileBytes) - plugin, err := c.compileFn(c.ctx, scriptBytes, c.options) + plugin, err := c.compileFn(ctx, scriptBytes, c.options) if err != nil { return nil, fmt.Errorf("%w: %w", ErrValidationFailed, err) } @@ -92,12 +92,14 @@ func (c *Compiler) Compile(scriptReader io.ReadCloser) (script.ExecutableContent } // Create a temporary instance to verify the entry point exists - instance, err := plugin.Instance(c.ctx, extismSDK.PluginInstanceConfig{}) + instance, err := plugin.Instance(ctx, extismSDK.PluginInstanceConfig{}) if err != nil { return nil, fmt.Errorf("%w: failed to create test instance: %w", ErrValidationFailed, err) } defer func() { - if err := instance.Close(c.ctx); err != nil { + // Use a cancel-immune ctx for cleanup so a cancelled compile ctx + // doesn't abort the plugin instance Close and leak resources. + if err := instance.Close(context.WithoutCancel(ctx)); err != nil { logger.Warn("Failed to close Extism plugin instance in compiler", "error", err) } }() diff --git a/engines/extism/compiler/compiler_test.go b/engines/extism/compiler/compiler_test.go index 3c10730..ab7be0a 100644 --- a/engines/extism/compiler/compiler_test.go +++ b/engines/extism/compiler/compiler_test.go @@ -120,7 +120,7 @@ func TestCompiler_Compile(t *testing.T) { reader := newMockScriptReaderCloser(wasmBytes) reader.On("Close").Return(nil) - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.NoError(t, err) require.NotNil(t, execContent) @@ -164,7 +164,7 @@ func TestCompiler_Compile(t *testing.T) { reader := newMockScriptReaderCloser(wasmBytes) reader.On("Close").Return(nil) - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.NoError(t, err) require.NotNil(t, execContent) @@ -208,7 +208,7 @@ func TestCompiler_Compile(t *testing.T) { reader := newMockScriptReaderCloser(wasmBytes) reader.On("Close").Return(nil) - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.NoError(t, err) require.NotNil(t, execContent) @@ -231,7 +231,7 @@ func TestCompiler_Compile(t *testing.T) { require.NoError(t, err) require.NotNil(t, comp) - execContent, err := comp.Compile(nil) + execContent, err := comp.Compile(t.Context(), nil) require.Error(t, err) require.Nil(t, execContent) require.ErrorIs(t, err, ErrContentNil) @@ -248,7 +248,7 @@ func TestCompiler_Compile(t *testing.T) { reader := newMockScriptReaderCloser([]byte{}) reader.On("Close").Return(nil) - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.Error(t, err) require.Nil(t, execContent) require.ErrorIs(t, err, ErrContentNil) @@ -267,7 +267,7 @@ func TestCompiler_Compile(t *testing.T) { reader := newMockScriptReaderCloser([]byte("not-wasm")) reader.On("Close").Return(nil) - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.Error(t, err) require.Nil(t, execContent) require.ErrorIs(t, err, ErrValidationFailed) @@ -287,7 +287,7 @@ func TestCompiler_Compile(t *testing.T) { reader := newMockScriptReaderCloser(wasmBytes) reader.On("Close").Return(nil) - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.Error(t, err) require.Nil(t, execContent) require.ErrorIs(t, err, ErrValidationFailed) @@ -309,28 +309,29 @@ func TestCompiler_Compile_Branches(t *testing.T) { comp := createTestCompiler(t, "main") reader := &errReader{readErr: errors.New("read kaboom")} - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.Error(t, err) require.Nil(t, execContent) require.ErrorContains(t, err, "failed to read script") require.ErrorContains(t, err, "read kaboom") }) - t.Run("scriptReader.Close fails", func(t *testing.T) { + t.Run("scriptReader.Close error is silenced", func(t *testing.T) { t.Parallel() comp := createTestCompiler(t, "main") - // Read succeeds (returns EOF on first Read with empty buffer is fine for - // io.ReadAll, but we want non-empty bytes so Close error is the only failure). + // Close errors are logged via the deferred cleanup, not returned. The + // "anything" bytes still fail validation; we assert that error + // propagates (not the close error). reader := &readCloseErr{ buf: bytes.NewReader([]byte("anything")), closeErr: errors.New("close kaboom"), } - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.Error(t, err) require.Nil(t, execContent) - require.ErrorContains(t, err, "failed to close reader") - require.ErrorContains(t, err, "close kaboom") + require.ErrorIs(t, err, ErrValidationFailed) + require.NotContains(t, err.Error(), "close kaboom") }) t.Run("compileFn returns nil plugin without error", func(t *testing.T) { @@ -347,7 +348,7 @@ func TestCompiler_Compile_Branches(t *testing.T) { reader := newMockScriptReaderCloser([]byte("any-bytes")) reader.On("Close").Return(nil) - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.Error(t, err) require.Nil(t, execContent) require.ErrorIs(t, err, ErrBytecodeNil) @@ -373,7 +374,7 @@ func TestCompiler_Compile_Branches(t *testing.T) { reader := newMockScriptReaderCloser([]byte("any-bytes")) reader.On("Close").Return(nil) - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.Error(t, err) require.Nil(t, execContent) require.ErrorIs(t, err, ErrValidationFailed) @@ -405,7 +406,7 @@ func TestCompiler_Compile_Branches(t *testing.T) { reader := newMockScriptReaderCloser([]byte("any-bytes")) reader.On("Close").Return(nil) - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.NoError(t, err, "instance.Close error must not propagate") require.NotNil(t, execContent) plugin.AssertExpectations(t) @@ -435,7 +436,7 @@ func TestCompiler_Compile_Branches(t *testing.T) { reader := newMockScriptReaderCloser([]byte("any-bytes")) reader.On("Close").Return(nil) - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.Error(t, err) require.Nil(t, execContent) require.ErrorIs(t, err, ErrValidationFailed) diff --git a/engines/extism/compiler/options.go b/engines/extism/compiler/options.go index 9cd0b3d..3476aad 100644 --- a/engines/extism/compiler/options.go +++ b/engines/extism/compiler/options.go @@ -1,7 +1,6 @@ package compiler import ( - "context" "fmt" "log/slog" @@ -96,17 +95,6 @@ func WithHostFunctions(funcs []extismSDK.HostFunction) FunctionalOption { } } -// WithContext creates an option to set a custom context for the Extism compiler. -func WithContext(ctx context.Context) FunctionalOption { - return func(c *Compiler) error { - if ctx == nil { - return fmt.Errorf("context cannot be nil") - } - c.ctx = ctx - return nil - } -} - // applyDefaults sets the default values for a compiler. func (c *Compiler) applyDefaults() { // Set default entry point @@ -132,11 +120,6 @@ func (c *Compiler) applyDefaults() { // Default WASI to true (EnableWASI is a bool so we don't need to check if it's nil) c.options.EnableWASI = true - // Default context - if c.ctx == nil { - c.ctx = context.Background() - } - // Default compile function (test seam); production path is compile.CompileBytes if c.compileFn == nil { c.compileFn = compile.CompileBytes @@ -167,10 +150,5 @@ func (c *Compiler) validate() error { return fmt.Errorf("runtime config cannot be nil") } - // Context cannot be nil - if c.ctx == nil { - return fmt.Errorf("context cannot be nil") - } - return nil } diff --git a/engines/extism/compiler/options_test.go b/engines/extism/compiler/options_test.go index 71955aa..09bffa1 100644 --- a/engines/extism/compiler/options_test.go +++ b/engines/extism/compiler/options_test.go @@ -8,7 +8,6 @@ import ( extismSDK "github.com/extism/go-sdk" "github.com/robbyt/go-polyscript/engines/extism/compiler/internal/compile" - "github.com/robbyt/go-polyscript/platform/constants" "github.com/stretchr/testify/require" "github.com/tetratelabs/wazero" ) @@ -313,52 +312,6 @@ func TestCompilerOptions_Options(t *testing.T) { require.Equal(t, hostFuncs, c.options.HostFunctions) }) }) - - // WithContext tests - t.Run("WithContext", func(t *testing.T) { - // Success cases - t.Run("valid context", func(t *testing.T) { - customCtx := context.WithValue( - t.Context(), - constants.EvalData, - "test-value", - ) - - c := &Compiler{} - c.applyDefaults() - opt := WithContext(customCtx) - err := opt(c) - - require.NoError(t, err) - require.Equal(t, customCtx, c.ctx) - require.Equal(t, "test-value", c.ctx.Value(constants.EvalData)) - }) - - t.Run("background context", func(t *testing.T) { - ctx := t.Context() - - c := &Compiler{} - c.applyDefaults() - opt := WithContext(ctx) - err := opt(c) - - require.NoError(t, err) - require.Equal(t, ctx, c.ctx) - }) - - // Error case - t.Run("nil context", func(t *testing.T) { - c := &Compiler{} - c.applyDefaults() - // Using variable to create nil context to avoid linter issues - var nilContext context.Context - nilOpt := WithContext(nilContext) - err := nilOpt(c) - - require.Error(t, err) - require.Contains(t, err.Error(), "context cannot be nil") - }) - }) } // TestCompilerOptions_SetupLogger tests the setupLogger method @@ -421,12 +374,10 @@ func TestCompilerOptions_DefaultsAndValidation(t *testing.T) { require.NotNil(t, c.options.RuntimeConfig, "runtime config should be initialized") require.NotNil(t, c.options.HostFunctions, "host functions should be initialized") require.Empty(t, c.options.HostFunctions, "host functions should be empty by default") - require.NotNil(t, c.ctx, "context should be initialized") }) t.Run("custom values preserved", func(t *testing.T) { customEntryPoint := "custom_entry" - customCtx := context.WithValue(t.Context(), constants.EvalData, "value") customConfig := wazero.NewRuntimeConfig() // Create a compiler with defaults @@ -435,13 +386,11 @@ func TestCompilerOptions_DefaultsAndValidation(t *testing.T) { // Then set the values that would have been set by options c.entryPointName = customEntryPoint - c.ctx = customCtx c.options.RuntimeConfig = customConfig c.options.EnableWASI = false // Check that values are set as expected require.Equal(t, customEntryPoint, c.entryPointName) - require.Equal(t, customCtx, c.ctx) require.Equal(t, customConfig, c.options.RuntimeConfig) require.False(t, c.options.EnableWASI) }) @@ -494,7 +443,6 @@ func TestCompilerOptions_DefaultsAndValidation(t *testing.T) { c := &Compiler{ entryPointName: "custom", logHandler: slog.NewTextHandler(bytes.NewBuffer(nil), nil), - ctx: t.Context(), options: &compile.Settings{ RuntimeConfig: wazero.NewRuntimeConfig(), }, @@ -507,7 +455,6 @@ func TestCompilerOptions_DefaultsAndValidation(t *testing.T) { t.Run("nil handler and logger", func(t *testing.T) { c := &Compiler{ entryPointName: "test", - ctx: t.Context(), logHandler: nil, logger: nil, options: &compile.Settings{ @@ -523,7 +470,6 @@ func TestCompilerOptions_DefaultsAndValidation(t *testing.T) { c := &Compiler{ entryPointName: "", logHandler: slog.NewTextHandler(bytes.NewBuffer(nil), nil), - ctx: t.Context(), options: &compile.Settings{ RuntimeConfig: wazero.NewRuntimeConfig(), }, @@ -538,7 +484,6 @@ func TestCompilerOptions_DefaultsAndValidation(t *testing.T) { c := &Compiler{ entryPointName: "test", logHandler: slog.NewTextHandler(bytes.NewBuffer(nil), nil), - ctx: t.Context(), options: nil, } @@ -551,7 +496,6 @@ func TestCompilerOptions_DefaultsAndValidation(t *testing.T) { c := &Compiler{ entryPointName: "test", logHandler: slog.NewTextHandler(bytes.NewBuffer(nil), nil), - ctx: t.Context(), options: &compile.Settings{ RuntimeConfig: nil, }, @@ -561,21 +505,6 @@ func TestCompilerOptions_DefaultsAndValidation(t *testing.T) { require.Error(t, err) require.Contains(t, err.Error(), "runtime config cannot be nil") }) - - t.Run("nil context", func(t *testing.T) { - c := &Compiler{ - entryPointName: "test", - logHandler: slog.NewTextHandler(bytes.NewBuffer(nil), nil), - ctx: nil, - options: &compile.Settings{ - RuntimeConfig: wazero.NewRuntimeConfig(), - }, - } - - err := c.validate() - require.Error(t, err) - require.Contains(t, err.Error(), "context cannot be nil") - }) }) } @@ -896,33 +825,6 @@ func TestCompilerOptions(t *testing.T) { require.Equal(t, hostFuncs, c.options.HostFunctions) }) }) - - t.Run("WithContext option", func(t *testing.T) { - ctx := t.Context() - - t.Run("valid context", func(t *testing.T) { - c := &Compiler{} - c.applyDefaults() - opt := WithContext(ctx) - err := opt(c) - - require.NoError(t, err) - require.Equal(t, ctx, c.ctx) - }) - - t.Run("nil context", func(t *testing.T) { - c := &Compiler{} - c.applyDefaults() - // We need to test our validation of nil contexts but without passing nil directly - // to satisfy the linter. Use a type conversion trick to create a nil context. - var nilContext context.Context - nilOpt := WithContext(nilContext) - err := nilOpt(c) - - require.Error(t, err) - require.Contains(t, err.Error(), "context cannot be nil") - }) - }) }) t.Run("Defaults and Validation", func(t *testing.T) { @@ -938,7 +840,6 @@ func TestCompilerOptions(t *testing.T) { require.NotNil(t, c.options.RuntimeConfig) require.NotNil(t, c.options.HostFunctions) require.Empty(t, c.options.HostFunctions) - require.NotNil(t, c.ctx) }) t.Run("defaults - entry point handling", func(t *testing.T) { @@ -946,7 +847,6 @@ func TestCompilerOptions(t *testing.T) { c := &Compiler{ entryPointName: "", options: &compile.Settings{}, - ctx: t.Context(), } c.applyDefaults() @@ -957,7 +857,6 @@ func TestCompilerOptions(t *testing.T) { c := &Compiler{ entryPointName: "initialValue", options: &compile.Settings{}, - ctx: t.Context(), } require.Equal(t, "initialValue", c.entryPointName) @@ -973,7 +872,6 @@ func TestCompilerOptions(t *testing.T) { c := &Compiler{ entryPointName: customEntryPoint, options: &compile.Settings{}, - ctx: t.Context(), } c.applyDefaults() diff --git a/engines/extism/evaluator/exec_helpers_test.go b/engines/extism/evaluator/exec_helpers_test.go index 9fe4ca5..791de7f 100644 --- a/engines/extism/evaluator/exec_helpers_test.go +++ b/engines/extism/evaluator/exec_helpers_test.go @@ -1,6 +1,7 @@ package evaluator import ( + "context" "io" "net/url" "strings" @@ -16,8 +17,8 @@ import ( // by NewExecutableUnit. type loaderMock struct{ mock.Mock } -func (m *loaderMock) GetReader() (io.ReadCloser, error) { - args := m.Called() +func (m *loaderMock) GetReader(ctx context.Context) (io.ReadCloser, error) { + args := m.Called(ctx) if args.Get(0) == nil { return nil, args.Error(1) } @@ -36,8 +37,8 @@ func (m *loaderMock) GetSourceURL() *url.URL { // ExecutableContent (which may be nil). type compilerMock struct{ mock.Mock } -func (m *compilerMock) Compile(r io.ReadCloser) (script.ExecutableContent, error) { - args := m.Called(r) +func (m *compilerMock) Compile(ctx context.Context, r io.ReadCloser) (script.ExecutableContent, error) { + args := m.Called(ctx, r) if args.Get(0) == nil { return nil, args.Error(1) } @@ -61,11 +62,11 @@ func newExe( id = t.Name() } ldr := new(loaderMock) - ldr.On("GetReader"). + ldr.On("GetReader", mock.Anything). Return(io.NopCloser(strings.NewReader("dummy")), nil) cmp := new(compilerMock) - cmp.On("Compile", mock.Anything).Return(content, nil) - exe, err := script.NewExecutableUnit(nil, id, ldr, cmp, provider) + cmp.On("Compile", mock.Anything, mock.Anything).Return(content, nil) + exe, err := script.NewExecutableUnit(t.Context(), nil, id, ldr, cmp, provider) require.NoError(t, err) return exe } diff --git a/engines/extism/mock_loader_test.go b/engines/extism/mock_loader_test.go index 71ec572..167a717 100644 --- a/engines/extism/mock_loader_test.go +++ b/engines/extism/mock_loader_test.go @@ -1,6 +1,7 @@ package extism import ( + "context" "io" "net/url" @@ -19,8 +20,8 @@ func (m *loaderMock) GetSourceURL() *url.URL { return args.Get(0).(*url.URL) } -func (m *loaderMock) GetReader() (io.ReadCloser, error) { - args := m.Called() +func (m *loaderMock) GetReader(ctx context.Context) (io.ReadCloser, error) { + args := m.Called(ctx) if args.Get(0) == nil { return nil, args.Error(1) } diff --git a/engines/extism/new.go b/engines/extism/new.go index 2945f39..2ddbc13 100644 --- a/engines/extism/new.go +++ b/engines/extism/new.go @@ -32,6 +32,7 @@ package extism import ( + "context" "errors" "fmt" @@ -95,7 +96,10 @@ func FromExtismLoader(ldr loader.Loader, opts ...Option) (*evaluator.Evaluator, execUnitID = u.String() } - execUnit, err := script.NewExecutableUnit(cfg.handler, execUnitID, ldr, comp, provider) + // FromExtismLoader is a one-shot startup constructor; compile uses a + // fresh Background context. Callers needing cancellable compile + // should drive script.NewExecutableUnit directly. + execUnit, err := script.NewExecutableUnit(context.Background(), cfg.handler, execUnitID, ldr, comp, provider) if err != nil { return nil, err } diff --git a/engines/extism/new_test.go b/engines/extism/new_test.go index 8d52de2..eec26e3 100644 --- a/engines/extism/new_test.go +++ b/engines/extism/new_test.go @@ -15,6 +15,7 @@ import ( "github.com/robbyt/go-polyscript/platform/data" "github.com/robbyt/go-polyscript/platform/script/loader" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -25,7 +26,7 @@ func newWASMLoader(t *testing.T) *loaderMock { require.NoError(t, err) mockLoader.On("GetSourceURL").Return(mockURL) wasmBytes := wasmdata.TestModule - mockLoader.On("GetReader").Return(io.NopCloser(bytes.NewReader(wasmBytes)), nil) + mockLoader.On("GetReader", mock.Anything).Return(io.NopCloser(bytes.NewReader(wasmBytes)), nil) return mockLoader } @@ -35,7 +36,7 @@ func newErrorLoader(t *testing.T, msg string) *loaderMock { mockURL, err := url.Parse("file:///test-wasm-file.wasm") require.NoError(t, err) mockLoader.On("GetSourceURL").Return(mockURL) - mockLoader.On("GetReader").Return(nil, errors.New(msg)) + mockLoader.On("GetReader", mock.Anything).Return(nil, errors.New(msg)) return mockLoader } @@ -148,7 +149,7 @@ func TestFromExtismLoader_LoaderError(t *testing.T) { func TestFromExtismLoader_NilSourceURL(t *testing.T) { mockLoader := new(loaderMock) mockLoader.On("GetSourceURL").Return(nil) - mockLoader.On("GetReader").Return(io.NopCloser(bytes.NewReader(wasmdata.TestModule)), nil) + mockLoader.On("GetReader", mock.Anything).Return(io.NopCloser(bytes.NewReader(wasmdata.TestModule)), nil) eval, err := FromExtismLoader(mockLoader, WithEntryPoint(wasmdata.EntrypointGreet)) require.NoError(t, err) diff --git a/engines/risor/compiler/compiler.go b/engines/risor/compiler/compiler.go index fafead4..a055a00 100644 --- a/engines/risor/compiler/compiler.go +++ b/engines/risor/compiler/compiler.go @@ -1,6 +1,7 @@ package compiler import ( + "context" "fmt" "io" "log/slog" @@ -48,25 +49,26 @@ func (c *Compiler) String() string { } // Compile turns the provided script content into runnable bytecode. -func (c *Compiler) Compile(scriptLoader io.ReadCloser) (script.ExecutableContent, error) { +// Cancelling ctx halts the Risor parser. +func (c *Compiler) Compile(ctx context.Context, scriptLoader io.ReadCloser) (script.ExecutableContent, error) { if scriptLoader == nil { return nil, ErrContentNil } + defer func() { + if err := scriptLoader.Close(); err != nil { + c.logger.WithGroup("compile").Warn("failed to close script reader", "error", err) + } + }() scriptBodyBytes, err := io.ReadAll(scriptLoader) if err != nil { return nil, fmt.Errorf("failed to read script: %w", err) } - err = scriptLoader.Close() - if err != nil { - return nil, fmt.Errorf("failed to close reader: %w", err) - } - - return c.compile(scriptBodyBytes) + return c.compile(ctx, scriptBodyBytes) } -func (c *Compiler) compile(scriptBodyBytes []byte) (*executable, error) { +func (c *Compiler) compile(ctx context.Context, scriptBodyBytes []byte) (*executable, error) { logger := c.logger.WithGroup("compile") if len(scriptBodyBytes) == 0 { return nil, ErrContentNil @@ -96,7 +98,7 @@ func (c *Compiler) compile(scriptBodyBytes []byte) (*executable, error) { logger.Debug("Starting Risor compilation", "scriptLength", len(trimmedScript)) - bc, err := compile.CompileWithGlobals(&scriptContent, c.globals) + bc, err := compile.CompileWithGlobals(ctx, &scriptContent, c.globals) if err != nil { return nil, fmt.Errorf("%w: %w", ErrValidationFailed, err) } diff --git a/engines/risor/compiler/compiler_test.go b/engines/risor/compiler/compiler_test.go index 19bdeea..5cce9b6 100644 --- a/engines/risor/compiler/compiler_test.go +++ b/engines/risor/compiler/compiler_test.go @@ -145,7 +145,7 @@ main() } // Execute test - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.NoError(t, err, "Did not expect an error but got one") require.NotNil(t, execContent, "Expected execContent to be non-nil") require.Equal( @@ -218,7 +218,7 @@ main() } // Execute test - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.Error(t, err, "Expected an error but got none") require.Nil(t, execContent, "Expected execContent to be nil") require.ErrorIs(t, err, tt.err, "Expected error %v, got %v", tt.err, err) @@ -235,7 +235,7 @@ main() require.NoError(t, err) require.NotNil(t, comp, "Expected compiler to be non-nil") - execContent, err := comp.Compile(nil) + execContent, err := comp.Compile(t.Context(), nil) require.Error(t, err, "Expected an error but got none") require.Nil(t, execContent, "Expected execContent to be nil") require.ErrorIs(t, err, ErrContentNil, "Expected error to be ErrContentNil") @@ -251,7 +251,7 @@ main() // Create a reader that will return an error reader := &mockErrorReader{} - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.Error(t, err, "Expected an error but got none") require.Nil(t, execContent, "Expected execContent to be nil") require.Contains( @@ -262,26 +262,22 @@ main() ) }) - t.Run("close error", func(t *testing.T) { + t.Run("close error is silenced", func(t *testing.T) { comp, err := New( WithLogHandler(slog.NewTextHandler(os.Stdout, nil)), ) require.NoError(t, err) require.NotNil(t, comp, "Expected compiler to be non-nil") - // Create a reader that will return an error on close + // Close errors are logged via the deferred cleanup, not returned — + // a successful read+compile still produces a usable executable. reader := newMockScriptReaderCloser(`"Hello, World!"`) reader.On("Close").Return(errors.New("test error")).Once() - execContent, err := comp.Compile(reader) - require.Error(t, err, "Expected an error but got none") - require.Nil(t, execContent, "Expected execContent to be nil") - require.Contains( - t, - err.Error(), - "failed to close reader", - "Expected error to contain 'failed to close reader'", - ) + execContent, err := comp.Compile(t.Context(), reader) + require.NoError(t, err) + require.NotNil(t, execContent) + reader.AssertExpectations(t) }) }) @@ -295,7 +291,7 @@ main() // Here we test that we can directly call the compile method with a byteslice scriptBytes := []byte(`"Hello, World!"`) - executable, err := comp.compile(scriptBytes) + executable, err := comp.compile(t.Context(), scriptBytes) require.NoError(t, err, "Did not expect an error but got one") require.NotNil(t, executable, "Expected execContent to be non-nil") require.Equal( @@ -349,7 +345,7 @@ func TestCompilerOptions(t *testing.T) { mockReader.On("Close").Return(nil) } - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.NoError(t, err) require.NotNil(t, execContent) }) @@ -367,7 +363,7 @@ func TestCompilerOptions(t *testing.T) { mockReader.On("Close").Return(nil) } - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.NoError(t, err) require.NotNil(t, execContent) }) @@ -390,7 +386,7 @@ func TestCompileError(t *testing.T) { require.NotNil(t, comp, "Expected compiler to be non-nil") // Execute test with nil reader - execContent, err := comp.Compile(nil) + execContent, err := comp.Compile(t.Context(), nil) require.Error(t, err, "Expected an error but got none") require.Nil(t, execContent, "Expected execContent to be nil") require.ErrorIs(t, err, ErrContentNil, "Expected error to be ErrContentNil") @@ -406,7 +402,7 @@ func TestCompileWithBytecode(t *testing.T) { // Here we test that we can directly call the compile method with a byteslice scriptBytes := []byte(`"Hello, World!"`) - executable, err := comp.compile(scriptBytes) + executable, err := comp.compile(t.Context(), scriptBytes) require.NoError(t, err, "Did not expect an error but got one") require.NotNil(t, executable, "Expected execContent to be non-nil") require.Equal(t, string(scriptBytes), executable.GetSource(), "Script content does not match") @@ -430,7 +426,7 @@ func TestCompileIOError(t *testing.T) { // Create a reader that will return an error reader := &mockErrorReader{} - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.Error(t, err, "Expected an error but got none") require.Nil(t, execContent, "Expected execContent to be nil") require.Contains( @@ -441,27 +437,22 @@ func TestCompileIOError(t *testing.T) { ) } -func TestCompileCloseError(t *testing.T) { - // Test that we return the correct error when there's an error closing the reader +func TestCompileCloseErrorIsSilenced(t *testing.T) { + // Close errors are logged via the deferred cleanup, not returned — + // a successful read+compile still produces a usable executable. comp, err := New( WithLogHandler(slog.NewTextHandler(os.Stdout, nil)), ) require.NoError(t, err) require.NotNil(t, comp, "Expected compiler to be non-nil") - // Create a reader that will return an error on close reader := newMockScriptReaderCloser(`"Hello, World!"`) reader.On("Close").Return(errors.New("test error")).Once() - execContent, err := comp.Compile(reader) - require.Error(t, err, "Expected an error but got none") - require.Nil(t, execContent, "Expected execContent to be nil") - require.Contains( - t, - err.Error(), - "failed to close reader", - "Expected error to contain 'failed to close reader'", - ) + execContent, err := comp.Compile(t.Context(), reader) + require.NoError(t, err) + require.NotNil(t, execContent) + reader.AssertExpectations(t) } // mockErrorReader implements io.ReadCloser for testing read errors diff --git a/engines/risor/compiler/internal/compile/compile.go b/engines/risor/compiler/internal/compile/compile.go index 622b6d3..e0781dc 100644 --- a/engines/risor/compiler/internal/compile/compile.go +++ b/engines/risor/compiler/internal/compile/compile.go @@ -12,13 +12,14 @@ import ( risorParser "github.com/deepnoodle-ai/risor/v2/pkg/parser" ) -// Compile parses and compiles the script content into bytecode -func Compile(scriptContent *string, cfg *risorCompiler.Config) (*bytecode.Code, error) { +// Compile parses and compiles the script content into bytecode. Cancelling ctx +// halts the parser. +func Compile(ctx context.Context, scriptContent *string, cfg *risorCompiler.Config) (*bytecode.Code, error) { if scriptContent == nil { return nil, ErrContentNil } - ast, err := risorParser.Parse(context.Background(), *scriptContent, nil) + ast, err := risorParser.Parse(ctx, *scriptContent, nil) if err != nil { return nil, fmt.Errorf("%w: %w", ErrCompileFailed, err) } @@ -35,7 +36,7 @@ func Compile(scriptContent *string, cfg *risorCompiler.Config) (*bytecode.Code, // which are needed when parsing a script that will eventually have globals injected at eval time. // For example, if a script uses a request or response object, it needs to be compiled with those // global names, even though they won't be available until eval time. -func CompileWithGlobals(scriptContent *string, globals []string) (*bytecode.Code, error) { +func CompileWithGlobals(ctx context.Context, scriptContent *string, globals []string) (*bytecode.Code, error) { // Start with the standard builtins env and add custom globals env := risor.Builtins() for _, g := range globals { @@ -50,5 +51,5 @@ func CompileWithGlobals(scriptContent *string, globals []string) (*bytecode.Code GlobalNames: globalNames, } - return Compile(scriptContent, cfg) + return Compile(ctx, scriptContent, cfg) } diff --git a/engines/risor/compiler/internal/compile/compile_test.go b/engines/risor/compiler/internal/compile/compile_test.go index c1debba..2b42779 100644 --- a/engines/risor/compiler/internal/compile/compile_test.go +++ b/engines/risor/compiler/internal/compile/compile_test.go @@ -10,7 +10,7 @@ import ( func TestCompileSuccess(t *testing.T) { scriptContent := `true` - code, err := Compile(&scriptContent, nil) + code, err := Compile(t.Context(), &scriptContent, nil) require.NoError(t, err) require.NotNil(t, code) } @@ -21,7 +21,7 @@ func TestCompileSyntaxError(t *testing.T) { "Hello, World! ` - code, err := Compile(&scriptContent, nil) + code, err := Compile(t.Context(), &scriptContent, nil) require.Error(t, err) require.Nil(t, code) require.ErrorIs(t, err, ErrCompileFailed) @@ -34,14 +34,14 @@ func TestCompileWithGlobals(t *testing.T) { ` globals := []string{"request"} - code, err := CompileWithGlobals(&scriptContent, globals) + code, err := CompileWithGlobals(t.Context(), &scriptContent, globals) require.NoError(t, err) require.NotNil(t, code) } // TestCompileNilContent tests the handling of nil script content func TestCompileNilContent(t *testing.T) { - code, err := Compile(nil, nil) + code, err := Compile(t.Context(), nil, nil) require.Error(t, err) require.Nil(t, code) require.ErrorIs(t, err, ErrContentNil) @@ -50,7 +50,7 @@ func TestCompileNilContent(t *testing.T) { // TestCompileWithGlobalsNilContent tests the handling of nil script content with globals func TestCompileWithGlobalsNilContent(t *testing.T) { globals := []string{"request"} - code, err := CompileWithGlobals(nil, globals) + code, err := CompileWithGlobals(t.Context(), nil, globals) require.Error(t, err) require.Nil(t, code) require.ErrorIs(t, err, ErrContentNil) @@ -63,7 +63,7 @@ func TestCompileWithGlobalsSyntaxError(t *testing.T) { ` globals := []string{"request"} - code, err := CompileWithGlobals(&scriptContent, globals) + code, err := CompileWithGlobals(t.Context(), &scriptContent, globals) require.Error(t, err) require.Nil(t, code) require.ErrorIs(t, err, ErrCompileFailed) diff --git a/engines/risor/evaluator/evaluator_test.go b/engines/risor/evaluator/evaluator_test.go index 0f26373..e7737ac 100644 --- a/engines/risor/evaluator/evaluator_test.go +++ b/engines/risor/evaluator/evaluator_test.go @@ -81,6 +81,7 @@ func (m *MockContent) EngineType() types.Type { // Helper function to create a test executable unit func createTestExecutable( + ctx context.Context, handler slog.Handler, ld loader.Loader, globals []string, @@ -93,7 +94,7 @@ func createTestExecutable( if err != nil { return nil, fmt.Errorf("failed to create compiler: %w", err) } - return script.NewExecutableUnit(handler, "test-id", ld, c, provider) + return script.NewExecutableUnit(ctx, handler, "test-id", ld, c, provider) } // TestEvaluator_Evaluate tests evaluating Risor scripts @@ -168,7 +169,7 @@ func TestEvaluator_Evaluate(t *testing.T) { ctxProvider := data.NewContextProvider(constants.EvalData) // Create executable unit and evaluator - exe, err := createTestExecutable(handler, ld, []string{constants.Ctx}, ctxProvider) + exe, err := createTestExecutable(t.Context(), handler, ld, []string{constants.Ctx}, ctxProvider) require.NoError(t, err) evaluator := New(handler, exe) require.NotNil(t, evaluator) @@ -544,7 +545,7 @@ range(1000000000).each(x => x) ld, err := loader.NewFromString(script) require.NoError(t, err) ctxProvider := data.NewContextProvider(constants.EvalData) - exe, err := createTestExecutable(handler, ld, []string{constants.Ctx}, ctxProvider) + exe, err := createTestExecutable(t.Context(), handler, ld, []string{constants.Ctx}, ctxProvider) require.NoError(t, err) eval := New(handler, exe) require.NotNil(t, eval) @@ -589,7 +590,7 @@ func TestEval_ErrorTypeExposesRisorDetails(t *testing.T) { ld, err := loader.NewFromString(script) require.NoError(t, err) ctxProvider := data.NewContextProvider(constants.EvalData) - exe, err := createTestExecutable(handler, ld, []string{constants.Ctx}, ctxProvider) + exe, err := createTestExecutable(t.Context(), handler, ld, []string{constants.Ctx}, ctxProvider) require.NoError(t, err) eval := New(handler, exe) require.NotNil(t, eval) diff --git a/engines/risor/evaluator/exec_helpers_test.go b/engines/risor/evaluator/exec_helpers_test.go index 087af8d..8d2f433 100644 --- a/engines/risor/evaluator/exec_helpers_test.go +++ b/engines/risor/evaluator/exec_helpers_test.go @@ -1,6 +1,7 @@ package evaluator import ( + "context" "io" "net/url" "strings" @@ -17,8 +18,8 @@ import ( // in evaluator_test.go (which has a non-matching GetSourceURL signature). type unitLoaderMock struct{ mock.Mock } -func (m *unitLoaderMock) GetReader() (io.ReadCloser, error) { - args := m.Called() +func (m *unitLoaderMock) GetReader(ctx context.Context) (io.ReadCloser, error) { + args := m.Called(ctx) if args.Get(0) == nil { return nil, args.Error(1) } @@ -37,8 +38,8 @@ func (m *unitLoaderMock) GetSourceURL() *url.URL { // ExecutableContent (which may be nil). type compilerMock struct{ mock.Mock } -func (m *compilerMock) Compile(r io.ReadCloser) (script.ExecutableContent, error) { - args := m.Called(r) +func (m *compilerMock) Compile(ctx context.Context, r io.ReadCloser) (script.ExecutableContent, error) { + args := m.Called(ctx, r) if args.Get(0) == nil { return nil, args.Error(1) } @@ -62,11 +63,11 @@ func newExe( id = t.Name() } ldr := new(unitLoaderMock) - ldr.On("GetReader"). + ldr.On("GetReader", mock.Anything). Return(io.NopCloser(strings.NewReader("dummy")), nil) cmp := new(compilerMock) - cmp.On("Compile", mock.Anything).Return(content, nil) - exe, err := script.NewExecutableUnit(nil, id, ldr, cmp, provider) + cmp.On("Compile", mock.Anything, mock.Anything).Return(content, nil) + exe, err := script.NewExecutableUnit(t.Context(), nil, id, ldr, cmp, provider) require.NoError(t, err) return exe } diff --git a/engines/risor/mock_loader_test.go b/engines/risor/mock_loader_test.go index f0d699d..100fa7d 100644 --- a/engines/risor/mock_loader_test.go +++ b/engines/risor/mock_loader_test.go @@ -1,6 +1,7 @@ package risor import ( + "context" "io" "net/url" @@ -19,8 +20,8 @@ func (m *loaderMock) GetSourceURL() *url.URL { return args.Get(0).(*url.URL) } -func (m *loaderMock) GetReader() (io.ReadCloser, error) { - args := m.Called() +func (m *loaderMock) GetReader(ctx context.Context) (io.ReadCloser, error) { + args := m.Called(ctx) if args.Get(0) == nil { return nil, args.Error(1) } diff --git a/engines/risor/new.go b/engines/risor/new.go index 8c5aaf1..d1a4a42 100644 --- a/engines/risor/new.go +++ b/engines/risor/new.go @@ -30,6 +30,7 @@ package risor import ( + "context" "fmt" "github.com/robbyt/go-polyscript/engines/risor/compiler" @@ -82,7 +83,10 @@ func FromRisorLoader(ldr loader.Loader, opts ...Option) (*evaluator.Evaluator, e execUnitID = u.String() } - execUnit, err := script.NewExecutableUnit(cfg.handler, execUnitID, ldr, comp, provider) + // FromRisorLoader is a one-shot startup constructor; compile uses a + // fresh Background context. Callers needing cancellable compile + // should drive script.NewExecutableUnit directly. + execUnit, err := script.NewExecutableUnit(context.Background(), cfg.handler, execUnitID, ldr, comp, provider) if err != nil { return nil, err } diff --git a/engines/risor/new_test.go b/engines/risor/new_test.go index 0daa236..9a9829e 100644 --- a/engines/risor/new_test.go +++ b/engines/risor/new_test.go @@ -14,6 +14,7 @@ import ( "github.com/robbyt/go-polyscript/platform/data" "github.com/robbyt/go-polyscript/platform/script/loader" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -40,7 +41,7 @@ func newErrorLoader(t *testing.T, msg string) *loaderMock { mockURL, err := url.Parse("file:///test-risor-file.risor") require.NoError(t, err) mockLoader.On("GetSourceURL").Return(mockURL) - mockLoader.On("GetReader").Return(nil, errors.New(msg)) + mockLoader.On("GetReader", mock.Anything).Return(nil, errors.New(msg)) return mockLoader } @@ -133,7 +134,7 @@ func TestFromRisorLoader_DiskLoader(t *testing.T) { assert.Equal(t, "risor.Evaluator", eval.String()) // Verify content was loaded correctly. - reader, err := diskLoader.GetReader() + reader, err := diskLoader.GetReader(t.Context()) require.NoError(t, err) content, err := io.ReadAll(reader) require.NoError(t, err) diff --git a/engines/starlark/compiler/compiler.go b/engines/starlark/compiler/compiler.go index c467f38..fcdf481 100644 --- a/engines/starlark/compiler/compiler.go +++ b/engines/starlark/compiler/compiler.go @@ -1,6 +1,7 @@ package compiler import ( + "context" "fmt" "io" "log/slog" @@ -47,21 +48,23 @@ func (c *Compiler) String() string { } // Compile turns the provided script content into runnable bytecode. -func (c *Compiler) Compile(scriptReader io.ReadCloser) (script.ExecutableContent, error) { +// Starlark's underlying parser is synchronous; ctx is accepted for +// interface conformance but the compile step itself does not honor it. +func (c *Compiler) Compile(_ context.Context, scriptReader io.ReadCloser) (script.ExecutableContent, error) { if scriptReader == nil { return nil, ErrContentNil } + defer func() { + if err := scriptReader.Close(); err != nil { + c.logger.WithGroup("compile").Warn("failed to close script reader", "error", err) + } + }() scriptBodyBytes, err := io.ReadAll(scriptReader) if err != nil { return nil, fmt.Errorf("failed to read script: %w", err) } - err = scriptReader.Close() - if err != nil { - return nil, fmt.Errorf("failed to close reader: %w", err) - } - return c.compile(scriptBodyBytes) } diff --git a/engines/starlark/compiler/compiler_test.go b/engines/starlark/compiler/compiler_test.go index 68e60d8..b6d2ffd 100644 --- a/engines/starlark/compiler/compiler_test.go +++ b/engines/starlark/compiler/compiler_test.go @@ -157,7 +157,7 @@ main() } // Execute test - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.NoError(t, err, "Did not expect an error but got one") require.NotNil(t, execContent, "Expected execContent to be non-nil") require.Equal( @@ -225,7 +225,7 @@ main() } // Execute test - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.Error(t, err, "Expected an error but got none") require.Nil(t, execContent, "Expected execContent to be nil") require.ErrorIs(t, err, tt.err, "Expected error %v, got %v", tt.err, err) @@ -242,7 +242,7 @@ main() require.NoError(t, err) require.NotNil(t, comp, "Expected compiler to be non-nil") - execContent, err := comp.Compile(nil) + execContent, err := comp.Compile(t.Context(), nil) require.Error(t, err, "Expected an error but got none") require.Nil(t, execContent, "Expected execContent to be nil") require.ErrorIs(t, err, ErrContentNil, "Expected error to be ErrContentNil") @@ -258,7 +258,7 @@ main() // Create a reader that will return an error reader := &mockErrorReader{} - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.Error(t, err, "Expected an error but got none") require.Nil(t, execContent, "Expected execContent to be nil") require.Contains( @@ -269,26 +269,22 @@ main() ) }) - t.Run("close error", func(t *testing.T) { + t.Run("close error is silenced", func(t *testing.T) { comp, err := New( WithLogHandler(slog.NewTextHandler(os.Stdout, nil)), ) require.NoError(t, err) require.NotNil(t, comp, "Expected compiler to be non-nil") - // Create a reader that will return an error on close + // Close errors are logged via the deferred cleanup, not returned — + // a successful read+compile still produces a usable executable. reader := newMockScriptReaderCloser(`print("Hello, World!")`) reader.On("Close").Return(errors.New("test error")).Once() - execContent, err := comp.Compile(reader) - require.Error(t, err, "Expected an error but got none") - require.Nil(t, execContent, "Expected execContent to be nil") - require.Contains( - t, - err.Error(), - "failed to close reader", - "Expected error to contain 'failed to close reader'", - ) + execContent, err := comp.Compile(t.Context(), reader) + require.NoError(t, err) + require.NotNil(t, execContent) + reader.AssertExpectations(t) }) }) } @@ -332,7 +328,7 @@ func TestCompilerOptions(t *testing.T) { mockReader.On("Close").Return(nil) } - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.NoError(t, err) require.NotNil(t, execContent) }) @@ -352,7 +348,7 @@ func TestCompilerOptions(t *testing.T) { mockReader.On("Close").Return(nil) } - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.NoError(t, err) require.NotNil(t, execContent) }) @@ -370,7 +366,7 @@ func TestCompilerOptions(t *testing.T) { mockReader.On("Close").Return(nil) } - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.NoError(t, err) require.NotNil(t, execContent) }) @@ -393,7 +389,7 @@ func TestCompileError(t *testing.T) { require.NotNil(t, comp, "Expected compiler to be non-nil") // Execute test with nil reader - execContent, err := comp.Compile(nil) + execContent, err := comp.Compile(t.Context(), nil) require.Error(t, err, "Expected an error but got none") require.Nil(t, execContent, "Expected execContent to be nil") require.ErrorIs(t, err, ErrContentNil, "Expected error to be ErrContentNil") @@ -410,7 +406,7 @@ func TestCompileIOError(t *testing.T) { // Create a reader that will return an error reader := &mockErrorReader{} - execContent, err := comp.Compile(reader) + execContent, err := comp.Compile(t.Context(), reader) require.Error(t, err, "Expected an error but got none") require.Nil(t, execContent, "Expected execContent to be nil") require.Contains( @@ -421,27 +417,22 @@ func TestCompileIOError(t *testing.T) { ) } -func TestCompileCloseError(t *testing.T) { - // Test that we return the correct error when there's an error closing the reader +func TestCompileCloseErrorIsSilenced(t *testing.T) { + // Close errors are logged via the deferred cleanup, not returned — + // a successful read+compile still produces a usable executable. comp, err := New( WithLogHandler(slog.NewTextHandler(os.Stdout, nil)), ) require.NoError(t, err) require.NotNil(t, comp, "Expected compiler to be non-nil") - // Create a reader that will return an error on close reader := newMockScriptReaderCloser(`print("Hello, World!")`) reader.On("Close").Return(errors.New("test error")).Once() - execContent, err := comp.Compile(reader) - require.Error(t, err, "Expected an error but got none") - require.Nil(t, execContent, "Expected execContent to be nil") - require.Contains( - t, - err.Error(), - "failed to close reader", - "Expected error to contain 'failed to close reader'", - ) + execContent, err := comp.Compile(t.Context(), reader) + require.NoError(t, err) + require.NotNil(t, execContent) + reader.AssertExpectations(t) } func TestCompilerString(t *testing.T) { diff --git a/engines/starlark/evaluator/evaluator_test.go b/engines/starlark/evaluator/evaluator_test.go index 034bd04..1b649f4 100644 --- a/engines/starlark/evaluator/evaluator_test.go +++ b/engines/starlark/evaluator/evaluator_test.go @@ -42,6 +42,7 @@ func evalBuilder(t *testing.T, scriptContent string) (*script.ExecutableUnit, *E require.NoError(t, err, "Failed to create compiler") exe, err := script.NewExecutableUnit( + t.Context(), handler, scriptContent, loader, @@ -378,7 +379,8 @@ result = spin() select { case err := <-done: require.Error(t, err) - require.True(t, + require.True( + t, errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "context canceled") || strings.Contains(err.Error(), "cancel"), diff --git a/engines/starlark/evaluator/exec_helpers_test.go b/engines/starlark/evaluator/exec_helpers_test.go index b61f285..e4d1153 100644 --- a/engines/starlark/evaluator/exec_helpers_test.go +++ b/engines/starlark/evaluator/exec_helpers_test.go @@ -1,6 +1,7 @@ package evaluator import ( + "context" "io" "net/url" "strings" @@ -16,8 +17,8 @@ import ( // the loader's behavior, only that NewExecutableUnit can call GetReader(). type loaderMock struct{ mock.Mock } -func (m *loaderMock) GetReader() (io.ReadCloser, error) { - args := m.Called() +func (m *loaderMock) GetReader(ctx context.Context) (io.ReadCloser, error) { + args := m.Called(ctx) if args.Get(0) == nil { return nil, args.Error(1) } @@ -36,8 +37,8 @@ func (m *loaderMock) GetSourceURL() *url.URL { // ExecutableContent (which may be nil). type compilerMock struct{ mock.Mock } -func (m *compilerMock) Compile(r io.ReadCloser) (script.ExecutableContent, error) { - args := m.Called(r) +func (m *compilerMock) Compile(ctx context.Context, r io.ReadCloser) (script.ExecutableContent, error) { + args := m.Called(ctx, r) if args.Get(0) == nil { return nil, args.Error(1) } @@ -61,11 +62,11 @@ func newExe( id = t.Name() } ldr := new(loaderMock) - ldr.On("GetReader"). + ldr.On("GetReader", mock.Anything). Return(io.NopCloser(strings.NewReader("dummy")), nil) cmp := new(compilerMock) - cmp.On("Compile", mock.Anything).Return(content, nil) - exe, err := script.NewExecutableUnit(nil, id, ldr, cmp, provider) + cmp.On("Compile", mock.Anything, mock.Anything).Return(content, nil) + exe, err := script.NewExecutableUnit(t.Context(), nil, id, ldr, cmp, provider) require.NoError(t, err) return exe } diff --git a/engines/starlark/mock_loader_test.go b/engines/starlark/mock_loader_test.go index 311c2ae..5b2007a 100644 --- a/engines/starlark/mock_loader_test.go +++ b/engines/starlark/mock_loader_test.go @@ -1,6 +1,7 @@ package starlark import ( + "context" "io" "net/url" @@ -19,8 +20,8 @@ func (m *loaderMock) GetSourceURL() *url.URL { return args.Get(0).(*url.URL) } -func (m *loaderMock) GetReader() (io.ReadCloser, error) { - args := m.Called() +func (m *loaderMock) GetReader(ctx context.Context) (io.ReadCloser, error) { + args := m.Called(ctx) if args.Get(0) == nil { return nil, args.Error(1) } diff --git a/engines/starlark/new.go b/engines/starlark/new.go index 9b385d3..f68dd8b 100644 --- a/engines/starlark/new.go +++ b/engines/starlark/new.go @@ -30,6 +30,7 @@ package starlark import ( + "context" "fmt" "github.com/robbyt/go-polyscript/engines/starlark/compiler" @@ -82,7 +83,10 @@ func FromStarlarkLoader(ldr loader.Loader, opts ...Option) (*evaluator.Evaluator execUnitID = u.String() } - execUnit, err := script.NewExecutableUnit(cfg.handler, execUnitID, ldr, comp, provider) + // FromStarlarkLoader is a one-shot startup constructor; compile uses a + // fresh Background context. Callers needing cancellable compile + // should drive script.NewExecutableUnit directly. + execUnit, err := script.NewExecutableUnit(context.Background(), cfg.handler, execUnitID, ldr, comp, provider) if err != nil { return nil, err } diff --git a/engines/starlark/new_test.go b/engines/starlark/new_test.go index 12c2510..93af763 100644 --- a/engines/starlark/new_test.go +++ b/engines/starlark/new_test.go @@ -14,6 +14,7 @@ import ( "github.com/robbyt/go-polyscript/platform/data" "github.com/robbyt/go-polyscript/platform/script/loader" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -40,7 +41,7 @@ func newErrorLoader(t *testing.T, msg string) *loaderMock { mockURL, err := url.Parse("file:///test-starlark-file.star") require.NoError(t, err) mockLoader.On("GetSourceURL").Return(mockURL) - mockLoader.On("GetReader").Return(nil, errors.New(msg)) + mockLoader.On("GetReader", mock.Anything).Return(nil, errors.New(msg)) return mockLoader } @@ -137,7 +138,7 @@ func TestFromStarlarkLoader_DiskLoader(t *testing.T) { require.NotNil(t, eval) assert.Equal(t, "starlark.Evaluator", eval.String()) - reader, err := diskLoader.GetReader() + reader, err := diskLoader.GetReader(t.Context()) require.NoError(t, err) content, err := io.ReadAll(reader) require.NoError(t, err) diff --git a/platform/script/compiler.go b/platform/script/compiler.go index 2d15d3c..50c2f6f 100644 --- a/platform/script/compiler.go +++ b/platform/script/compiler.go @@ -1,29 +1,18 @@ package script -import "io" +import ( + "context" + "io" +) // Compiler defines the interface for validating scripts before execution. // It checks syntax and semantics, and may perform parsing, compilation, // and optimization. A valid script is returned as ExecutableContent. -// -// Example usage: -// -// var comp Compiler = NewRisorCompiler(globals) -// executableContent, err := comp.Compile(&scriptContent) -// if err != nil { -// // Handle validation error -// } -// // Use executableContent for execution type Compiler interface { // Compile checks if a script is valid and returns it as executable content. // The returned ExecutableContent contains the validated and possibly compiled - // script ready for execution. - // - // Parameters: - // - script: A pointer to the script content string - // - // Returns: - // - ExecutableContent: The validated script - // - error: Details about validation failures (syntax errors, undefined globals) - Compile(scriptReader io.ReadCloser) (ExecutableContent, error) + // script ready for execution. Implementations whose underlying parser is + // ctx-aware (Risor, Extism) honor cancellation; Starlark's parser is + // synchronous and ignores ctx. + Compile(ctx context.Context, scriptReader io.ReadCloser) (ExecutableContent, error) } diff --git a/platform/script/compiler_test.go b/platform/script/compiler_test.go index 273b284..55003de 100644 --- a/platform/script/compiler_test.go +++ b/platform/script/compiler_test.go @@ -79,10 +79,10 @@ func TestCompiler(t *testing.T) { reader.On("Close").Return(nil).Maybe() // Set expectations - mockCompiler.On("Compile", reader).Return(tt.mockReturn, tt.mockError) + mockCompiler.On("Compile", mock.Anything, reader).Return(tt.mockReturn, tt.mockError) // Execute test - result, err := mockCompiler.Compile(reader) + result, err := mockCompiler.Compile(t.Context(), reader) // Verify results if tt.expectError { diff --git a/platform/script/executableUnit.go b/platform/script/executableUnit.go index 7e6479a..6a281b0 100644 --- a/platform/script/executableUnit.go +++ b/platform/script/executableUnit.go @@ -1,6 +1,7 @@ package script import ( + "context" "errors" "fmt" "log/slog" @@ -32,7 +33,11 @@ type ExecutableUnit struct { // NewExecutableUnit creates a new ExecutableUnit from the provided loader and compiler. // The dataProvider parameter provides runtime data for script evaluation. // A nil handler is permitted and inherits from slog.Default. +// The supplied ctx flows into the loader (for IO cancellation) and the +// compiler (for parse/compile cancellation); it is not retained on the +// returned unit. func NewExecutableUnit( + ctx context.Context, handler slog.Handler, versionID string, scriptLoader loader.Loader, @@ -45,12 +50,12 @@ func NewExecutableUnit( return nil, errors.New("compiler is nil") } - reader, err := scriptLoader.GetReader() + reader, err := scriptLoader.GetReader(ctx) if err != nil { return nil, fmt.Errorf("failed to get reader from loader: %w", err) } - exe, err := compiler.Compile(reader) + exe, err := compiler.Compile(ctx, reader) if err != nil { return nil, fmt.Errorf("compiler failed: %w", err) } diff --git a/platform/script/executableUnit_test.go b/platform/script/executableUnit_test.go index 667c578..2c69197 100644 --- a/platform/script/executableUnit_test.go +++ b/platform/script/executableUnit_test.go @@ -1,6 +1,7 @@ package script import ( + "context" "errors" "fmt" "io" @@ -24,8 +25,11 @@ type mockLoader struct { mock.Mock } -func (m *mockLoader) GetReader() (io.ReadCloser, error) { - args := m.Called() +func (m *mockLoader) GetReader(ctx context.Context) (io.ReadCloser, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } return args.Get(0).(io.ReadCloser), args.Error(1) } @@ -107,19 +111,20 @@ func TestNewVersion(t *testing.T) { lod, err := loader.NewFromString(scriptContent) require.NoError(t, err, "Expected no error when creating loader") - reader, err := lod.GetReader() + reader, err := lod.GetReader(t.Context()) require.NoError(t, err, "Expected no error when getting reader") // Create mock loader instead of real loader mockLoader := new(mockLoader) - mockLoader.On("GetReader").Return(reader, nil) + mockLoader.On("GetReader", mock.Anything).Return(reader, nil) // Setup mock compiler with same reader instance comp := new(MockCompiler) - comp.On("Compile", reader).Return(&MockExecutableContent{}, nil) + comp.On("Compile", mock.Anything, reader).Return(&MockExecutableContent{}, nil) // Create executable unit exe, err := NewExecutableUnit( + t.Context(), logHandler, t.Name(), mockLoader, @@ -140,14 +145,15 @@ func TestNewVersion(t *testing.T) { lod, err := loader.NewFromString(scriptBody) require.NoError(t, err, "Expected no error when creating a new loader with valid content") - reader, err := lod.GetReader() + reader, err := lod.GetReader(t.Context()) require.NoError(t, err, "Expected no error when getting reader") comp := new(MockCompiler) mockContent := new(MockExecutableContent) - comp.On("Compile", reader).Return(mockContent, nil).Once() + comp.On("Compile", mock.Anything, reader).Return(mockContent, nil).Once() exe, err := NewExecutableUnit( + t.Context(), logHandler, t.Name(), lod, @@ -181,14 +187,15 @@ func TestNewVersion(t *testing.T) { lod, err := loader.NewFromString(scriptBody) require.NoError(t, err, "Expected no error when creating a new loader with empty content") - reader, err := lod.GetReader() + reader, err := lod.GetReader(t.Context()) require.NoError(t, err, "Expected no error when getting reader") comp := new(MockCompiler) validationError := errors.New("validation failed") - comp.On("Compile", reader).Return(nil, validationError).Once() + comp.On("Compile", mock.Anything, reader).Return(nil, validationError).Once() exe, err := NewExecutableUnit( + t.Context(), logHandler, t.Name(), lod, @@ -210,22 +217,23 @@ func TestNewVersion(t *testing.T) { lod, err := loader.NewFromString(scriptContent) require.NoError(t, err, "Expected no error when creating loader") - reader, err := lod.GetReader() + reader, err := lod.GetReader(t.Context()) require.NoError(t, err, "Expected no error when getting reader") // Create mock loader instead of real loader mockLoader := new(mockLoader) - mockLoader.On("GetReader").Return(reader, nil) + mockLoader.On("GetReader", mock.Anything).Return(reader, nil) // Setup mock compiler with same reader instance mockCompiler := new(MockCompiler) mockContent := new(MockExecutableContent) // Add expectation for GetSource mockContent.On("GetSource").Return(scriptContent) - mockCompiler.On("Compile", reader).Return(mockContent, nil) + mockCompiler.On("Compile", mock.Anything, reader).Return(mockContent, nil) // Create executable unit with empty ID exe, err := NewExecutableUnit( + t.Context(), logHandler, "", mockLoader, @@ -252,6 +260,7 @@ func TestNewVersion(t *testing.T) { t.Run("NilCompiler", func(t *testing.T) { exe, err := NewExecutableUnit( + t.Context(), logHandler, "test", &mockLoader{}, @@ -276,13 +285,14 @@ func TestNewVersion(t *testing.T) { // Setup mock loader with source URL mockLoader := new(mockLoader) - mockLoader.On("GetReader").Return(mockReader, nil) + mockLoader.On("GetReader", mock.Anything).Return(mockReader, nil) // Setup mock compiler with expected error mockCompiler := new(MockCompiler) - mockCompiler.On("Compile", mockReader).Return(nil, errors.New("empty content")) + mockCompiler.On("Compile", mock.Anything, mockReader).Return(nil, errors.New("empty content")) // Create executable unit exe, err := NewExecutableUnit( + t.Context(), logHandler, "test", mockLoader, @@ -299,12 +309,13 @@ func TestNewVersion(t *testing.T) { }) t.Run("GetReaderError", func(t *testing.T) { - mockReader := new(mockReadCloser) - + // Real loaders follow the Go convention: nil reader on error. + // Modeling it that way here keeps the mock contract honest. mockLoader := new(mockLoader) - mockLoader.On("GetReader").Return(mockReader, errors.New("get reader error")).Once() + mockLoader.On("GetReader", mock.Anything).Return(nil, errors.New("get reader error")).Once() exe, err := NewExecutableUnit( + t.Context(), logHandler, "test", mockLoader, @@ -313,8 +324,6 @@ func TestNewVersion(t *testing.T) { ) require.Error(t, err) require.Nil(t, exe) - - mockReader.AssertExpectations(t) mockLoader.AssertExpectations(t) }) @@ -324,14 +333,15 @@ func TestNewVersion(t *testing.T) { // Setup mock loader mockLoader := new(mockLoader) - mockLoader.On("GetReader").Return(mockReader, nil).Once() + mockLoader.On("GetReader", mock.Anything).Return(mockReader, nil).Once() // Setup mock compiler with same reader instance mockCompiler := new(MockCompiler) - mockCompiler.On("Compile", mockReader).Return(nil, errors.New("compile failed")).Once() + mockCompiler.On("Compile", mock.Anything, mockReader).Return(nil, errors.New("compile failed")).Once() // Create executable unit exe, err := NewExecutableUnit( + t.Context(), logHandler, "test", mockLoader, @@ -393,10 +403,10 @@ func TestNewVersionWithScriptData(t *testing.T) { lod, err := loader.NewFromString(scriptContent) require.NoError(t, err, "Expected no error when creating loader") - reader, err := lod.GetReader() + reader, err := lod.GetReader(t.Context()) require.NoError(t, err, "Expected no error when getting reader") - mockCompiler.On("Compile", reader).Return(mockContent, nil).Once() + mockCompiler.On("Compile", mock.Anything, reader).Return(mockContent, nil).Once() // Create loader directly from string instead of file loader, err := loader.NewFromString(scriptContent) @@ -410,6 +420,7 @@ func TestNewVersionWithScriptData(t *testing.T) { // Create executable unit exe, err := NewExecutableUnit( + t.Context(), logHandler, t.Name(), loader, @@ -434,15 +445,16 @@ func TestNewVersionWithScriptData(t *testing.T) { lod, err := loader.NewFromString(scriptBody) require.NoError(t, err, "Expected no error when creating a new loader with valid content") - reader, err := lod.GetReader() + reader, err := lod.GetReader(t.Context()) require.NoError(t, err, "Expected no error when getting reader") comp := new(MockCompiler) mockContent := new(MockExecutableContent) - comp.On("Compile", reader).Return(mockContent, nil).Once() + comp.On("Compile", mock.Anything, reader).Return(mockContent, nil).Once() exe, err := NewExecutableUnit( + t.Context(), logHandler, t.Name(), lod, @@ -474,16 +486,16 @@ func TestNewExecutableUnit_NilHandler(t *testing.T) { lod, err := loader.NewFromString("any content") require.NoError(t, err) - reader, err := lod.GetReader() + reader, err := lod.GetReader(t.Context()) require.NoError(t, err) mockLdr := new(mockLoader) - mockLdr.On("GetReader").Return(reader, nil) + mockLdr.On("GetReader", mock.Anything).Return(reader, nil) comp := new(MockCompiler) - comp.On("Compile", reader).Return(&MockExecutableContent{}, nil) + comp.On("Compile", mock.Anything, reader).Return(&MockExecutableContent{}, nil) - exe, err := NewExecutableUnit(nil, t.Name(), mockLdr, comp, data.NewStaticProvider(emptyScriptData)) + exe, err := NewExecutableUnit(t.Context(), nil, t.Name(), mockLdr, comp, data.NewStaticProvider(emptyScriptData)) require.NoError(t, err) require.NotNil(t, exe) require.Equal(t, t.Name(), exe.GetID()) diff --git a/platform/script/loader/fromBytes.go b/platform/script/loader/fromBytes.go index 51661a9..205fffe 100644 --- a/platform/script/loader/fromBytes.go +++ b/platform/script/loader/fromBytes.go @@ -2,6 +2,7 @@ package loader import ( "bytes" + "context" "fmt" "io" "net/url" @@ -46,8 +47,9 @@ func (l *FromBytes) String() string { return fmt.Sprintf("loader.FromBytes{Bytes: %d}", len(l.content)) } -// GetReader returns a new reader for the stored content. -func (l *FromBytes) GetReader() (io.ReadCloser, error) { +// GetReader returns a new reader for the stored content. The ctx is +// unused: bytes are already in memory. +func (l *FromBytes) GetReader(_ context.Context) (io.ReadCloser, error) { return io.NopCloser(bytes.NewReader(l.content)), nil } diff --git a/platform/script/loader/fromBytes_test.go b/platform/script/loader/fromBytes_test.go index bdc1e11..f371f2f 100644 --- a/platform/script/loader/fromBytes_test.go +++ b/platform/script/loader/fromBytes_test.go @@ -100,7 +100,7 @@ func TestFromBytes_GetReader(t *testing.T) { loader, err := NewFromBytes(content) require.NoError(t, err) - reader, err := loader.GetReader() + reader, err := loader.GetReader(t.Context()) require.NoError(t, err) data, err := io.ReadAll(reader) @@ -115,7 +115,7 @@ func TestFromBytes_GetReader(t *testing.T) { require.NoError(t, err) // First read - reader1, err := loader.GetReader() + reader1, err := loader.GetReader(t.Context()) require.NoError(t, err) data1, err := io.ReadAll(reader1) require.NoError(t, err) @@ -123,7 +123,7 @@ func TestFromBytes_GetReader(t *testing.T) { require.NoError(t, reader1.Close()) // Second read - reader2, err := loader.GetReader() + reader2, err := loader.GetReader(t.Context()) require.NoError(t, err) data2, err := io.ReadAll(reader2) require.NoError(t, err) @@ -136,7 +136,7 @@ func TestFromBytes_GetReader(t *testing.T) { loader, err := NewFromBytes(content) require.NoError(t, err) - reader, err := loader.GetReader() + reader, err := loader.GetReader(t.Context()) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, reader.Close(), "Failed to close reader") @@ -160,7 +160,7 @@ func TestFromBytes_GetReader(t *testing.T) { loader, err := NewFromBytes(content) require.NoError(t, err) - reader, err := loader.GetReader() + reader, err := loader.GetReader(t.Context()) require.NoError(t, err) data, err := io.ReadAll(reader) require.NoError(t, err) diff --git a/platform/script/loader/fromDisk.go b/platform/script/loader/fromDisk.go index 4797297..cfd5102 100644 --- a/platform/script/loader/fromDisk.go +++ b/platform/script/loader/fromDisk.go @@ -1,6 +1,7 @@ package loader import ( + "context" "fmt" "io" "log/slog" @@ -61,7 +62,9 @@ func (l *FromDisk) String() string { noChkSum := fmt.Sprintf("loader.FromDisk{Path: %s}", l.path) if l.sourceURL != nil { - reader, err := l.GetReader() + // Background ctx is fine here: String() is a synchronous debug + // helper, and FromDisk.GetReader's os.Open is synchronous anyway. + reader, err := l.GetReader(context.Background()) if err != nil { return noChkSum } @@ -86,8 +89,9 @@ func (l *FromDisk) String() string { return fmt.Sprintf("loader.FromDisk{Path: %s, SHA256: %s}", l.path, chksum) } -func (l *FromDisk) GetReader() (io.ReadCloser, error) { - // Just return a reader for the file +// GetReader opens the backing file. ctx is accepted for interface +// conformance; os.Open itself is synchronous and doesn't honor it. +func (l *FromDisk) GetReader(_ context.Context) (io.ReadCloser, error) { return os.Open(l.sourceURL.Path) } diff --git a/platform/script/loader/fromDisk_test.go b/platform/script/loader/fromDisk_test.go index 6587790..0d0a303 100644 --- a/platform/script/loader/fromDisk_test.go +++ b/platform/script/loader/fromDisk_test.go @@ -156,7 +156,7 @@ func TestFromDisk_GetReader(t *testing.T) { require.NoError(t, err, "Failed to create loader") // Get and read from reader - reader, err := loader.GetReader() + reader, err := loader.GetReader(t.Context()) require.NoError(t, err, "Failed to get reader") verifyReaderContent(t, reader, testContent) @@ -185,7 +185,7 @@ func TestFromDisk_GetReader(t *testing.T) { loader, err := NewFromDisk(nonExistingFile) require.NoError(t, err) - reader, err := loader.GetReader() + reader, err := loader.GetReader(t.Context()) require.Error(t, err) require.Nil(t, reader) require.Contains(t, err.Error(), "no such file or directory") diff --git a/platform/script/loader/fromHTTP.go b/platform/script/loader/fromHTTP.go index 37186f3..169ab99 100644 --- a/platform/script/loader/fromHTTP.go +++ b/platform/script/loader/fromHTTP.go @@ -225,22 +225,13 @@ func NewFromHTTPWithOptions(rawURL string, options *HTTPOptions) (*FromHTTP, err }, nil } -// GetReader returns a reader for the HTTP content. -// This method is part of the Loader interface and is used internally by -// the polyscript system to fetch the script content. +// GetReader returns a reader for the HTTP content. The ctx flows into +// the HTTP request and the authenticator so a cancelled ctx aborts the +// fetch. // // The returned io.ReadCloser must be closed by the caller when done. // HTTP errors are handled and converted to appropriate error types. -func (l *FromHTTP) GetReader() (io.ReadCloser, error) { - return l.GetReaderWithContext(context.Background()) -} - -// GetReaderWithContext returns a reader for the HTTP content with context support. -// This allows for request cancellation and timeouts via context. -// -// The returned io.ReadCloser must be closed by the caller when done. -// HTTP errors are handled and converted to appropriate error types. -func (l *FromHTTP) GetReaderWithContext(ctx context.Context) (io.ReadCloser, error) { +func (l *FromHTTP) GetReader(ctx context.Context) (io.ReadCloser, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, l.url, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) @@ -342,9 +333,8 @@ func (l *FromHTTP) GetSourceURL() *url.URL { // it never performs a network request. // // The "SHA256: " form is included only after the body has been -// fetched at least once via [FromHTTP.GetReader] (or -// [FromHTTP.GetReaderWithContext]) and the response was buffered through -// [FromHTTP.cappedBody]. Until then — or when [HTTPOptions.MaxBodySize] +// fetched at least once via [FromHTTP.GetReader] and the response was +// buffered through [FromHTTP.cappedBody]. Until then — or when [HTTPOptions.MaxBodySize] // is negative (which disables buffering) — String returns the // no-checksum form. func (l *FromHTTP) String() string { diff --git a/platform/script/loader/fromHTTP_test.go b/platform/script/loader/fromHTTP_test.go index f37833d..530dbec 100644 --- a/platform/script/loader/fromHTTP_test.go +++ b/platform/script/loader/fromHTTP_test.go @@ -379,7 +379,7 @@ func TestFromHTTP_GetReader(t *testing.T) { require.NoError(t, err) // Use real server with helper - reader, err := loader.GetReader() + reader, err := loader.GetReader(t.Context()) require.NoError(t, err) verifyReaderContent(t, reader, testScript) }) @@ -397,7 +397,7 @@ func TestFromHTTP_GetReader(t *testing.T) { loader, err := NewFromHTTP(testURL) require.NoError(t, err) - reader, err := loader.GetReader() + reader, err := loader.GetReader(t.Context()) require.Error(t, err) require.Contains(t, err.Error(), "HTTP 401") require.Nil(t, reader) @@ -416,7 +416,7 @@ func TestFromHTTP_GetReader(t *testing.T) { loader, err := NewFromHTTP(testURL) require.NoError(t, err) - reader, err := loader.GetReader() + reader, err := loader.GetReader(t.Context()) require.Error(t, err) require.Contains(t, err.Error(), "HTTP 404") require.Nil(t, reader) @@ -437,14 +437,14 @@ func TestFromHTTP_GetReader(t *testing.T) { } loader.client = mockClient - reader, err := loader.GetReader() + reader, err := loader.GetReader(t.Context()) require.Error(t, err) require.Contains(t, err.Error(), "failed to execute HTTP request") require.Nil(t, reader) }) } -func TestFromHTTP_GetReaderWithContext(t *testing.T) { +func TestFromHTTP_GetReaderCancellation(t *testing.T) { t.Parallel() const testScript = FunctionContent @@ -463,7 +463,7 @@ func TestFromHTTP_GetReaderWithContext(t *testing.T) { require.NoError(t, err) ctx := t.Context() - reader, err := loader.GetReaderWithContext(ctx) + reader, err := loader.GetReader(ctx) require.NoError(t, err) require.NotNil(t, reader) verifyReaderContent(t, reader, testScript) @@ -496,7 +496,7 @@ func TestFromHTTP_GetReaderWithContext(t *testing.T) { } loader.client = mockClient - reader, err := loader.GetReaderWithContext(ctx) + reader, err := loader.GetReader(ctx) require.Error(t, err) require.Contains(t, err.Error(), "context canceled") require.Nil(t, reader) @@ -532,7 +532,7 @@ func TestFromHTTP_GetReaderWithContext(t *testing.T) { } loader.client = mockClient - reader, err := loader.GetReaderWithContext(ctx) + reader, err := loader.GetReader(ctx) require.Error(t, err) require.Contains(t, err.Error(), "context deadline exceeded") require.Nil(t, reader) @@ -559,7 +559,7 @@ func TestFromHTTP_String(t *testing.T) { // Issue #96: String() no longer fetches on its own. The SHA is // cached during the production GetReader path; drain a reader // here to populate it. - r, err := loader.GetReader() + r, err := loader.GetReader(t.Context()) require.NoError(t, err) _, copyErr := io.Copy(io.Discard, r) require.NoError(t, copyErr) @@ -646,7 +646,7 @@ func TestFromHTTP_String_NoNetworkRoundTrip(t *testing.T) { loader, err := NewFromHTTP(server.URL + "/script.js") require.NoError(t, err) - r, err := loader.GetReader() + r, err := loader.GetReader(t.Context()) require.NoError(t, err) _, copyErr := io.Copy(io.Discard, r) require.NoError(t, copyErr) @@ -680,7 +680,7 @@ func TestFromHTTP_String_NoNetworkRoundTrip(t *testing.T) { }() go func() { defer wg.Done() - r, err := loader.GetReader() + r, err := loader.GetReader(t.Context()) if err != nil { return } @@ -706,7 +706,7 @@ func TestFromHTTP_String_NoNetworkRoundTrip(t *testing.T) { loader, err := NewFromHTTPWithOptions(server.URL+"/script.js", opts) require.NoError(t, err) - r, err := loader.GetReader() + r, err := loader.GetReader(t.Context()) require.NoError(t, err) _, copyErr := io.Copy(io.Discard, r) require.NoError(t, copyErr) @@ -836,7 +836,7 @@ func TestFromHTTP_MaxBodySize(t *testing.T) { loader, err := NewFromHTTPWithOptions(server.URL+"/s", opts) require.NoError(t, err) - reader, err := loader.GetReader() + reader, err := loader.GetReader(t.Context()) require.NoError(t, err) require.NotNil(t, reader) require.NoError(t, reader.Close()) @@ -850,7 +850,7 @@ func TestFromHTTP_MaxBodySize(t *testing.T) { loader, err := NewFromHTTPWithOptions(server.URL+"/s", opts) require.NoError(t, err) - reader, err := loader.GetReader() + reader, err := loader.GetReader(t.Context()) require.NoError(t, err) require.NotNil(t, reader) require.NoError(t, reader.Close()) @@ -864,7 +864,7 @@ func TestFromHTTP_MaxBodySize(t *testing.T) { loader, err := NewFromHTTPWithOptions(server.URL+"/s", opts) require.NoError(t, err) - reader, err := loader.GetReader() + reader, err := loader.GetReader(t.Context()) require.Error(t, err) require.Nil(t, reader) require.ErrorIs(t, err, ErrScriptTooLarge) @@ -880,7 +880,7 @@ func TestFromHTTP_MaxBodySize(t *testing.T) { loader, err := NewFromHTTPWithOptions(server.URL+"/s", opts) require.NoError(t, err) - reader, err := loader.GetReader() + reader, err := loader.GetReader(t.Context()) require.NoError(t, err) require.NotNil(t, reader) require.NoError(t, reader.Close()) @@ -902,7 +902,7 @@ func TestFromHTTP_MaxBodySize(t *testing.T) { loader, err := NewFromHTTPWithOptions(server.URL+"/s", opts) require.NoError(t, err) - reader, err := loader.GetReader() + reader, err := loader.GetReader(t.Context()) require.Error(t, err) require.Nil(t, reader) require.ErrorIs(t, err, ErrScriptTooLarge) @@ -916,7 +916,7 @@ func TestFromHTTP_MaxBodySize(t *testing.T) { loader, err := NewFromHTTP(server.URL + "/s") require.NoError(t, err) - reader, err := loader.GetReader() + reader, err := loader.GetReader(t.Context()) require.Error(t, err) require.Nil(t, reader) require.ErrorIs(t, err, ErrScriptTooLarge) diff --git a/platform/script/loader/fromIoReader.go b/platform/script/loader/fromIoReader.go index 4be0a4c..8e21e80 100644 --- a/platform/script/loader/fromIoReader.go +++ b/platform/script/loader/fromIoReader.go @@ -2,6 +2,7 @@ package loader import ( "bytes" + "context" "fmt" "io" "net/url" @@ -67,8 +68,9 @@ func (l *FromIoReader) String() string { ) } -// GetReader returns a new reader for the stored content. -func (l *FromIoReader) GetReader() (io.ReadCloser, error) { +// GetReader returns a new reader for the stored content. The ctx is +// unused: the bytes were buffered at construction. +func (l *FromIoReader) GetReader(_ context.Context) (io.ReadCloser, error) { return io.NopCloser(bytes.NewReader(l.content)), nil } diff --git a/platform/script/loader/fromIoReader_test.go b/platform/script/loader/fromIoReader_test.go index 075a9a0..8caf5bb 100644 --- a/platform/script/loader/fromIoReader_test.go +++ b/platform/script/loader/fromIoReader_test.go @@ -77,7 +77,7 @@ func TestNewFromIoReader(t *testing.T) { } // Use GetReader to verify the content - reader, err := loader.GetReader() + reader, err := loader.GetReader(t.Context()) require.NoError(t, err) content, err := io.ReadAll(reader) require.NoError(t, err) @@ -144,7 +144,7 @@ func TestFromIoReader_GetReader(t *testing.T) { require.NoError(t, err) // First read - reader1, err := loader.GetReader() + reader1, err := loader.GetReader(t.Context()) require.NoError(t, err) content1, err := io.ReadAll(reader1) require.NoError(t, err) @@ -152,7 +152,7 @@ func TestFromIoReader_GetReader(t *testing.T) { require.NoError(t, reader1.Close()) // Second read should work the same way - reader2, err := loader.GetReader() + reader2, err := loader.GetReader(t.Context()) require.NoError(t, err) content2, err := io.ReadAll(reader2) require.NoError(t, err) @@ -166,9 +166,9 @@ func TestFromIoReader_GetReader(t *testing.T) { require.NoError(t, err) // Get two readers - reader1, err := loader.GetReader() + reader1, err := loader.GetReader(t.Context()) require.NoError(t, err) - reader2, err := loader.GetReader() + reader2, err := loader.GetReader(t.Context()) require.NoError(t, err) // Read partial content from first reader diff --git a/platform/script/loader/fromString.go b/platform/script/loader/fromString.go index 630b980..dcc0ab4 100644 --- a/platform/script/loader/fromString.go +++ b/platform/script/loader/fromString.go @@ -1,6 +1,7 @@ package loader import ( + "context" "encoding/base64" "fmt" "io" @@ -60,7 +61,9 @@ func (l *FromString) String() string { return fmt.Sprintf("loader.FromString{Chars: %d}", len(l.content)) } -func (l *FromString) GetReader() (io.ReadCloser, error) { +// GetReader returns a new reader over the stored string. The ctx is +// unused: the string is already in memory. +func (l *FromString) GetReader(_ context.Context) (io.ReadCloser, error) { return io.NopCloser(strings.NewReader(l.content)), nil } diff --git a/platform/script/loader/fromString_test.go b/platform/script/loader/fromString_test.go index 56035ed..2d273b6 100644 --- a/platform/script/loader/fromString_test.go +++ b/platform/script/loader/fromString_test.go @@ -99,7 +99,7 @@ func TestFromString_GetReader(t *testing.T) { loader, err := NewFromString(content) require.NoError(t, err) - reader, err := loader.GetReader() + reader, err := loader.GetReader(t.Context()) require.NoError(t, err) verifyReaderContent(t, reader, content) @@ -118,7 +118,7 @@ func TestFromString_GetReader(t *testing.T) { loader, err := NewFromString(content) require.NoError(t, err) - reader, err := loader.GetReader() + reader, err := loader.GetReader(t.Context()) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, reader.Close(), "Failed to close reader") @@ -217,7 +217,7 @@ func TestNewFromStringBase64(t *testing.T) { loader, err := NewFromStringBase64(encodedScript) require.NoError(t, err) - reader, err := loader.GetReader() + reader, err := loader.GetReader(t.Context()) require.NoError(t, err) verifyReaderContent(t, reader, script) @@ -227,7 +227,7 @@ func TestNewFromStringBase64(t *testing.T) { loader, err := NewFromStringBase64(script) require.NoError(t, err) - reader, err := loader.GetReader() + reader, err := loader.GetReader(t.Context()) require.NoError(t, err) verifyReaderContent(t, reader, script) @@ -250,7 +250,7 @@ func TestNewFromStringBase64(t *testing.T) { loader, err := NewFromStringBase64(contentWithWhitespace) require.NoError(t, err) - reader, err := loader.GetReader() + reader, err := loader.GetReader(t.Context()) require.NoError(t, err) verifyReaderContent(t, reader, script) diff --git a/platform/script/loader/inference_test.go b/platform/script/loader/inference_test.go index e703e23..81fa73c 100644 --- a/platform/script/loader/inference_test.go +++ b/platform/script/loader/inference_test.go @@ -368,7 +368,7 @@ func TestInferFromString(t *testing.T) { assert.IsType(t, tc.expectedType, result) // Verify content - reader, err := result.GetReader() + reader, err := result.GetReader(t.Context()) require.NoError(t, err) defer func() { assert.NoError(t, reader.Close()) @@ -423,7 +423,7 @@ func TestInferLoader_Integration(t *testing.T) { inferredLoader, err := InferLoader(tc.input) require.NoError(t, err) - reader, err := inferredLoader.GetReader() + reader, err := inferredLoader.GetReader(t.Context()) require.NoError(t, err) defer func() { @@ -488,7 +488,7 @@ func TestInferLoader_Integration(t *testing.T) { assert.IsType(t, (*FromDisk)(nil), result) // Verify the content can be read - reader, err := result.GetReader() + reader, err := result.GetReader(t.Context()) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, reader.Close()) @@ -560,7 +560,7 @@ process()`, ) // Verify content can be read correctly - reader, err := result.GetReader() + reader, err := result.GetReader(t.Context()) require.NoError(t, err) defer func() { assert.NoError(t, reader.Close()) diff --git a/platform/script/loader/loader.go b/platform/script/loader/loader.go index 8d7b639..cb3185f 100644 --- a/platform/script/loader/loader.go +++ b/platform/script/loader/loader.go @@ -1,12 +1,17 @@ package loader import ( + "context" "io" "net/url" ) // Loader is an interface used by the engines to load scripts or binaries. +// GetReader accepts a context so loaders performing cancellable I/O (HTTP) +// can honor caller cancellation. In-memory loaders ignore ctx; FromDisk +// accepts ctx for interface conformance but os.Open itself is synchronous +// and does not observe it. type Loader interface { - GetReader() (io.ReadCloser, error) + GetReader(ctx context.Context) (io.ReadCloser, error) GetSourceURL() *url.URL } diff --git a/platform/script/loader/loader_test.go b/platform/script/loader/loader_test.go index 3a08b8e..e423402 100644 --- a/platform/script/loader/loader_test.go +++ b/platform/script/loader/loader_test.go @@ -47,7 +47,7 @@ func verifyLoader(t *testing.T, loader Loader, expectedURLString string) { } // Test getting a reader - reader, err := loader.GetReader() + reader, err := loader.GetReader(t.Context()) if err == nil { // If no error, verify reader works and cleanup require.NotNil(t, reader) @@ -78,7 +78,7 @@ func verifyMultipleReads(t *testing.T, loader Loader, expectedContent string) { t.Helper() // First read - reader1, err := loader.GetReader() + reader1, err := loader.GetReader(t.Context()) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, reader1.Close(), "Failed to close first reader") @@ -89,7 +89,7 @@ func verifyMultipleReads(t *testing.T, loader Loader, expectedContent string) { require.Equal(t, expectedContent, string(content1)) // Second read - reader2, err := loader.GetReader() + reader2, err := loader.GetReader(t.Context()) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, reader2.Close(), "Failed to close second reader") diff --git a/platform/script/loader/mock_helpers_test.go b/platform/script/loader/mock_helpers_test.go index caa3490..4f3bbac 100644 --- a/platform/script/loader/mock_helpers_test.go +++ b/platform/script/loader/mock_helpers_test.go @@ -2,6 +2,7 @@ package loader import ( "bytes" + "context" "io" "net/url" @@ -21,8 +22,8 @@ func (m *MockLoader) GetSourceURL() *url.URL { return args.Get(0).(*url.URL) } -func (m *MockLoader) GetReader() (io.ReadCloser, error) { - args := m.Called() +func (m *MockLoader) GetReader(ctx context.Context) (io.ReadCloser, error) { + args := m.Called(ctx) if args.Get(0) == nil { return nil, args.Error(1) } @@ -37,6 +38,6 @@ func (m *MockLoader) Close() error { // Helper method to easily create a mock with content func NewMockLoaderWithContent(content []byte) *MockLoader { m := new(MockLoader) - m.On("GetReader").Return(io.NopCloser(bytes.NewReader(content)), nil) + m.On("GetReader", mock.Anything).Return(io.NopCloser(bytes.NewReader(content)), nil) return m } diff --git a/platform/script/mocks_test.go b/platform/script/mocks_test.go index e50a9f1..415affb 100644 --- a/platform/script/mocks_test.go +++ b/platform/script/mocks_test.go @@ -1,6 +1,7 @@ package script import ( + "context" "io" engineTypes "github.com/robbyt/go-polyscript/engines/types" @@ -13,8 +14,8 @@ type MockCompiler struct { } // Compile mocks the Compile method of the Compiler interface. -func (m *MockCompiler) Compile(scriptReader io.ReadCloser) (ExecutableContent, error) { - args := m.Called(scriptReader) +func (m *MockCompiler) Compile(ctx context.Context, scriptReader io.ReadCloser) (ExecutableContent, error) { + args := m.Called(ctx, scriptReader) execContent, ok := args.Get(0).(ExecutableContent) if !ok { return nil, args.Error(1)