diff --git a/Cargo.lock b/Cargo.lock index 4178511..43a5a52 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -197,6 +197,14 @@ dependencies = [ "wit-component", ] +[[package]] +name = "example-variants" +version = "0.0.2" +dependencies = [ + "wit-bindgen", + "wit-component", +] + [[package]] name = "foldhash" version = "0.2.0" diff --git a/cmd/gravity/src/codegen/exports.rs b/cmd/gravity/src/codegen/exports.rs index a9f42c6..b414168 100644 --- a/cmd/gravity/src/codegen/exports.rs +++ b/cmd/gravity/src/codegen/exports.rs @@ -36,7 +36,7 @@ impl<'a> ExportGenerator<'a> { .params .iter() .map( - |Param { name, ty, .. }| match crate::resolve_type(ty, self.config.resolve) { + |Param { name, ty, .. }| match crate::resolve_param_type(ty, self.config.resolve) { GoType::ValueOrOk(t) => (GoIdentifier::local(name), *t), t => (GoIdentifier::local(name), t), }, diff --git a/cmd/gravity/src/codegen/func.rs b/cmd/gravity/src/codegen/func.rs index c0cac08..35f552c 100644 --- a/cmd/gravity/src/codegen/func.rs +++ b/cmd/gravity/src/codegen/func.rs @@ -10,9 +10,9 @@ use crate::{ go::{ comment, imports::{ - ERRORS_NEW, REFLECT_VALUE_OF, WAZERO_API_DECODE_F32, WAZERO_API_DECODE_F64, - WAZERO_API_DECODE_I32, WAZERO_API_DECODE_U32, WAZERO_API_ENCODE_F32, - WAZERO_API_ENCODE_F64, WAZERO_API_ENCODE_I32, + ERRORS_NEW, WAZERO_API_DECODE_F32, WAZERO_API_DECODE_F64, WAZERO_API_DECODE_I32, + WAZERO_API_DECODE_U32, WAZERO_API_ENCODE_F32, WAZERO_API_ENCODE_F64, + WAZERO_API_ENCODE_I32, }, GoIdentifier, GoResult, GoType, Operand, }, @@ -82,6 +82,17 @@ impl<'a> Func<'a> { ret } + /// The Go expression that resolves to the wasm `api.Module` in the + /// current direction. Exports live on a Go-side instance struct + /// (`i.module`); imports receive the module as a `mod` parameter from + /// wazero's host-function builder. + fn module_handle(&self) -> &'static str { + match self.direction { + Direction::Export => "i.module", + Direction::Import { .. } => "mod", + } + } + pub fn args(&self) -> &[String] { &self.args } @@ -115,6 +126,8 @@ impl Bindgen for Func<'_> { ) { let iter_element = "e"; let iter_base = "base"; + // Hoist to avoid borrow-checker conflict with `quote_in! { self.body => ... }`. + let module_handle = self.module_handle(); match inst { Instruction::GetArg { nth } => { @@ -194,27 +207,27 @@ impl Bindgen for Func<'_> { $['\r'] $(match &self.result { GoResult::Anon(GoType::ValueOrError(typ)) => { - $raw, $err := i.module.ExportedFunction($(quoted(*name))).Call(ctx, $(for op in operands.iter() join (, ) => uint64($op))) + $raw, $err := $module_handle.ExportedFunction($(quoted(*name))).Call(ctx, $(for op in operands.iter() join (, ) => uint64($op))) if $err != nil { var $default $(typ.as_ref()) return $default, $err } } GoResult::Anon(GoType::Error) => { - $raw, $err := i.module.ExportedFunction($(quoted(*name))).Call(ctx, $(for op in operands.iter() join (, ) => uint64($op))) + $raw, $err := $module_handle.ExportedFunction($(quoted(*name))).Call(ctx, $(for op in operands.iter() join (, ) => uint64($op))) if $err != nil { return $err } } GoResult::Anon(_) => { - $raw, $err := i.module.ExportedFunction($(quoted(*name))).Call(ctx, $(for op in operands.iter() join (, ) => uint64($op))) + $raw, $err := $module_handle.ExportedFunction($(quoted(*name))).Call(ctx, $(for op in operands.iter() join (, ) => uint64($op))) $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) if $err != nil { panic($err) } } GoResult::Empty => { - _, $err := i.module.ExportedFunction($(quoted(*name))).Call(ctx, $(for op in operands.iter() join (, ) => uint64($op))) + _, $err := $module_handle.ExportedFunction($(quoted(*name))).Call(ctx, $(for op in operands.iter() join (, ) => uint64($op))) $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) if $err != nil { panic($err) @@ -229,7 +242,7 @@ impl Bindgen for Func<'_> { "is done accessing it." ])) defer func() { - if postFn := i.module.ExportedFunction($(quoted(format!("cabi_post_{name}")))); postFn != nil { + if postFn := $module_handle.ExportedFunction($(quoted(format!("cabi_post_{name}")))); postFn != nil { if _, err := postFn.Call(ctx, $raw...); err != nil { $(comment(&[ "If we get an error during cleanup, something really bad is", @@ -262,7 +275,7 @@ impl Bindgen for Func<'_> { let operand = &operands[0]; quote_in! { self.body => $['\r'] - $value, $ok := i.module.Memory().ReadByte(uint32($operand + $offset)) + $value, $ok := $module_handle.Memory().ReadByte(uint32($operand + $offset)) $(match &self.result { GoResult::Anon(GoType::ValueOrError(typ)) => { if !$ok { @@ -348,7 +361,7 @@ impl Bindgen for Func<'_> { let operand = &operands[0]; quote_in! { self.body => $['\r'] - $ptr, $ok := i.module.Memory().ReadUint32Le(uint32($operand + $offset)) + $ptr, $ok := $module_handle.Memory().ReadUint32Le(uint32($operand + $offset)) $(match &self.result { GoResult::Anon(GoType::ValueOrError(typ)) => { if !$ok { @@ -381,7 +394,7 @@ impl Bindgen for Func<'_> { let operand = &operands[0]; quote_in! { self.body => $['\r'] - $len, $ok := i.module.Memory().ReadUint32Le(uint32($operand + $offset)) + $len, $ok := $module_handle.Memory().ReadUint32Le(uint32($operand + $offset)) $(match &self.result { GoResult::Anon(GoType::ValueOrError(typ)) => { if !$ok { @@ -414,7 +427,7 @@ impl Bindgen for Func<'_> { let operand = &operands[0]; quote_in! { self.body => $['\r'] - $value, $ok := i.module.Memory().ReadUint32Le(uint32($operand + $offset)) + $value, $ok := $module_handle.Memory().ReadUint32Le(uint32($operand + $offset)) $(match &self.result { GoResult::Anon(GoType::ValueOrError(typ)) => { if !$ok { @@ -579,6 +592,11 @@ impl Bindgen for Func<'_> { let value = &format!("value{tmp}"); let err = &format!("err{tmp}"); let ok = &format!("ok{tmp}"); + // `ValueOrError`/`ValueOrOk` are the only two Go shapes that + // come back as multiple return values — everything else (a + // primitive, string, slice, pointer-to-T, interface, or a + // user-defined record/enum/alias) lands in a single + // identifier that subsequent ABI instructions will lower. match self.direction { Direction::Export { .. } => todo!("TODO(#10): handle export direction"), Direction::Import { param_name, .. } => { @@ -586,7 +604,6 @@ impl Bindgen for Func<'_> { $['\r'] $(match returns { GoType::Nothing => $param_name.$ident(ctx, $args), - GoType::Bool | GoType::Uint32 | GoType::Interface | GoType::String | GoType::UserDefined(_) => $value := $param_name.$ident(ctx, $args), GoType::Error => $err := $param_name.$ident(ctx, $args), GoType::ValueOrError(_) => { $value, $err := $param_name.$ident(ctx, $args) @@ -594,20 +611,13 @@ impl Bindgen for Func<'_> { GoType::ValueOrOk(_) => { $value, $ok := $param_name.$ident(ctx, $args) } - _ => $(comment(&["TODO(#9): handle return type"])) + _ => $value := $param_name.$ident(ctx, $args), }) } } } match returns { GoType::Nothing => (), - GoType::Bool - | GoType::Uint32 - | GoType::Interface - | GoType::UserDefined(_) - | GoType::String => { - results.push(Operand::SingleValue(value.into())); - } GoType::Error => { results.push(Operand::SingleValue(err.into())); } @@ -617,10 +627,14 @@ impl Bindgen for Func<'_> { GoType::ValueOrOk(_) => { results.push(Operand::MultiValue((value.into(), ok.into()))) } - _ => todo!("TODO(#9): handle return type - {returns:?}"), + _ => { + results.push(Operand::SingleValue(value.into())); + } } } Instruction::VariantPayloadName => { + // `VariantLower` and `OptionLower` both bind `variantPayload` + // to the case payload before invoking the per-case block. results.push(Operand::SingleValue("variantPayload".into())); } Instruction::I32Const { val } => results.push(Operand::Literal(val.to_string())), @@ -630,55 +644,26 @@ impl Bindgen for Func<'_> { let tag = &operands[0]; let ptr = &operands[1]; if let Operand::Literal(byte) = tag { - match &self.direction { - Direction::Export => { - quote_in! { self.body => - $['\r'] - i.module.Memory().WriteByte($ptr+$offset, $byte) - } - } - Direction::Import { .. } => { - quote_in! { self.body => - $['\r'] - mod.Memory().WriteByte($ptr+$offset, $byte) - } - } + quote_in! { self.body => + $['\r'] + $module_handle.Memory().WriteByte($ptr+$offset, $byte) } } else { let tmp = self.tmp(); let byte = format!("byte{tmp}"); - match &self.direction { - Direction::Export => { - quote_in! { self.body => - $['\r'] - var $(&byte) uint8 - switch $tag { - case 0: - $(&byte) = 0 - case 1: - $(&byte) = 1 - default: - $(comment(["TODO(#8): Return an error if the return type allows it"])) - panic($ERRORS_NEW("invalid int8 value encountered")) - } - i.module.Memory().WriteByte($ptr+$offset, $byte) - } - } - Direction::Import { .. } => { - quote_in! { self.body => - $['\r'] - var $(&byte) uint8 - switch $tag { - case 0: - $(&byte) = 0 - case 1: - $(&byte) = 1 - default: - panic($ERRORS_NEW("invalid int8 value encountered")) - } - mod.Memory().WriteByte($ptr+$offset, $byte) - } + quote_in! { self.body => + $['\r'] + var $(&byte) uint8 + switch $tag { + case 0: + $(&byte) = 0 + case 1: + $(&byte) = 1 + default: + $(comment(["TODO(#8): Return an error if the return type allows it"])) + panic($ERRORS_NEW("invalid int8 value encountered")) } + $module_handle.Memory().WriteByte($ptr+$offset, $byte) } } } @@ -687,19 +672,9 @@ impl Bindgen for Func<'_> { let offset = offset.size_wasm32(); let tag = &operands[0]; let ptr = &operands[1]; - match &self.direction { - Direction::Export => { - quote_in! { self.body => - $['\r'] - i.module.Memory().WriteUint32Le($ptr+$offset, $tag) - } - } - Direction::Import { .. } => { - quote_in! { self.body => - $['\r'] - mod.Memory().WriteUint32Le($ptr+$offset, $tag) - } - } + quote_in! { self.body => + $['\r'] + $module_handle.Memory().WriteUint32Le($ptr+$offset, $tag) } } Instruction::LengthStore { offset } => { @@ -707,19 +682,9 @@ impl Bindgen for Func<'_> { let offset = offset.size_wasm32(); let len = &operands[0]; let ptr = &operands[1]; - match &self.direction { - Direction::Export => { - quote_in! { self.body => - $['\r'] - i.module.Memory().WriteUint32Le($ptr+$offset, uint32($len)) - } - } - Direction::Import { .. } => { - quote_in! { self.body => - $['\r'] - mod.Memory().WriteUint32Le($ptr+$offset, uint32($len)) - } - } + quote_in! { self.body => + $['\r'] + $module_handle.Memory().WriteUint32Le($ptr+$offset, uint32($len)) } } Instruction::PointerStore { offset } => { @@ -727,19 +692,9 @@ impl Bindgen for Func<'_> { let offset = offset.size_wasm32(); let value = &operands[0]; let ptr = &operands[1]; - match &self.direction { - Direction::Export => { - quote_in! { self.body => - $['\r'] - i.module.Memory().WriteUint32Le($ptr+$offset, uint32($value)) - } - } - Direction::Import { .. } => { - quote_in! { self.body => - $['\r'] - mod.Memory().WriteUint32Le($ptr+$offset, uint32($value)) - } - } + quote_in! { self.body => + $['\r'] + $module_handle.Memory().WriteUint32Le($ptr+$offset, uint32($value)) } } Instruction::ResultLower { @@ -797,30 +752,25 @@ impl Bindgen for Func<'_> { Instruction::ResultLower { .. } => todo!("implement instruction: {inst:?}"), Instruction::OptionLift { payload, .. } => { let (some, some_results) = self.blocks.pop().unwrap(); - let (none, _) = self.blocks.pop().unwrap(); + let (_none, _) = self.blocks.pop().unwrap(); let some_result = &some_results[0]; let tmp = self.tmp(); let result = &format!("result{tmp}"); - let ok = &format!("ok{tmp}"); - let typ = resolve_type(payload, resolve); + let inner_typ = resolve_type(payload, resolve); let op = &operands[0]; quote_in! { self.body => $['\r'] - var $result $typ - var $ok bool - if $op == 0 { - $none - $ok = false - } else { + var $result *$inner_typ + if $op != 0 { $some - $ok = true - $result = $some_result + someValue$tmp := $some_result + $result = &someValue$tmp } }; - results.push(Operand::MultiValue((result.into(), ok.into()))); + results.push(Operand::SingleValue(result.into())); } Instruction::OptionLower { results: result_types, @@ -831,10 +781,6 @@ impl Bindgen for Func<'_> { let tmp = self.tmp(); - // If there are no result_types, then the payload will be a pointer, - // because that's how we represent optionals in Go. - let is_pointer = result_types.is_empty(); - let mut vars: Tokens = Tokens::new(); for i in 0..result_types.len() { let variant = &format!("variant{tmp}_{i}"); @@ -858,33 +804,17 @@ impl Bindgen for Func<'_> { }; } - let operand = &operands[0]; - match operand { - Operand::Literal(_) => { - panic!("impossible: expected Operand::MultiValue but got Operand::Literal") - } - Operand::SingleValue(value) => { - quote_in! { self.body => - $['\r'] - $vars - if $REFLECT_VALUE_OF($value).IsZero() { - $none_block - } else { - variantPayload := $(if is_pointer => *)$value - $some_block - } - }; - } - Operand::MultiValue((value, ok)) => { - quote_in! { self.body => - $['\r'] - if $ok { - variantPayload := $value - $some_block - } else { - $none_block - } - }; + let Operand::SingleValue(value) = &operands[0] else { + unreachable!("OptionLower expects a single `*T` operand"); + }; + quote_in! { self.body => + $['\r'] + $vars + if $value == nil { + $none_block + } else { + variantPayload := *$value + $some_block } }; } @@ -943,7 +873,7 @@ impl Bindgen for Func<'_> { $['\r'] $vec := $operand $len := uint64(len($vec)) - $result, $err := i.module.ExportedFunction($(quoted(*realloc_name))).Call(ctx, 0, 0, $align, $len * $size) + $result, $err := $module_handle.ExportedFunction($(quoted(*realloc_name))).Call(ctx, 0, 0, $align, $len * $size) $(match &self.result { GoResult::Anon(GoType::ValueOrError(typ)) => { if $err != nil { @@ -1003,9 +933,11 @@ impl Bindgen for Func<'_> { } Instruction::VariantLower { variant, + ty, results: result_types, .. } => { + let name = crate::qualified_type_name(*ty, resolve); let blocks = self .blocks .drain(self.blocks.len() - variant.cases.len()..) @@ -1024,6 +956,22 @@ impl Bindgen for Func<'_> { results.push(Operand::SingleValue(variant_item.into())); } + // Collapse the type-switch when every case is `DirectRecord`: + // the case-struct binder IS the payload, so we can bind + // `variantPayload` once in the switch header instead of + // re-aliasing it per arm. Mixed variants need a separate + // binder so `Wrapped` cases can unwrap via `.Value`. + let all_direct = variant.cases.iter().all(|case| { + matches!( + crate::case_dispatch_kind(case, resolve), + crate::CaseDispatchKind::DirectRecord + ) + }); + let case_binder = if all_direct { + "variantPayload".to_string() + } else { + format!("case{tmp}") + }; let mut cases: Tokens = Tokens::new(); for (case, (block, block_results)) in variant.cases.iter().zip(blocks) { let mut assignments: Tokens = Tokens::new(); @@ -1035,10 +983,28 @@ impl Bindgen for Func<'_> { }; } - let name = GoIdentifier::public(case.name.clone()); + let case_type = GoIdentifier::public(crate::case_dispatch_name( + &name, case, resolve, + )); + let payload_intro = if all_direct { + quote!() + } else { + match crate::case_dispatch_kind(case, resolve) { + crate::CaseDispatchKind::DirectRecord => { + quote!(variantPayload := $(&case_binder)$['\r']) + } + crate::CaseDispatchKind::Wrapped if case.ty.is_some() => { + quote!(variantPayload := $(&case_binder).Value$['\r']) + } + crate::CaseDispatchKind::Wrapped => { + quote!(_ = $(&case_binder)$['\r']) + } + } + }; quote_in! { cases => $['\r'] - case $name: + case $case_type: + $payload_intro $block $assignments } @@ -1046,7 +1012,7 @@ impl Bindgen for Func<'_> { quote_in! { self.body => $['\r'] - switch variantPayload := $value.(type) { + switch $(&case_binder) := $value.(type) { $cases default: $(match &self.result { @@ -1106,7 +1072,7 @@ impl Bindgen for Func<'_> { let operand = &operands[0]; quote_in! { self.body => $['\r'] - $value, $ok := i.module.Memory().ReadUint64Le(uint32($operand + $offset)) + $value, $ok := $module_handle.Memory().ReadUint64Le(uint32($operand + $offset)) $(match &self.result { GoResult::Anon(GoType::ValueOrError(typ)) => { if !$ok { @@ -1139,7 +1105,7 @@ impl Bindgen for Func<'_> { let operand = &operands[0]; quote_in! { self.body => $['\r'] - $value, $ok := i.module.Memory().ReadUint64Le(uint32($operand + $offset)) + $value, $ok := $module_handle.Memory().ReadUint64Le(uint32($operand + $offset)) $(match &self.result { GoResult::Anon(GoType::ValueOrError(typ)) => { if !$ok { @@ -1172,7 +1138,7 @@ impl Bindgen for Func<'_> { let operand = &operands[0]; quote_in! { self.body => $['\r'] - $value, $ok := i.module.Memory().ReadUint64Le(uint32($operand + $offset)) + $value, $ok := $module_handle.Memory().ReadUint64Le(uint32($operand + $offset)) $(match &self.result { GoResult::Anon(GoType::ValueOrError(typ)) => { if !$ok { @@ -1202,19 +1168,9 @@ impl Bindgen for Func<'_> { let offset = offset.size_wasm32(); let tag = &operands[0]; let ptr = &operands[1]; - match &self.direction { - Direction::Export => { - quote_in! { self.body => - $['\r'] - i.module.Memory().WriteUint64Le($ptr+$offset, $tag) - } - } - Direction::Import { .. } => { - quote_in! { self.body => - $['\r'] - mod.Memory().WriteUint64Le($ptr+$offset, $tag) - } - } + quote_in! { self.body => + $['\r'] + $module_handle.Memory().WriteUint64Le($ptr+$offset, $tag) } } Instruction::F64Store { offset } => { @@ -1222,19 +1178,9 @@ impl Bindgen for Func<'_> { let offset = offset.size_wasm32(); let tag = &operands[0]; let ptr = &operands[1]; - match &self.direction { - Direction::Export => { - quote_in! { self.body => - $['\r'] - i.module.Memory().WriteUint64Le($ptr+$offset, $tag) - } - } - Direction::Import { .. } => { - quote_in! { self.body => - $['\r'] - mod.Memory().WriteUint64Le($ptr+$offset, $tag) - } - } + quote_in! { self.body => + $['\r'] + $module_handle.Memory().WriteUint64Le($ptr+$offset, $tag) } } Instruction::I32FromChar => todo!("implement instruction: {inst:?}"), @@ -1396,10 +1342,108 @@ impl Bindgen for Func<'_> { Instruction::TupleLift { .. } => todo!("implement instruction: {inst:?}"), Instruction::FlagsLower { .. } => todo!("implement instruction: {inst:?}"), Instruction::FlagsLift { .. } => todo!("implement instruction: {inst:?}"), - Instruction::VariantLift { .. } => { - todo!("implement instruction: {inst:?}") + Instruction::VariantLift { variant, ty, .. } => { + let name = crate::qualified_type_name(*ty, resolve); + let blocks = self + .blocks + .drain(self.blocks.len() - variant.cases.len()..) + .collect::>(); + let discriminant = &operands[0]; + let tmp = self.tmp(); + let value = &format!("value{tmp}"); + let variant_type = GoType::UserDefined(name.clone()); + + let mut cases: Tokens = Tokens::new(); + for (i, (case, (block, block_results))) in + variant.cases.iter().zip(blocks).enumerate() + { + let case_type = + GoIdentifier::public(crate::case_dispatch_name(&name, case, resolve)); + let payload = block_results.first(); + let construction = match crate::case_dispatch_kind(case, resolve) { + crate::CaseDispatchKind::DirectRecord => { + let payload = payload.expect("DirectRecord case has a payload"); + quote!($payload) + } + crate::CaseDispatchKind::Wrapped => match payload { + None => quote!($(&case_type){}), + Some(payload) => quote!($(&case_type){Value: $payload}), + }, + }; + quote_in! { cases => + $['\r'] + case $i: + $block + $value = $construction + }; + } + + let err_msg = format!("\"invalid {name} discriminant\""); + quote_in! { self.body => + $['\r'] + var $value $variant_type + switch $discriminant { + $cases + default: + $(match &self.result { + GoResult::Anon(GoType::ValueOrError(typ)) => { + var default0 $(typ.as_ref()) + return default0, $ERRORS_NEW($(&err_msg)) + } + GoResult::Anon(GoType::Error) => { + return $ERRORS_NEW($(&err_msg)) + } + GoResult::Anon(_) | GoResult::Empty => { + $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) + panic($ERRORS_NEW($(&err_msg))) + } + }) + } + }; + + results.push(Operand::SingleValue(value.to_string())); + } + Instruction::EnumLift { enum_, ty, .. } => { + let name = crate::qualified_type_name(*ty, resolve); + let discriminant = &operands[0]; + let tmp = self.tmp(); + let enum_value = &format!("enum{tmp}"); + let go_type = GoType::UserDefined(name.clone()); + + let mut cases: Tokens = Tokens::new(); + for (i, case) in enum_.cases.iter().enumerate() { + let case_name = GoIdentifier::public(case.name.clone()); + quote_in! { cases => + $['\r'] + case $i: + $enum_value = $case_name + }; + } + + quote_in! { self.body => + $['\r'] + var $enum_value $go_type + switch $discriminant { + $cases + default: + $(match &self.result { + GoResult::Anon(GoType::ValueOrError(typ)) => { + var default0 $(typ.as_ref()) + return default0, $ERRORS_NEW($(format!("\"invalid {name} discriminant\""))) + } + GoResult::Anon(GoType::Error) => { + return $ERRORS_NEW($(format!("\"invalid {name} discriminant\""))) + } + GoResult::Anon(_) | GoResult::Empty => { + $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) + panic($ERRORS_NEW($(format!("\"invalid {name} discriminant\"")))) + } + }) + } + }; + + results.push(Operand::SingleValue(enum_value.to_string())); } - Instruction::EnumLift { .. } => todo!("implement instruction: {inst:?}"), Instruction::Malloc { .. } => todo!("implement instruction: {inst:?}"), Instruction::HandleLower { .. } | Instruction::HandleLift { .. } => { todo!("implement resources: {inst:?}") diff --git a/cmd/gravity/src/codegen/imports.rs b/cmd/gravity/src/codegen/imports.rs index c07cf37..15c589f 100644 --- a/cmd/gravity/src/codegen/imports.rs +++ b/cmd/gravity/src/codegen/imports.rs @@ -4,7 +4,7 @@ use genco::prelude::*; use wit_bindgen_core::{ abi::{AbiVariant, LiftLower}, wit_parser::{ - Function, InterfaceId, Param, Resolve, SizeAlign, Type, TypeDefKind, TypeId, World, + Case, Function, InterfaceId, Param, Resolve, SizeAlign, Type, TypeDefKind, TypeId, World, WorldItem, }, }; @@ -13,15 +13,15 @@ use crate::{ codegen::{ func::Func, ir::{ - AnalyzedFunction, AnalyzedImports, AnalyzedInterface, AnalyzedType, InterfaceMethod, - Parameter, TypeDefinition, WitReturn, + AnalyzedFunction, AnalyzedImports, AnalyzedInterface, AnalyzedType, CaseDispatch, + InterfaceMethod, Parameter, TypeDefinition, VariantCase, WitReturn, }, }, go::{ imports::{CONTEXT_CONTEXT, WAZERO_API_MODULE}, GoIdentifier, GoResult, GoType, }, - resolve_type, resolve_wasm_type, + resolve_param_type, resolve_type, resolve_wasm_type, }; /// Analyzer for imports - only does analysis, no code generation @@ -120,7 +120,7 @@ impl<'a> ImportAnalyzer<'a> { .iter() .map(|Param { name, ty, .. }| Parameter { name: GoIdentifier::private(name), - go_type: resolve_type(ty, self.resolve), + go_type: resolve_param_type(ty, self.resolve), wit_type: *ty, }) .collect(); @@ -141,25 +141,49 @@ impl<'a> ImportAnalyzer<'a> { fn analyze_type(&self, type_id: TypeId) -> Option { let type_def = &self.resolve.types[type_id]; - let type_name = type_def.name.as_ref().expect("type missing name"); - - let go_type_name = GoIdentifier::public(type_name); - let definition = self.analyze_type_definition(&type_def.kind); + let qualified = crate::qualified_type_name(type_id, self.resolve); + let go_type_name = GoIdentifier::public(&qualified); + // Variants live here (not in `analyze_type_definition`) because + // their case wrapper names need the qualified variant name. + let definition = match &type_def.kind { + TypeDefKind::Variant(variant) => Some(TypeDefinition::Variant { + cases: variant + .cases + .iter() + .map(|case| self.analyze_variant_case(&qualified, case)) + .collect(), + }), + kind => self.analyze_type_definition(kind), + }; definition.map(|definition| AnalyzedType { - name: type_name.clone(), + name: qualified, go_type_name, definition, }) } - /// Analyze a type definition and return an intermediate representation ready for - /// codegen. - /// - /// Returns `None` if the kind is just a `TypeDefKind::Type(Type::Id)`, because this - /// is probably a reference to an imported type that we have already analyzed. - /// - /// TODO: we should probably instead resolve and return type and dedup elsewhere. + fn analyze_variant_case(&self, variant_name: &str, case: &Case) -> VariantCase { + let payload = case.ty.as_ref().map(|t| resolve_type(t, self.resolve)); + let dispatch = match crate::case_dispatch_kind(case, self.resolve) { + crate::CaseDispatchKind::DirectRecord => CaseDispatch::DirectRecord { + record_type: payload + .clone() + .expect("DirectRecord case has a payload"), + }, + crate::CaseDispatchKind::Wrapped => CaseDispatch::Wrapped { + wrapper_name: GoIdentifier::public(format!("{variant_name}-{}", case.name)), + }, + }; + VariantCase { + name: case.name.clone(), + payload, + dispatch, + } + } + + /// Analyze a type definition. Returns `None` for `Type::Id` aliases + /// that just re-export an already-analyzed type. fn analyze_type_definition(&self, kind: &TypeDefKind) -> Option { Some(match kind { TypeDefKind::Record(record) => TypeDefinition::Record { @@ -177,18 +201,9 @@ impl<'a> ImportAnalyzer<'a> { TypeDefKind::Enum(enum_def) => TypeDefinition::Enum { cases: enum_def.cases.iter().map(|c| c.name.clone()).collect(), }, - TypeDefKind::Variant(variant) => TypeDefinition::Variant { - cases: variant - .cases - .iter() - .map(|case| { - ( - case.name.clone(), - case.ty.as_ref().map(|t| resolve_type(t, self.resolve)), - ) - }) - .collect(), - }, + TypeDefKind::Variant(_) => unreachable!( + "Variant analysis is handled in `analyze_type` where the qualified name is in scope" + ), TypeDefKind::Type(Type::Id(_)) => { // TODO(#4): Only skip this if we have already generated the type return None; @@ -234,7 +249,7 @@ impl<'a> ImportAnalyzer<'a> { .iter() .map(|Param { name, ty, .. }| Parameter { name: GoIdentifier::private(name), - go_type: resolve_type(ty, self.resolve), + go_type: resolve_param_type(ty, self.resolve), wit_type: *ty, }) .collect(); @@ -398,10 +413,33 @@ impl<'a> ImportCodeGenerator<'a> { // Primitive type: $(typ.name) } } - TypeDefinition::Variant { .. } => { + TypeDefinition::Variant { cases } => { + let variant_interface = &typ.go_type_name; + let marker_method = + &GoIdentifier::private(format!("is-{}", &typ.name)); + let case_definitions = cases.iter().map(|case| match &case.dispatch { + CaseDispatch::DirectRecord { record_type } => quote! { + $['\n'] + func ($record_type) $marker_method() {} + }, + CaseDispatch::Wrapped { wrapper_name } => { + let payload_field = case.payload.as_ref().map(|p| quote!(Value $p)); + quote! { + $['\n'] + type $wrapper_name struct { + $(if let Some(field) = payload_field => $field) + } + $['\n'] + func ($wrapper_name) $marker_method() {} + } + } + }); quote_in! { *tokens => $['\n'] - // Variant type: $(typ.name) (TODO: implement) + type $variant_interface interface { + $marker_method() + } + $(for def in case_definitions => $def) } } } @@ -1199,6 +1237,9 @@ mod tests { assert_eq!(interface.types.len(), 1); let analyzed_type = &interface.types[0]; + // Interface-scoped types are qualified only when their bare name + // would collide with another concrete type in the same world. The + // test world's `foo` is unique, so it stays flat. assert_eq!(analyzed_type.name, "foo"); println!("Analyzed type definition: {:?}", analyzed_type.definition); diff --git a/cmd/gravity/src/codegen/ir.rs b/cmd/gravity/src/codegen/ir.rs index 5a79de3..a141fad 100644 --- a/cmd/gravity/src/codegen/ir.rs +++ b/cmd/gravity/src/codegen/ir.rs @@ -113,10 +113,8 @@ pub struct AnalyzedType { pub enum TypeDefinition { /// A struct-like type with named fields Record { fields: Vec<(GoIdentifier, GoType)> }, - /// A union-like type with multiple cases, each optionally carrying data - Variant { - cases: Vec<(String, Option)>, - }, + /// A union-like type with multiple cases, each optionally carrying data. + Variant { cases: Vec }, /// A simple enumeration with named constants Enum { cases: Vec }, /// A type alias that wraps another type @@ -125,6 +123,29 @@ pub enum TypeDefinition { Primitive, } +#[derive(Debug, Clone)] +pub struct VariantCase { + pub name: String, + /// `None` for unit cases. + pub payload: Option, + pub dispatch: CaseDispatch, +} + +/// How a variant case is represented in Go. See `crate::CaseDispatchKind` +/// for the underlying decision; this enum carries the resolved Go names so +/// they aren't re-derived at emission time. +#[derive(Debug, Clone)] +pub enum CaseDispatch { + /// WIT shorthand `case(case)`: the payload record itself implements + /// the variant's marker interface, so callers construct + /// `MyRecord{...}` directly. + DirectRecord { record_type: GoType }, + /// Dedicated `{VariantName}{CaseName}` wrapper struct with an optional + /// `Value` field. Callers construct `Wrapper{Value: payload}` (or + /// `Wrapper{}` for unit cases) and read the payload via `.Value`. + Wrapped { wrapper_name: GoIdentifier }, +} + /// An analyzed WIT function. #[derive(Debug, Clone)] pub struct AnalyzedFunction { diff --git a/cmd/gravity/src/go/type.rs b/cmd/gravity/src/go/type.rs index 47cfaba..24e47e3 100644 --- a/cmd/gravity/src/go/type.rs +++ b/cmd/gravity/src/go/type.rs @@ -46,6 +46,11 @@ pub enum GoType { Slice(Box), /// Multi-return type (for functions returning arbitrary multiple values) // MultiReturn(Vec), + /// Pointer to another type. Used as the canonical Go representation of + /// `option` so the same lowering composes in every position (params, + /// return values, record fields, list elements). `nil` is `none`, + /// `&value` is `some`. + Pointer(Box), /// User-defined type (records, enums, type aliases) UserDefined(String), /// Represents no value/void @@ -93,6 +98,11 @@ impl GoType { // Complex types need cleanup if their inner types do GoType::ValueOrOk(inner) => inner.needs_cleanup(), + // Pointer (representing option) needs cleanup only when its + // inner type does. Strings and slices behind a pointer still own + // memory the guest allocated. + GoType::Pointer(inner) => inner.needs_cleanup(), + // The inner type of `Err` is always a String so it requires cleanup // TODO(#91): Store the error type to check both inner types. GoType::ValueOrError(_) => true, @@ -162,10 +172,10 @@ impl FormatInto for &GoType { // GoType::MultiReturn(typs) => { // tokens.append(quote!($(for typ in typs join (, ) => $typ))) // } - // GoType::Pointer(typ) => { - // tokens.append(static_literal("*")); - // typ.as_ref().format_into(tokens); - // } + GoType::Pointer(typ) => { + tokens.append(static_literal("*")); + typ.as_ref().format_into(tokens); + } GoType::UserDefined(name) => { let id = GoIdentifier::public(name); id.format_into(tokens) diff --git a/cmd/gravity/src/lib.rs b/cmd/gravity/src/lib.rs index 16600b8..0b4654c 100644 --- a/cmd/gravity/src/lib.rs +++ b/cmd/gravity/src/lib.rs @@ -4,12 +4,93 @@ pub mod go; use crate::go::GoType; use wit_bindgen_core::{ abi::WasmType, - wit_parser::{Resolve, Result_, Type, TypeDef, TypeDefKind}, + dealias, + wit_parser::{Case, Resolve, Result_, Type, TypeDef, TypeDefKind, TypeId, TypeOwner}, }; // Temporary re-export while we migrate. pub use codegen::Func; +/// How a single variant case is represented in Go. +pub enum CaseDispatchKind { + /// The case payload's named record IS the dispatch type — the record + /// implements the variant's marker interface directly. Constructed as + /// `MyRecord{...}`. + DirectRecord, + /// A dedicated `{variant_name}-{case_name}` wrapper struct holds the + /// optional payload in a `Value` field. Constructed as + /// `Wrapper{Value: payload}` or `Wrapper{}` for unit cases. + Wrapped, +} + +/// Detect the WIT shorthand `case-name(case-name)` where the payload is a +/// named record sharing the case's name — the historical arcjet shape +/// (`allow-email-validation-config(allow-email-validation-config)`). We let +/// the record implement the marker interface directly so existing call +/// sites that construct the record value as the variant keep working. +pub fn case_dispatch_kind(case: &Case, resolve: &Resolve) -> CaseDispatchKind { + if let Some(Type::Id(payload_id)) = &case.ty { + let payload_def = &resolve.types[*payload_id]; + if matches!(payload_def.kind, TypeDefKind::Record(_)) + && payload_def.name.as_deref() == Some(case.name.as_str()) + { + return CaseDispatchKind::DirectRecord; + } + } + CaseDispatchKind::Wrapped +} + +/// Kebab-case Go name for the type a variant case dispatches against in a +/// type-switch. +pub fn case_dispatch_name(variant_name: &str, case: &Case, resolve: &Resolve) -> String { + match case_dispatch_kind(case, resolve) { + CaseDispatchKind::DirectRecord => match case.ty { + Some(Type::Id(payload_id)) => qualified_type_name(payload_id, resolve), + _ => unreachable!("DirectRecord requires a Type::Id payload"), + }, + CaseDispatchKind::Wrapped => format!("{variant_name}-{}", case.name), + } +} + +/// Returns a globally-unique kebab-case name suitable for deriving a Go +/// identifier from a WIT type. WIT lets two interfaces declare types of +/// the same name (e.g. both `email-validator-overrides` and `verify-bot` +/// declare an `enum validator-response`); we qualify only the colliding +/// names with their owning interface so stable single-instance names like +/// `algorithm-result` stay flat. The result is fed to +/// `GoIdentifier::public`, so it must remain in kebab-case. +pub fn qualified_type_name(type_id: TypeId, resolve: &Resolve) -> String { + let canonical = dealias(resolve, type_id); + let type_def = &resolve.types[canonical]; + let name = type_def + .name + .as_ref() + .expect("expected named type for qualified_type_name"); + + // Skip `Type` aliases when looking for collisions: they re-export an + // existing type rather than introducing a new one. + let collides = resolve.types.iter().any(|(other_id, other_def)| { + other_id != canonical + && other_def.name.as_deref() == Some(name.as_str()) + && !matches!(other_def.kind, TypeDefKind::Type(_)) + }); + + if !collides { + return name.clone(); + } + + match type_def.owner { + TypeOwner::Interface(id) => { + let interface_name = resolve.interfaces[id] + .name + .as_ref() + .expect("interface missing name"); + format!("{interface_name}-{name}") + } + TypeOwner::World(_) | TypeOwner::None => name.clone(), + } +} + /// Resolves a Wasm type to a Go type. pub fn resolve_wasm_type(typ: &WasmType) -> GoType { match typ { @@ -55,26 +136,24 @@ pub fn resolve_type(typ: &Type, resolve: &Resolve) -> GoType { // Complex types. Type::Id(id) => { - let TypeDef { name, kind, .. } = resolve + let TypeDef { kind, .. } = resolve .types .get(*id) .expect("failed to find type definition"); match kind { - TypeDefKind::Record(_) => { - GoType::UserDefined(name.clone().expect("expected record to have a name")) - } + TypeDefKind::Record(_) => GoType::UserDefined(qualified_type_name(*id, resolve)), TypeDefKind::Resource => todo!("TODO(#5): implement resources"), TypeDefKind::Handle(_) => todo!("TODO(#5): implement resources"), TypeDefKind::Flags(_) => todo!("TODO(#4): implement flag conversion"), TypeDefKind::Tuple(_) => todo!("TODO(#4): implement tuple conversion"), - // Variants are handled as an empty interfaces in type signatures; however, that - // means they require runtime type reflection - TypeDefKind::Variant(_) => GoType::Interface, - TypeDefKind::Enum(_) => { - GoType::UserDefined(name.clone().expect("expected enum to have a name")) - } + TypeDefKind::Variant(_) => GoType::UserDefined(qualified_type_name(*id, resolve)), + TypeDefKind::Enum(_) => GoType::UserDefined(qualified_type_name(*id, resolve)), + // `option` is `*T`: `nil` is `none`, `&v` is `some`. A + // single pointer composes in every position (param, return, + // record field, list element); the prior `(T, bool)` + // comma-ok shape didn't. TypeDefKind::Option(value) => { - GoType::ValueOrOk(Box::new(resolve_type(value, resolve))) + GoType::Pointer(Box::new(resolve_type(value, resolve))) } // Various results, including specialised ones. @@ -108,9 +187,7 @@ pub fn resolve_type(typ: &Type, resolve: &Resolve) -> GoType { TypeDefKind::List(inner) => GoType::Slice(Box::new(resolve_type(inner, resolve))), TypeDefKind::Future(_) => todo!("TODO(#4): implement future conversion"), TypeDefKind::Stream(_) => todo!("TODO(#4): implement stream conversion"), - TypeDefKind::Type(_) => { - GoType::UserDefined(name.clone().expect("expected type alias to have a name")) - } + TypeDefKind::Type(_) => GoType::UserDefined(qualified_type_name(*id, resolve)), TypeDefKind::FixedLengthList(_, _) => { todo!("TODO(#4): implement fixed length list conversion") } @@ -120,3 +197,23 @@ pub fn resolve_type(typ: &Type, resolve: &Resolve) -> GoType { } } } + +/// Like [`resolve_type`], but downgrades a top-level Variant to +/// `interface{}` so existing call sites can keep passing the variant +/// payload through `any`-typed plumbing (rule config returns, generic +/// dispatch layers). The marker interface and per-case structs are still +/// generated, and the type-switch in `VariantLower` still dispatches on +/// the concrete case types — callers who want compile-time exhaustiveness +/// just declare their value as the marker interface explicitly. +/// +/// Variants nested inside records, lists, or returns stay typed so +/// generated record fields remain strongly typed. +pub fn resolve_param_type(typ: &Type, resolve: &Resolve) -> GoType { + if let Type::Id(id) = typ { + let def = &resolve.types[dealias(resolve, *id)]; + if matches!(def.kind, TypeDefKind::Variant(_)) { + return GoType::Interface; + } + } + resolve_type(typ, resolve) +} diff --git a/cmd/gravity/tests/cmd/basic.stdout b/cmd/gravity/tests/cmd/basic.stdout index 5fd7413..8724bd8 100644 --- a/cmd/gravity/tests/cmd/basic.stdout +++ b/cmd/gravity/tests/cmd/basic.stdout @@ -6,7 +6,6 @@ import "context" import "errors" import "github.com/tetratelabs/wazero" import "github.com/tetratelabs/wazero/api" -import "reflect" import _ "embed" @@ -301,16 +300,16 @@ func (i *BasicInstance) Primitive( func (i *BasicInstance) OptionalPrimitive( ctx context.Context, - b bool, -) (bool, bool) { + b *bool, +) *bool { arg0 := b var variant1_0 uint32 var variant1_1 uint32 - if reflect.ValueOf(arg0).IsZero() { + if arg0 == nil { variant1_0 = 0 variant1_1 = 0 } else { - variantPayload := arg0 + variantPayload := *arg0 var value0 uint32 if variantPayload { value0 = 1 @@ -332,21 +331,18 @@ func (i *BasicInstance) OptionalPrimitive( if !ok3 { panic(errors.New("failed to read byte from memory")) } - var result6 bool - var ok6 bool - if value3 == 0 { - ok6 = false - } else { + var result6 *bool + if value3 != 0 { value4, ok4 := i.module.Memory().ReadByte(uint32(results2 + 1)) // The return type doesn't contain an error so we panic if one is encountered if !ok4 { panic(errors.New("failed to read byte from memory")) } value5 := value4 != 0 - ok6 = true - result6 = value5 + someValue6 := value5 + result6 = &someValue6 } - return result6, ok6 + return result6 } func (i *BasicInstance) ResultPrimitive( @@ -415,18 +411,18 @@ func (i *BasicInstance) ResultPrimitive( func (i *BasicInstance) OptionalString( ctx context.Context, - s string, -) (string, bool) { + s *string, +) *string { arg0 := s var variant1_0 uint32 var variant1_1 uint64 var variant1_2 uint64 - if reflect.ValueOf(arg0).IsZero() { + if arg0 == nil { variant1_0 = 0 variant1_1 = 0 variant1_2 = 0 } else { - variantPayload := arg0 + variantPayload := *arg0 memory0 := i.module.Memory() realloc0 := i.module.ExportedFunction("cabi_realloc") ptr0, len0, err0 := writeString(ctx, variantPayload, memory0, realloc0) @@ -464,11 +460,8 @@ func (i *BasicInstance) OptionalString( if !ok3 { panic(errors.New("failed to read byte from memory")) } - var result7 string - var ok7 bool - if value3 == 0 { - ok7 = false - } else { + var result7 *string + if value3 != 0 { ptr4, ok4 := i.module.Memory().ReadUint32Le(uint32(results2 + 4)) // The return type doesn't contain an error so we panic if one is encountered if !ok4 { @@ -485,9 +478,9 @@ func (i *BasicInstance) OptionalString( panic(errors.New("failed to read bytes from memory")) } str6 := string(buf6) - ok7 = true - result7 = str6 + someValue7 := str6 + result7 = &someValue7 } - return result7, ok7 + return result7 } diff --git a/cmd/gravity/tests/cmd/regressions.stdout b/cmd/gravity/tests/cmd/regressions.stdout index ba00add..de00d9f 100644 --- a/cmd/gravity/tests/cmd/regressions.stdout +++ b/cmd/gravity/tests/cmd/regressions.stdout @@ -50,6 +50,55 @@ type IRegressionsPinger interface { ) bool } +type IRegressionsEmailChecker interface { + IsAllowed( + ctx context.Context, + email string, + ) EmailCheckerValidatorResponse +} + +type EmailCheckerValidatorResponse interface { + isEmailCheckerValidatorResponse() +} + +type emailCheckerValidatorResponse int + +func (emailCheckerValidatorResponse) isEmailCheckerValidatorResponse() {} + +const ( + Yes emailCheckerValidatorResponse = iota + No emailCheckerValidatorResponse = iota + Maybe emailCheckerValidatorResponse = iota +) + +type IRegressionsBotVerifier interface { + Verify( + ctx context.Context, + botId string, + ) BotVerifierValidatorResponse +} + +type BotVerifierValidatorResponse interface { + isBotVerifierValidatorResponse() +} + +type botVerifierValidatorResponse int + +func (botVerifierValidatorResponse) isBotVerifierValidatorResponse() {} + +const ( + Verified botVerifierValidatorResponse = iota + Spoofed botVerifierValidatorResponse = iota + Unverifiable botVerifierValidatorResponse = iota +) + +type IRegressionsIpSource interface { + Lookup( + ctx context.Context, + ip string, + ) *string +} + type RegressionsFactory struct { runtime wazero.Runtime module wazero.CompiledModule @@ -60,9 +109,44 @@ func NewRegressionsFactory( checker IRegressionsChecker, processor IRegressionsProcessor, pinger IRegressionsPinger, + emailChecker IRegressionsEmailChecker, + botVerifier IRegressionsBotVerifier, + ipSource IRegressionsIpSource, ) (*RegressionsFactory, error) { wazeroRuntime := wazero.NewRuntime(ctx) + _, err4 := wazeroRuntime.NewHostModuleBuilder("gravity:regressions/bot-verifier"). + NewFunctionBuilder(). + WithFunc(func( + ctx context.Context, + mod api.Module, + arg0 uint32, + arg1 uint32, + ) uint32{ + buf0, ok0 := mod.Memory().Read(arg0, arg1) + if !ok0 { + panic(errors.New("failed to read bytes from memory")) + } + str0 := string(buf0) + value1 := botVerifier.Verify(ctx, str0) + var enum2 uint32 + switch value1 { + case Verified: + enum2 = 0 + case Spoofed: + enum2 = 1 + case Unverifiable: + enum2 = 2 + default: + panic(errors.New("invalid enum type provided")) + } + return enum2 + }). + Export("verify"). + Instantiate(ctx) + if err4 != nil { + return nil, err4 + } _, err0 := wazeroRuntime.NewHostModuleBuilder("gravity:regressions/checker"). NewFunctionBuilder(). WithFunc(func( @@ -117,6 +201,73 @@ func NewRegressionsFactory( if err0 != nil { return nil, err0 } + _, err3 := wazeroRuntime.NewHostModuleBuilder("gravity:regressions/email-checker"). + NewFunctionBuilder(). + WithFunc(func( + ctx context.Context, + mod api.Module, + arg0 uint32, + arg1 uint32, + ) uint32{ + buf0, ok0 := mod.Memory().Read(arg0, arg1) + if !ok0 { + panic(errors.New("failed to read bytes from memory")) + } + str0 := string(buf0) + value1 := emailChecker.IsAllowed(ctx, str0) + var enum2 uint32 + switch value1 { + case Yes: + enum2 = 0 + case No: + enum2 = 1 + case Maybe: + enum2 = 2 + default: + panic(errors.New("invalid enum type provided")) + } + return enum2 + }). + Export("is-allowed"). + Instantiate(ctx) + if err3 != nil { + return nil, err3 + } + _, err5 := wazeroRuntime.NewHostModuleBuilder("gravity:regressions/ip-source"). + NewFunctionBuilder(). + WithFunc(func( + ctx context.Context, + mod api.Module, + arg0 uint32, + arg1 uint32, + arg2 uint32, + ) { + buf0, ok0 := mod.Memory().Read(arg0, arg1) + if !ok0 { + panic(errors.New("failed to read bytes from memory")) + } + str0 := string(buf0) + value1 := ipSource.Lookup(ctx, str0) + if value1 == nil { + mod.Memory().WriteByte(arg2+0, 0) + } else { + variantPayload := *value1 + mod.Memory().WriteByte(arg2+0, 1) + memory2 := mod.Memory() + realloc2 := mod.ExportedFunction("cabi_realloc") + ptr2, len2, err2 := writeString(ctx, variantPayload, memory2, realloc2) + if err2 != nil { + panic(err2) + } + mod.Memory().WriteUint32Le(arg2+8, uint32(len2)) + mod.Memory().WriteUint32Le(arg2+4, uint32(ptr2)) + } + }). + Export("lookup"). + Instantiate(ctx) + if err5 != nil { + return nil, err5 + } _, err2 := wazeroRuntime.NewHostModuleBuilder("gravity:regressions/pinger"). NewFunctionBuilder(). WithFunc(func( @@ -292,3 +443,101 @@ func (i *RegressionsInstance) RunPing( return value1 } +func (i *RegressionsInstance) CheckEmailAllowed( + ctx context.Context, + email string, +) uint32 { + arg0 := email + memory0 := i.module.Memory() + realloc0 := i.module.ExportedFunction("cabi_realloc") + ptr0, len0, err0 := writeString(ctx, arg0, memory0, realloc0) + // The return type doesn't contain an error so we panic if one is encountered + if err0 != nil { + panic(err0) + } + raw1, err1 := i.module.ExportedFunction("check-email-allowed").Call(ctx, uint64(ptr0), uint64(len0)) + // The return type doesn't contain an error so we panic if one is encountered + if err1 != nil { + panic(err1) + } + + results1 := raw1[0] + result2 := uint32(results1) + return result2 +} + +func (i *RegressionsInstance) CheckBotVerified( + ctx context.Context, + botId string, +) uint32 { + arg0 := botId + memory0 := i.module.Memory() + realloc0 := i.module.ExportedFunction("cabi_realloc") + ptr0, len0, err0 := writeString(ctx, arg0, memory0, realloc0) + // The return type doesn't contain an error so we panic if one is encountered + if err0 != nil { + panic(err0) + } + raw1, err1 := i.module.ExportedFunction("check-bot-verified").Call(ctx, uint64(ptr0), uint64(len0)) + // The return type doesn't contain an error so we panic if one is encountered + if err1 != nil { + panic(err1) + } + + results1 := raw1[0] + result2 := uint32(results1) + return result2 +} + +func (i *RegressionsInstance) RunIpLookup( + ctx context.Context, + ip string, +) string { + arg0 := ip + memory0 := i.module.Memory() + realloc0 := i.module.ExportedFunction("cabi_realloc") + ptr0, len0, err0 := writeString(ctx, arg0, memory0, realloc0) + // The return type doesn't contain an error so we panic if one is encountered + if err0 != nil { + panic(err0) + } + raw1, err1 := i.module.ExportedFunction("run-ip-lookup").Call(ctx, uint64(ptr0), uint64(len0)) + // The return type doesn't contain an error so we panic if one is encountered + if err1 != nil { + panic(err1) + } + + // The cleanup via `cabi_post_*` cleans up the memory in the guest. By + // deferring this, we ensure that no memory is corrupted before the function + // is done accessing it. + defer func() { + if postFn := i.module.ExportedFunction("cabi_post_run-ip-lookup"); postFn != nil { + if _, err := postFn.Call(ctx, raw1...); err != nil { + // If we get an error during cleanup, something really bad is + // going on, so we panic. Also, you can't return the error from + // the `defer` + panic(errors.New("failed to cleanup")) + } + } + }() + + results1 := raw1[0] + ptr2, ok2 := i.module.Memory().ReadUint32Le(uint32(results1 + 0)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok2 { + panic(errors.New("failed to read pointer from memory")) + } + len3, ok3 := i.module.Memory().ReadUint32Le(uint32(results1 + 4)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok3 { + panic(errors.New("failed to read length from memory")) + } + buf4, ok4 := i.module.Memory().Read(ptr2, len3) + // The return type doesn't contain an error so we panic if one is encountered + if !ok4 { + panic(errors.New("failed to read bytes from memory")) + } + str4 := string(buf4) + return str4 +} + diff --git a/cmd/gravity/tests/cmd/variants.stderr b/cmd/gravity/tests/cmd/variants.stderr new file mode 100644 index 0000000..e69de29 diff --git a/cmd/gravity/tests/cmd/variants.stdout b/cmd/gravity/tests/cmd/variants.stdout new file mode 100644 index 0000000..3141048 --- /dev/null +++ b/cmd/gravity/tests/cmd/variants.stdout @@ -0,0 +1,651 @@ +// Code generated by arcjet-gravity; DO NOT EDIT. + +package variants + +import "context" +import "errors" +import "github.com/tetratelabs/wazero" +import "github.com/tetratelabs/wazero/api" + +import _ "embed" + +//go:embed variants.wasm +var wasmFileVariants []byte + +type Entity interface { + isEntity() +} + +type EntityEmail struct {} + +func (EntityEmail) isEntity() {} + +type EntityPhoneNumber struct {} + +func (EntityPhoneNumber) isEntity() {} + +type EntityIpAddress struct {} + +func (EntityIpAddress) isEntity() {} + +type EntityCreditCardNumber struct {} + +func (EntityCreditCardNumber) isEntity() {} + +type EntityCustom struct { + Value string +} + +func (EntityCustom) isEntity() {} + +type Allow struct { + Entities []Entity + ContextWindowSize *uint32 +} + +type Deny struct { + Entities []Entity +} + +type Config interface { + isConfig() +} + +func (Allow) isConfig() {} + +func (Deny) isConfig() {} + +type Entities interface { + isEntities() +} + +type EntitiesAllowAll struct { + Value []Entity +} + +func (EntitiesAllowAll) isEntities() {} + +type EntitiesDenyAll struct { + Value []Entity +} + +func (EntitiesDenyAll) isEntities() {} + +type Detected struct { + Kind Entity + Start uint32 + End uint32 +} + +type VariantsFactory struct { + runtime wazero.Runtime + module wazero.CompiledModule +} + +func NewVariantsFactory( + ctx context.Context, +) (*VariantsFactory, error) { + wazeroRuntime := wazero.NewRuntime(ctx) + + // Compiling the module takes a LONG time, so we want to do it once and hold + // onto it with the Runtime + module, err := wazeroRuntime.CompileModule(ctx, wasmFileVariants) + if err != nil { + return nil, err + } + return &VariantsFactory{ + runtime: wazeroRuntime, + module: module, + }, nil +} + +func (f *VariantsFactory) Instantiate(ctx context.Context) (*VariantsInstance, error) { + if module, err := f.runtime.InstantiateModule(ctx, f.module, wazero.NewModuleConfig()); err != nil { + return nil, err + } else { + return &VariantsInstance{module}, nil + } +} + +func (f *VariantsFactory) Close(ctx context.Context) { + f.runtime.Close(ctx) +} + +type VariantsInstance struct { + module api.Module +} + +func (i *VariantsInstance) Close(ctx context.Context) error { + if err := i.module.Close(ctx); err != nil { + return err + } + + return nil +} + +// writeString will put a Go string into the Wasm memory following the Component +// Model calling conventions, such as allocating memory with the realloc function +func writeString( + ctx context.Context, + s string, + memory api.Memory, + realloc api.Function, +) (uint64, uint64, error) { + if len(s) == 0 { + return 1, 0, nil + } + + results, err := realloc.Call(ctx, 0, 0, 1, uint64(len(s))) + if err != nil { + return 1, 0, err + } + ptr := results[0] + ok := memory.Write(uint32(ptr), []byte(s)) + if !ok { + return 1, 0, errors.New("failed to write string to wasm memory") + } + return uint64(ptr), uint64(len(s)), nil +} + +func (i *VariantsInstance) Classify( + ctx context.Context, + input string, +) Entity { + arg0 := input + memory0 := i.module.Memory() + realloc0 := i.module.ExportedFunction("cabi_realloc") + ptr0, len0, err0 := writeString(ctx, arg0, memory0, realloc0) + // The return type doesn't contain an error so we panic if one is encountered + if err0 != nil { + panic(err0) + } + raw1, err1 := i.module.ExportedFunction("classify").Call(ctx, uint64(ptr0), uint64(len0)) + // The return type doesn't contain an error so we panic if one is encountered + if err1 != nil { + panic(err1) + } + + // The cleanup via `cabi_post_*` cleans up the memory in the guest. By + // deferring this, we ensure that no memory is corrupted before the function + // is done accessing it. + defer func() { + if postFn := i.module.ExportedFunction("cabi_post_classify"); postFn != nil { + if _, err := postFn.Call(ctx, raw1...); err != nil { + // If we get an error during cleanup, something really bad is + // going on, so we panic. Also, you can't return the error from + // the `defer` + panic(errors.New("failed to cleanup")) + } + } + }() + + results1 := raw1[0] + value2, ok2 := i.module.Memory().ReadByte(uint32(results1 + 0)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok2 { + panic(errors.New("failed to read byte from memory")) + } + var value6 Entity + switch value2 { + case 0: + value6 = EntityEmail{} + case 1: + value6 = EntityPhoneNumber{} + case 2: + value6 = EntityIpAddress{} + case 3: + value6 = EntityCreditCardNumber{} + case 4: + ptr3, ok3 := i.module.Memory().ReadUint32Le(uint32(results1 + 4)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok3 { + panic(errors.New("failed to read pointer from memory")) + } + len4, ok4 := i.module.Memory().ReadUint32Le(uint32(results1 + 8)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok4 { + panic(errors.New("failed to read length from memory")) + } + buf5, ok5 := i.module.Memory().Read(ptr3, len4) + // The return type doesn't contain an error so we panic if one is encountered + if !ok5 { + panic(errors.New("failed to read bytes from memory")) + } + str5 := string(buf5) + value6 = EntityCustom{Value: str5} + default: + // The return type doesn't contain an error so we panic if one is encountered + panic(errors.New("invalid entity discriminant")) + } + return value6 +} + +func (i *VariantsInstance) TagAll( + ctx context.Context, + inputs []string, +) []Detected { + arg0 := inputs + vec1 := arg0 + len1 := uint64(len(vec1)) + result1, err1 := i.module.ExportedFunction("cabi_realloc").Call(ctx, 0, 0, 4, len1 * 8) + // The return type doesn't contain an error so we panic if one is encountered + if err1 != nil { + panic(err1) + } + ptr1 := result1[0] + for idx := uint64(0); idx < len1; idx++ { + e := vec1[idx] + base := uint32(ptr1 + uint64(idx) * uint64(8)) + memory0 := i.module.Memory() + realloc0 := i.module.ExportedFunction("cabi_realloc") + ptr0, len0, err0 := writeString(ctx, e, memory0, realloc0) + // The return type doesn't contain an error so we panic if one is encountered + if err0 != nil { + panic(err0) + } + i.module.Memory().WriteUint32Le(base+4, uint32(len0)) + i.module.Memory().WriteUint32Le(base+0, uint32(ptr0)) + } + raw2, err2 := i.module.ExportedFunction("tag-all").Call(ctx, uint64(ptr1), uint64(len1)) + // The return type doesn't contain an error so we panic if one is encountered + if err2 != nil { + panic(err2) + } + + // The cleanup via `cabi_post_*` cleans up the memory in the guest. By + // deferring this, we ensure that no memory is corrupted before the function + // is done accessing it. + defer func() { + if postFn := i.module.ExportedFunction("cabi_post_tag-all"); postFn != nil { + if _, err := postFn.Call(ctx, raw2...); err != nil { + // If we get an error during cleanup, something really bad is + // going on, so we panic. Also, you can't return the error from + // the `defer` + panic(errors.New("failed to cleanup")) + } + } + }() + + results2 := raw2[0] + ptr3, ok3 := i.module.Memory().ReadUint32Le(uint32(results2 + 0)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok3 { + panic(errors.New("failed to read pointer from memory")) + } + len4, ok4 := i.module.Memory().ReadUint32Le(uint32(results2 + 4)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok4 { + panic(errors.New("failed to read length from memory")) + } + base15 := ptr3 + len15 := len4 + result15 := make([]Detected, len15) + for idx15 := uint32(0); idx15 < len15; idx15++ { + base := base15 + idx15 * 20 + value5, ok5 := i.module.Memory().ReadByte(uint32(base + 0)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok5 { + panic(errors.New("failed to read byte from memory")) + } + var value9 Entity + switch value5 { + case 0: + value9 = EntityEmail{} + case 1: + value9 = EntityPhoneNumber{} + case 2: + value9 = EntityIpAddress{} + case 3: + value9 = EntityCreditCardNumber{} + case 4: + ptr6, ok6 := i.module.Memory().ReadUint32Le(uint32(base + 4)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok6 { + panic(errors.New("failed to read pointer from memory")) + } + len7, ok7 := i.module.Memory().ReadUint32Le(uint32(base + 8)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok7 { + panic(errors.New("failed to read length from memory")) + } + buf8, ok8 := i.module.Memory().Read(ptr6, len7) + // The return type doesn't contain an error so we panic if one is encountered + if !ok8 { + panic(errors.New("failed to read bytes from memory")) + } + str8 := string(buf8) + value9 = EntityCustom{Value: str8} + default: + // The return type doesn't contain an error so we panic if one is encountered + panic(errors.New("invalid entity discriminant")) + } + value10, ok10 := i.module.Memory().ReadUint32Le(uint32(base + 12)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok10 { + panic(errors.New("failed to read i32 from memory")) + } + result11 := uint32(value10) + value12, ok12 := i.module.Memory().ReadUint32Le(uint32(base + 16)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok12 { + panic(errors.New("failed to read i32 from memory")) + } + result13 := uint32(value12) + value14 := Detected{ + Kind: value9, + Start: result11, + End: result13, + } + result15[idx15] = value14 + } + return result15 +} + +func (i *VariantsInstance) Choose( + ctx context.Context, + input interface{}, +) string { + arg0 := input + var variant10_0 uint32 + var variant10_1 uint64 + var variant10_2 uint64 + var variant10_3 uint32 + var variant10_4 uint32 + switch variantPayload := arg0.(type) { + case Allow: + entities0 := variantPayload.Entities + contextWindowSize0 := variantPayload.ContextWindowSize + vec3 := entities0 + len3 := uint64(len(vec3)) + result3, err3 := i.module.ExportedFunction("cabi_realloc").Call(ctx, 0, 0, 4, len3 * 12) + // The return type doesn't contain an error so we panic if one is encountered + if err3 != nil { + panic(err3) + } + ptr3 := result3[0] + for idx := uint64(0); idx < len3; idx++ { + e := vec3[idx] + base := uint32(ptr3 + uint64(idx) * uint64(12)) + switch case2 := e.(type) { + case EntityEmail: + _ = case2 + i.module.Memory().WriteByte(base+0, 0) + case EntityPhoneNumber: + _ = case2 + i.module.Memory().WriteByte(base+0, 1) + case EntityIpAddress: + _ = case2 + i.module.Memory().WriteByte(base+0, 2) + case EntityCreditCardNumber: + _ = case2 + i.module.Memory().WriteByte(base+0, 3) + case EntityCustom: + variantPayload := case2.Value + i.module.Memory().WriteByte(base+0, 4) + memory1 := i.module.Memory() + realloc1 := i.module.ExportedFunction("cabi_realloc") + ptr1, len1, err1 := writeString(ctx, variantPayload, memory1, realloc1) + // The return type doesn't contain an error so we panic if one is encountered + if err1 != nil { + panic(err1) + } + i.module.Memory().WriteUint32Le(base+8, uint32(len1)) + i.module.Memory().WriteUint32Le(base+4, uint32(ptr1)) + default: + // The return type doesn't contain an error so we panic if one is encountered + panic(errors.New("invalid variant type provided")) + } + } + var variant5_0 uint32 + var variant5_1 uint32 + if contextWindowSize0 == nil { + variant5_0 = 0 + variant5_1 = 0 + } else { + variantPayload := *contextWindowSize0 + result4 := uint32(variantPayload) + variant5_0 = 1 + variant5_1 = result4 + } + variant10_0 = 0 + variant10_1 = ptr3 + variant10_2 = len3 + variant10_3 = variant5_0 + variant10_4 = variant5_1 + case Deny: + entities6 := variantPayload.Entities + vec9 := entities6 + len9 := uint64(len(vec9)) + result9, err9 := i.module.ExportedFunction("cabi_realloc").Call(ctx, 0, 0, 4, len9 * 12) + // The return type doesn't contain an error so we panic if one is encountered + if err9 != nil { + panic(err9) + } + ptr9 := result9[0] + for idx := uint64(0); idx < len9; idx++ { + e := vec9[idx] + base := uint32(ptr9 + uint64(idx) * uint64(12)) + switch case8 := e.(type) { + case EntityEmail: + _ = case8 + i.module.Memory().WriteByte(base+0, 0) + case EntityPhoneNumber: + _ = case8 + i.module.Memory().WriteByte(base+0, 1) + case EntityIpAddress: + _ = case8 + i.module.Memory().WriteByte(base+0, 2) + case EntityCreditCardNumber: + _ = case8 + i.module.Memory().WriteByte(base+0, 3) + case EntityCustom: + variantPayload := case8.Value + i.module.Memory().WriteByte(base+0, 4) + memory7 := i.module.Memory() + realloc7 := i.module.ExportedFunction("cabi_realloc") + ptr7, len7, err7 := writeString(ctx, variantPayload, memory7, realloc7) + // The return type doesn't contain an error so we panic if one is encountered + if err7 != nil { + panic(err7) + } + i.module.Memory().WriteUint32Le(base+8, uint32(len7)) + i.module.Memory().WriteUint32Le(base+4, uint32(ptr7)) + default: + // The return type doesn't contain an error so we panic if one is encountered + panic(errors.New("invalid variant type provided")) + } + } + variant10_0 = 1 + variant10_1 = ptr9 + variant10_2 = len9 + variant10_3 = 0 + variant10_4 = 0 + default: + // The return type doesn't contain an error so we panic if one is encountered + panic(errors.New("invalid variant type provided")) + } + raw11, err11 := i.module.ExportedFunction("choose").Call(ctx, uint64(variant10_0), uint64(variant10_1), uint64(variant10_2), uint64(variant10_3), uint64(variant10_4)) + // The return type doesn't contain an error so we panic if one is encountered + if err11 != nil { + panic(err11) + } + + // The cleanup via `cabi_post_*` cleans up the memory in the guest. By + // deferring this, we ensure that no memory is corrupted before the function + // is done accessing it. + defer func() { + if postFn := i.module.ExportedFunction("cabi_post_choose"); postFn != nil { + if _, err := postFn.Call(ctx, raw11...); err != nil { + // If we get an error during cleanup, something really bad is + // going on, so we panic. Also, you can't return the error from + // the `defer` + panic(errors.New("failed to cleanup")) + } + } + }() + + results11 := raw11[0] + ptr12, ok12 := i.module.Memory().ReadUint32Le(uint32(results11 + 0)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok12 { + panic(errors.New("failed to read pointer from memory")) + } + len13, ok13 := i.module.Memory().ReadUint32Le(uint32(results11 + 4)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok13 { + panic(errors.New("failed to read length from memory")) + } + buf14, ok14 := i.module.Memory().Read(ptr12, len13) + // The return type doesn't contain an error so we panic if one is encountered + if !ok14 { + panic(errors.New("failed to read bytes from memory")) + } + str14 := string(buf14) + return str14 +} + +func (i *VariantsInstance) ChooseMany( + ctx context.Context, + input interface{}, +) string { + arg0 := input + var variant6_0 uint32 + var variant6_1 uint64 + var variant6_2 uint64 + switch case6 := arg0.(type) { + case EntitiesAllowAll: + variantPayload := case6.Value + vec2 := variantPayload + len2 := uint64(len(vec2)) + result2, err2 := i.module.ExportedFunction("cabi_realloc").Call(ctx, 0, 0, 4, len2 * 12) + // The return type doesn't contain an error so we panic if one is encountered + if err2 != nil { + panic(err2) + } + ptr2 := result2[0] + for idx := uint64(0); idx < len2; idx++ { + e := vec2[idx] + base := uint32(ptr2 + uint64(idx) * uint64(12)) + switch case1 := e.(type) { + case EntityEmail: + _ = case1 + i.module.Memory().WriteByte(base+0, 0) + case EntityPhoneNumber: + _ = case1 + i.module.Memory().WriteByte(base+0, 1) + case EntityIpAddress: + _ = case1 + i.module.Memory().WriteByte(base+0, 2) + case EntityCreditCardNumber: + _ = case1 + i.module.Memory().WriteByte(base+0, 3) + case EntityCustom: + variantPayload := case1.Value + i.module.Memory().WriteByte(base+0, 4) + memory0 := i.module.Memory() + realloc0 := i.module.ExportedFunction("cabi_realloc") + ptr0, len0, err0 := writeString(ctx, variantPayload, memory0, realloc0) + // The return type doesn't contain an error so we panic if one is encountered + if err0 != nil { + panic(err0) + } + i.module.Memory().WriteUint32Le(base+8, uint32(len0)) + i.module.Memory().WriteUint32Le(base+4, uint32(ptr0)) + default: + // The return type doesn't contain an error so we panic if one is encountered + panic(errors.New("invalid variant type provided")) + } + } + variant6_0 = 0 + variant6_1 = ptr2 + variant6_2 = len2 + case EntitiesDenyAll: + variantPayload := case6.Value + vec5 := variantPayload + len5 := uint64(len(vec5)) + result5, err5 := i.module.ExportedFunction("cabi_realloc").Call(ctx, 0, 0, 4, len5 * 12) + // The return type doesn't contain an error so we panic if one is encountered + if err5 != nil { + panic(err5) + } + ptr5 := result5[0] + for idx := uint64(0); idx < len5; idx++ { + e := vec5[idx] + base := uint32(ptr5 + uint64(idx) * uint64(12)) + switch case4 := e.(type) { + case EntityEmail: + _ = case4 + i.module.Memory().WriteByte(base+0, 0) + case EntityPhoneNumber: + _ = case4 + i.module.Memory().WriteByte(base+0, 1) + case EntityIpAddress: + _ = case4 + i.module.Memory().WriteByte(base+0, 2) + case EntityCreditCardNumber: + _ = case4 + i.module.Memory().WriteByte(base+0, 3) + case EntityCustom: + variantPayload := case4.Value + i.module.Memory().WriteByte(base+0, 4) + memory3 := i.module.Memory() + realloc3 := i.module.ExportedFunction("cabi_realloc") + ptr3, len3, err3 := writeString(ctx, variantPayload, memory3, realloc3) + // The return type doesn't contain an error so we panic if one is encountered + if err3 != nil { + panic(err3) + } + i.module.Memory().WriteUint32Le(base+8, uint32(len3)) + i.module.Memory().WriteUint32Le(base+4, uint32(ptr3)) + default: + // The return type doesn't contain an error so we panic if one is encountered + panic(errors.New("invalid variant type provided")) + } + } + variant6_0 = 1 + variant6_1 = ptr5 + variant6_2 = len5 + default: + // The return type doesn't contain an error so we panic if one is encountered + panic(errors.New("invalid variant type provided")) + } + raw7, err7 := i.module.ExportedFunction("choose-many").Call(ctx, uint64(variant6_0), uint64(variant6_1), uint64(variant6_2)) + // The return type doesn't contain an error so we panic if one is encountered + if err7 != nil { + panic(err7) + } + + // The cleanup via `cabi_post_*` cleans up the memory in the guest. By + // deferring this, we ensure that no memory is corrupted before the function + // is done accessing it. + defer func() { + if postFn := i.module.ExportedFunction("cabi_post_choose-many"); postFn != nil { + if _, err := postFn.Call(ctx, raw7...); err != nil { + // If we get an error during cleanup, something really bad is + // going on, so we panic. Also, you can't return the error from + // the `defer` + panic(errors.New("failed to cleanup")) + } + } + }() + + results7 := raw7[0] + ptr8, ok8 := i.module.Memory().ReadUint32Le(uint32(results7 + 0)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok8 { + panic(errors.New("failed to read pointer from memory")) + } + len9, ok9 := i.module.Memory().ReadUint32Le(uint32(results7 + 4)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok9 { + panic(errors.New("failed to read length from memory")) + } + buf10, ok10 := i.module.Memory().Read(ptr8, len9) + // The return type doesn't contain an error so we panic if one is encountered + if !ok10 { + panic(errors.New("failed to read bytes from memory")) + } + str10 := string(buf10) + return str10 +} + diff --git a/cmd/gravity/tests/cmd/variants.toml b/cmd/gravity/tests/cmd/variants.toml new file mode 100644 index 0000000..55ee673 --- /dev/null +++ b/cmd/gravity/tests/cmd/variants.toml @@ -0,0 +1,2 @@ +bin.name = "gravity" +args = "--world variants ../../target/wasm32-unknown-unknown/release/example_variants.wasm" diff --git a/examples/basic/basic_test.go b/examples/basic/basic_test.go index 806558d..4dd2e1f 100644 --- a/examples/basic/basic_test.go +++ b/examples/basic/basic_test.go @@ -78,14 +78,15 @@ func TestNoOptionalPrimitiveCleanup(t *testing.T) { } defer ins.Close(t.Context()) - actual, ok := ins.OptionalPrimitive(t.Context(), true) - if !ok { - t.Fatal(err) + in := true + actual := ins.OptionalPrimitive(t.Context(), &in) + if actual == nil { + t.Fatal("expected non-nil option result") } const expected = true - if actual != expected { - t.Errorf("expected: %t, but got: %t", expected, actual) + if *actual != expected { + t.Errorf("expected: %t, but got: %t", expected, *actual) } } diff --git a/examples/generate.go b/examples/generate.go index 302d0f3..909c31a 100644 --- a/examples/generate.go +++ b/examples/generate.go @@ -5,9 +5,11 @@ package examples //go:generate cargo build -p example-iface-method-returns-string --target wasm32-unknown-unknown --release //go:generate cargo build -p example-instructions --target wasm32-unknown-unknown --release //go:generate cargo build -p example-regressions --target wasm32-unknown-unknown --release +//go:generate cargo build -p example-variants --target wasm32-unknown-unknown --release //go:generate cargo run --bin gravity -- --world basic --output ./basic/basic.go ../target/wasm32-unknown-unknown/release/example_basic.wasm //go:generate cargo run --bin gravity -- --world records --output ./records/records.go ../target/wasm32-unknown-unknown/release/example_records.wasm //go:generate cargo run --bin gravity -- --world example --output ./iface-method-returns-string/example.go ../target/wasm32-unknown-unknown/release/example_iface_method_returns_string.wasm //go:generate cargo run --bin gravity -- --world instructions --output ./instructions/bindings.go ../target/wasm32-unknown-unknown/release/example_instructions.wasm //go:generate cargo run --bin gravity -- --world regressions --output ./regressions/regressions.go ../target/wasm32-unknown-unknown/release/example_regressions.wasm +//go:generate cargo run --bin gravity -- --world variants --output ./variants/variants.go ../target/wasm32-unknown-unknown/release/example_variants.wasm diff --git a/examples/regressions/regressions_test.go b/examples/regressions/regressions_test.go index 830c2d7..b6b660e 100644 --- a/examples/regressions/regressions_test.go +++ b/examples/regressions/regressions_test.go @@ -58,9 +58,55 @@ func (Pinger) Ping(_ context.Context) bool { return true } +// EmailChecker, BotVerifier, IpSource — regression 4 (cross-interface +// enum collision) and regression 5 (callback returning option). +type EmailChecker struct{} + +func (EmailChecker) IsAllowed(_ context.Context, email string) EmailCheckerValidatorResponse { + switch email { + case "allow@example.com": + return Yes + case "block@example.com": + return No + default: + return Maybe + } +} + +type BotVerifier struct{} + +func (BotVerifier) Verify(_ context.Context, botID string) BotVerifierValidatorResponse { + switch botID { + case "verified-bot": + return Verified + case "spoofed-bot": + return Spoofed + default: + return Unverifiable + } +} + +type IpSource struct{} + +func (IpSource) Lookup(_ context.Context, ip string) *string { + if ip == "127.0.0.1" { + s := "localhost" + return &s + } + return nil +} + func newInstance(t *testing.T) *RegressionsInstance { t.Helper() - fac, err := NewRegressionsFactory(t.Context(), Checker{}, Processor{}, Pinger{}) + fac, err := NewRegressionsFactory( + t.Context(), + Checker{}, + Processor{}, + Pinger{}, + EmailChecker{}, + BotVerifier{}, + IpSource{}, + ) if err != nil { t.Fatal(err) } @@ -159,7 +205,58 @@ func TestRunPing(t *testing.T) { } } -// TODO: When gravity supports generating Go variant type definitions, add E2E -// tests for export functions that accept variant parameters (e.g. a variant -// with a u32 or u64 payload). These would exercise the VariantLower codepath -// end-to-end through wazero. +// TestCrossInterfaceEnumCollision covers regression 4. Both +// `email-checker` and `bot-verifier` define `enum validator-response` +// inside the same world; without interface-scoped qualification the +// generated Go contained two `type ValidatorResponse interface { ... }` +// declarations and refused to compile. We verify (a) both host method +// signatures use qualified Go type names that exist alongside each other +// and (b) the wasm guest can dispatch on each independently. +func TestCrossInterfaceEnumCollision(t *testing.T) { + ins := newInstance(t) + + emailTests := []struct { + input string + want uint32 + }{ + {"allow@example.com", 0}, + {"block@example.com", 1}, + {"other", 2}, + } + for _, tt := range emailTests { + if got := ins.CheckEmailAllowed(t.Context(), tt.input); got != tt.want { + t.Errorf("CheckEmailAllowed(%q) = %d, want %d", tt.input, got, tt.want) + } + } + + botTests := []struct { + input string + want uint32 + }{ + {"verified-bot", 0}, + {"spoofed-bot", 1}, + {"other", 2}, + } + for _, tt := range botTests { + if got := ins.CheckBotVerified(t.Context(), tt.input); got != tt.want { + t.Errorf("CheckBotVerified(%q) = %d, want %d", tt.input, got, tt.want) + } + } +} + +// TestImportCallbackOptionString covers regression 5. The `ip-source` +// import returns `option`. Lowering it into wasm memory must run +// against the IMPORT-side module handle (`mod.Memory()` / +// `mod.ExportedFunction("cabi_realloc")`) — gravity previously +// hard-coded the export-side `i.module.*` handle in list and option +// lowering, producing `undefined: i` from the generated host wrapper. +func TestImportCallbackOptionString(t *testing.T) { + ins := newInstance(t) + + if got := ins.RunIpLookup(t.Context(), "127.0.0.1"); got != "localhost" { + t.Errorf("RunIpLookup(\"127.0.0.1\") = %q, want \"localhost\"", got) + } + if got := ins.RunIpLookup(t.Context(), "0.0.0.0"); got != "absent" { + t.Errorf("RunIpLookup(\"0.0.0.0\") = %q, want \"absent\"", got) + } +} diff --git a/examples/regressions/src/lib.rs b/examples/regressions/src/lib.rs index f5bead4..d12ff1e 100644 --- a/examples/regressions/src/lib.rs +++ b/examples/regressions/src/lib.rs @@ -1,4 +1,6 @@ -use gravity::regressions::{checker, pinger, processor}; +use gravity::regressions::{ + bot_verifier, checker, email_checker, ip_source, pinger, processor, +}; wit_bindgen::generate!({ world: "regressions", @@ -28,4 +30,24 @@ impl Guest for RegressionsWorld { fn run_ping() -> bool { pinger::ping() } + + fn check_email_allowed(email: String) -> u32 { + match email_checker::is_allowed(&email) { + email_checker::ValidatorResponse::Yes => 0, + email_checker::ValidatorResponse::No => 1, + email_checker::ValidatorResponse::Maybe => 2, + } + } + + fn check_bot_verified(bot_id: String) -> u32 { + match bot_verifier::verify(&bot_id) { + bot_verifier::ValidatorResponse::Verified => 0, + bot_verifier::ValidatorResponse::Spoofed => 1, + bot_verifier::ValidatorResponse::Unverifiable => 2, + } + } + + fn run_ip_lookup(ip: String) -> String { + ip_source::lookup(&ip).unwrap_or_else(|| "absent".to_string()) + } } diff --git a/examples/regressions/wit/regressions.wit b/examples/regressions/wit/regressions.wit index 0401d24..e40ddb0 100644 --- a/examples/regressions/wit/regressions.wit +++ b/examples/regressions/wit/regressions.wit @@ -1,7 +1,6 @@ package gravity:regressions; -// Regression 1: Import functions returning bool and enum types. -// Previously caused: todo!("implement handling of wasm signatures with results") +// Regression 1: import functions returning bool and enum types. interface checker { is-enabled: func(key: string) -> bool; @@ -13,33 +12,68 @@ interface checker { get-status: func(key: string) -> status; } -// Regression 2: Import functions with u32 parameters. -// Previously caused: uint32/uint64 type mismatch in generated Go code -// due to api.EncodeU32/api.DecodeU32 being used instead of uint32() casts. +// Regression 2: import functions with u32 parameters. interface processor { double: func(value: u32) -> u32; } -// Regression 3: Import functions with zero WIT parameters. -// Previously caused: Go syntax error from trailing comma in host function -// signature — func(ctx context.Context, mod api.Module, ,) — because the -// template unconditionally emitted a comma separator between the fixed -// params and the (empty) WIT params. +// Regression 3: import functions with zero WIT parameters. interface pinger { ping: func() -> bool; } +// Regression 4: two interfaces in the same world both define an enum named +// `validator-response`. Without interface-scoped name qualification, the +// generated Go redeclares the same type twice. +interface email-checker { + // Case names overlap intentionally with the bot-verifier enum below, but + // `maybe` is used in place of `unknown` so the case CONSTANTS don't + // collide with `checker.status.unknown` (constant qualification is out + // of scope for this regression). + enum validator-response { + yes, + no, + maybe, + } + is-allowed: func(email: string) -> validator-response; +} + +interface bot-verifier { + enum validator-response { + verified, + spoofed, + unverifiable, + } + verify: func(bot-id: string) -> validator-response; +} + +// Regression 5: import callback returning `option`. Exercises +// import-side lowering of an option (the option payload must be written +// through the host-side `mod.Memory()` / `mod.ExportedFunction(...)`). +interface ip-source { + lookup: func(ip: string) -> option; +} + world regressions { import checker; import processor; import pinger; + import email-checker; + import bot-verifier; + import ip-source; export check-enabled: func(key: string) -> bool; export check-status: func(key: string) -> u32; export double-value: func(value: u32) -> u32; export run-ping: func() -> bool; - // TODO: When variant type definition generation is supported, add variant - // exports here (e.g. variant with u32/u64 payloads) and corresponding E2E - // tests in regressions_test.go. + // Both calls funnel back into u32 discriminants so the test can assert + // against the host implementations without depending on the generated + // Go enum type names. + export check-email-allowed: func(email: string) -> u32; + export check-bot-verified: func(bot-id: string) -> u32; + + // Round-trip the option callback. Returns "absent" when the + // host returned `none`. + export run-ip-lookup: func(ip: string) -> string; } diff --git a/examples/variants/Cargo.toml b/examples/variants/Cargo.toml new file mode 100644 index 0000000..27e7fb6 --- /dev/null +++ b/examples/variants/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "example-variants" +version = "0.0.2" +edition = "2024" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +wit-bindgen = "=0.56.0" +wit-component = "=0.246.2" diff --git a/examples/variants/src/lib.rs b/examples/variants/src/lib.rs new file mode 100644 index 0000000..c9c0051 --- /dev/null +++ b/examples/variants/src/lib.rs @@ -0,0 +1,49 @@ +wit_bindgen::generate!({ + world: "variants", +}); + +struct VariantsWorld; + +export!(VariantsWorld); + +impl Guest for VariantsWorld { + fn classify(input: String) -> Entity { + match input.as_str() { + "email" => Entity::Email, + "phone" => Entity::PhoneNumber, + "ip" => Entity::IpAddress, + "cc" => Entity::CreditCardNumber, + other => Entity::Custom(other.to_string()), + } + } + + fn tag_all(inputs: Vec) -> Vec { + inputs + .into_iter() + .enumerate() + .map(|(i, input)| Detected { + kind: Self::classify(input), + start: i as u32, + end: (i + 1) as u32, + }) + .collect() + } + + fn choose(input: Config) -> String { + match input { + Config::Allow(allow) => format!( + "allow:{}:ctx={:?}", + allow.entities.len(), + allow.context_window_size + ), + Config::Deny(deny) => format!("deny:{}", deny.entities.len()), + } + } + + fn choose_many(input: Entities) -> String { + match input { + Entities::AllowAll(list) => format!("allow-all:{}", list.len()), + Entities::DenyAll(list) => format!("deny-all:{}", list.len()), + } + } +} diff --git a/examples/variants/variants_test.go b/examples/variants/variants_test.go new file mode 100644 index 0000000..c825400 --- /dev/null +++ b/examples/variants/variants_test.go @@ -0,0 +1,137 @@ +package variants + +import ( + "testing" +) + +// TestClassify_UnitCase exercises lifting a unit-case variant returned +// from the guest. The wrapper struct must be zero-sized and the type +// switch on the Go side must accept it as the variant interface. +func TestClassify_UnitCase(t *testing.T) { + ins := newInstance(t) + + got := ins.Classify(t.Context(), "email") + if _, ok := got.(EntityEmail); !ok { + t.Fatalf("Classify(\"email\") = %T, want EntityEmail", got) + } +} + +// TestClassify_PayloadCase exercises lifting a payload-bearing case +// where the payload is a primitive — the wrapper carries it in `Value`. +// Regression: the original sensitive-info-entity `custom(string)` shape. +func TestClassify_PayloadCase(t *testing.T) { + ins := newInstance(t) + + got := ins.Classify(t.Context(), "anything-else") + custom, ok := got.(EntityCustom) + if !ok { + t.Fatalf("Classify(\"anything-else\") = %T, want EntityCustom", got) + } + if custom.Value != "anything-else" { + t.Errorf("EntityCustom.Value = %q, want \"anything-else\"", custom.Value) + } +} + +// TestTagAll exercises: +// - returning a list of records from the guest (list lift of record) +// - records that contain a variant field (variant lift inside record) +// - the variant having both unit and payload cases in the same list +func TestTagAll(t *testing.T) { + ins := newInstance(t) + + got := ins.TagAll(t.Context(), []string{"email", "custom-thing", "ip"}) + if len(got) != 3 { + t.Fatalf("TagAll len = %d, want 3", len(got)) + } + if _, ok := got[0].Kind.(EntityEmail); !ok { + t.Errorf("got[0].Kind = %T, want EntityEmail", got[0].Kind) + } + if got[0].Start != 0 || got[0].End != 1 { + t.Errorf("got[0] indices = %d/%d, want 0/1", got[0].Start, got[0].End) + } + custom, ok := got[1].Kind.(EntityCustom) + if !ok { + t.Errorf("got[1].Kind = %T, want EntityCustom", got[1].Kind) + } else if custom.Value != "custom-thing" { + t.Errorf("got[1] Custom.Value = %q, want \"custom-thing\"", custom.Value) + } + if _, ok := got[2].Kind.(EntityIpAddress); !ok { + t.Errorf("got[2].Kind = %T, want EntityIpAddress", got[2].Kind) + } +} + +// TestChoose_DirectRecordDispatch exercises the WIT shorthand +// `case(case)` — when a variant case's only payload is a named record +// sharing its name, gravity should let the record satisfy the variant +// marker directly. Callers MUST be able to pass `Allow{...}` (not +// `ConfigAllow{Value: Allow{...}}`). +func TestChoose_DirectRecordDispatch(t *testing.T) { + ins := newInstance(t) + + window := uint32(3) + got := ins.Choose(t.Context(), Allow{ + Entities: []Entity{ + EntityEmail{}, + EntityCustom{Value: "tagged"}, + }, + ContextWindowSize: &window, + }) + want := "allow:2:ctx=Some(3)" + if got != want { + t.Errorf("Choose(Allow{...}) = %q, want %q", got, want) + } + + got = ins.Choose(t.Context(), Deny{ + Entities: []Entity{EntityIpAddress{}}, + }) + want = "deny:1" + if got != want { + t.Errorf("Choose(Deny{...}) = %q, want %q", got, want) + } +} + +// TestChooseMany exercises a variant whose payload is `list` +// (e.g. `allow-all(list)`). Gravity must wrap the list in a +// `EntitiesAllowAll{Value: []Entity{...}}` struct. +func TestChooseMany(t *testing.T) { + ins := newInstance(t) + + got := ins.ChooseMany(t.Context(), EntitiesAllowAll{ + Value: []Entity{ + EntityEmail{}, + EntityPhoneNumber{}, + EntityCustom{Value: "x"}, + }, + }) + if got != "allow-all:3" { + t.Errorf("ChooseMany allow-all = %q, want \"allow-all:3\"", got) + } + + got = ins.ChooseMany(t.Context(), EntitiesDenyAll{ + Value: []Entity{EntityIpAddress{}, EntityCreditCardNumber{}}, + }) + if got != "deny-all:2" { + t.Errorf("ChooseMany deny-all = %q, want \"deny-all:2\"", got) + } +} + +func newInstance(t *testing.T) *VariantsInstance { + t.Helper() + fac, err := NewVariantsFactory(t.Context()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { fac.Close(t.Context()) }) + + ins, err := fac.Instantiate(t.Context()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + if err := ins.Close(t.Context()); err != nil { + t.Error(err) + } + }) + + return ins +} diff --git a/examples/variants/wit/variants.wit b/examples/variants/wit/variants.wit new file mode 100644 index 0000000..684f84b --- /dev/null +++ b/examples/variants/wit/variants.wit @@ -0,0 +1,73 @@ +package gravity:variants; + +/// Regression coverage for variant codegen. +/// +/// The two variants below are intentionally shaped after the patterns that +/// initially blew up gravity when generating bindings for arcjet's +/// `bindings_js_req` world: +/// +/// * `entity` mixes unit cases with a payload-carrying case (`custom`), +/// and the payload is a primitive (`string`) rather than a named +/// record. Each case needs its own `{Variant}{Case}` wrapper struct. +/// This is the `sensitive-info-entity` shape. +/// +/// * `config` has two payload-only cases whose names equal the records +/// they carry (`allow(allow)`, `deny(deny)`). We rely on the WIT +/// shorthand to let the record satisfy the marker interface directly, +/// so existing arcjet callers that pass `Allow{...}` (instead of +/// `ConfigAllow{Value: Allow{...}}`) keep working. +/// +/// * `entities` wraps a list of variants inside a variant, exercising +/// variant lift/lower on the list element side. +/// +/// We also nest `entity` inside `detected` to exercise variant-in-record +/// codegen and `list` exercises list-of-record lifting. +world variants { + variant entity { + email, + phone-number, + ip-address, + credit-card-number, + custom(string), + } + + record allow { + entities: list, + /// `option` field — exercises option-in-record codegen. + context-window-size: option, + } + + record deny { + entities: list, + } + + /// Shorthand `case(case)` — should map to direct-record dispatch so + /// callers can pass `Allow{...}` / `Deny{...}` directly without + /// wrapping in `ConfigAllow{Value: ...}`. + variant config { + allow(allow), + deny(deny), + } + + /// Variant containing a `list` payload. + variant entities { + allow-all(list), + deny-all(list), + } + + record detected { + /// Variant field nested in a record (exercises Lift through record + /// fields). + kind: entity, + start: u32, + end: u32, + } + + /// Exports both lift (constructing a variant inside the guest and + /// returning it via a record) and lower (accepting a variant from the + /// host, dispatching on it, mutating, returning). + export classify: func(input: string) -> entity; + export tag-all: func(inputs: list) -> list; + export choose: func(input: config) -> string; + export choose-many: func(input: entities) -> string; +}