Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 123 additions & 56 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/wasm/text/Ctx.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ import hkmc2.utils.*

import document.*
import document.Document
import semantics.*
import semantics.*, Elaborator.State
import text.Param as WasmParam
import Instructions.*

import scala.annotation.nowarn
import scala.collection.mutable.{ArrayBuffer as ArrayBuf, Map as MutMap}

/** A Wasm function and its associated information.
Expand All @@ -30,14 +31,17 @@ import scala.collection.mutable.{ArrayBuffer as ArrayBuf, Map as MutMap}
* [[Seq]] of local variables (excluding parameters) and their names.
* @param body
* The expression of the function body.
* @param exports
* Optional export name for the function.
*/
class FuncInfo(
val id: Opt[SymIdx],
val id: SymIdx,
val typeIdx: TypeIdx,
params: Seq[Local -> Str],
nResults: Int,
locals: Seq[Local -> Str],
val body: Expr,
val `export`: Opt[Str],
) extends ToWat:

/** @param sym
Expand All @@ -60,13 +64,33 @@ class FuncInfo(
nResults: Int,
locals: Seq[Local -> Str],
body: Expr,
) = this(
sym.optionIf(_.nameIsMeaningful).map(sym => SymIdx(sym.nme)),
)(using Raise, Scope) = this(
SymIdx(sym.optionIf(_.nameIsMeaningful).fold(summon[Scope].allocateName(sym))(_.nme)),
typeIdx,
params,
nResults,
locals,
body,
sym.optionIf(_.nameIsMeaningful).map(_.nme),
)

@deprecated("Consider providing a symbolic identifier by using `Scope.allocateName` with a `TempSymbol`.")
def this(
id: Opt[SymIdx],
typeIdx: TypeIdx,
params: Seq[Local -> Str],
nResults: Int,
locals: Seq[Local -> Str],
body: Expr,
`export`: Opt[Str],
)(using Raise, Scope, State) = this(
id.getOrElse(SymIdx(summon[Scope].allocateName(TempSymbol(N, "")))),
typeIdx,
params,
nResults,
locals,
body,
`export`,
)

/** Returns the type of this function as a [[SignatureType]]. */
Expand All @@ -76,16 +100,16 @@ class FuncInfo(
)

def toWat: Document =
doc"""(func ${id.fold(doc"")(_.toWat)} (type ${typeIdx.toWat})${
doc"""(func ${id.toWat} (type ${typeIdx.toWat})${
getSignatureType.toWat.surroundUnlessEmpty(doc" ")
} #{ ${
locals.map: p =>
doc"(local $$${p._2} ${RefType.anyref.toWat})"
.mkDocument(doc" # ").surroundUnlessEmpty(doc" # ")
} # ${body.toWat} #} )${
id.fold(doc""): id =>
doc""" # (export "${id.id}" (func ${id.toWat})) # (elem declare func ${id.toWat})"""
}"""
`export`.fold(doc""): e =>
doc""" # (export "${e}" (func ${id.toWat}))"""
} # (elem declare func ${id.toWat})"""
end FuncInfo

/** A Wasm global and its associated information.
Expand Down Expand Up @@ -121,29 +145,33 @@ end GlobalInfo
* Symbolic identifier for the function, or `N` if the function is anonymous.
* @param compType
* The composite type this type definition represents.
* @param objectTag
* An optional object tag number associated with this type.
*/
class TypeInfo(val id: Opt[SymIdx], val compType: CompType) extends ToWat:
class TypeInfo(val id: SymIdx, val compType: CompType, val objectTag: Opt[Int]) extends ToWat:

/** @param sym
* The source [[BlockMemberSymbol]] which this type is generated from.
* @param compType
* The composite type this type definition represents.
*/
def this(sym: BlockMemberSymbol, compType: CompType) = this(
sym.optionIf(_.nameIsMeaningful).map(sym => SymIdx(sym.nme)),
def this(sym: BlockMemberSymbol, compType: CompType, objectTag: Opt[Int])(using Raise, Scope) = this(
SymIdx(sym.optionIf(_.nameIsMeaningful).fold(summon[Scope].allocateName(sym))(_.nme)),
compType,
objectTag,
)

private def idDoc: Document = id.fold(doc"")(_.toWat)
def this(id: Opt[SymIdx], compType: CompType)(using Raise, Scope, State) =
this(id.getOrElse(SymIdx(summon[Scope].allocateName(TempSymbol(N, "")))), compType, N)

def toWat: Document = compType match
case struct: StructType if struct.isSubtype =>
val parentsDoc = struct.parents.optionIf(_.nonEmpty).fold(doc""): parents =>
parents.map(_.toWat).mkDocument(doc" ")
val structDoc = struct.copy(isSubtype = false).toWat
doc"(type${idDoc.surroundUnlessEmpty(doc" ")} (sub${parentsDoc.surroundUnlessEmpty(doc" ")} ${structDoc}))"
doc"(type ${id.toWat} (sub${parentsDoc.surroundUnlessEmpty(doc" ")} ${structDoc}))"
case _ =>
doc"(type${idDoc.surroundUnlessEmpty(doc" ")} ${compType.toWat})"
doc"(type ${id.toWat} ${compType.toWat})"
end TypeInfo

/** A WebAssembly exception tag declaration.
Expand Down Expand Up @@ -241,22 +269,25 @@ end Ctx
*/
class Ctx(
types: ArrayBuf[TypeInfo],
namedTypes: MutMap[BlockMemberSymbol, NumIdx],
namedTypes: MutMap[BlockMemberSymbol, Int],
memoryImports: ArrayBuf[MemoryImport],
functionImports: ArrayBuf[FuncImport],
dataSegments: ArrayBuf[DataSegment],
funcs: ArrayBuf[FuncInfo],
funcInfosByIndex: MutMap[NumIdx, FuncInfo],
funcInfosByIndex: MutMap[Int, FuncInfo],
globals: ArrayBuf[GlobalInfo],
namedFuncs: MutMap[Symbol, NumIdx],
namedFuncs: MutMap[Symbol, Int],
tags: ArrayBuf[TagInfo],
namedGlobals: MutMap[Symbol, NumIdx],
var locals: Ls[MutMap[Local, NumIdx]],
namedGlobals: MutMap[Symbol, Int],
var locals: Ls[MutMap[Local, Int]],
private var startFunc: Opt[FuncIdx],
) extends ToWat:

import Ctx.prettyString

/** Counter for generating object tags. */
private var objectTagNum = 0

private val wasmIntrinsicFuncs: MutMap[Str, FuncIdx] = MutMap.empty
private val wasmIntrinsicTypes: MutMap[WasmIntrinsicType, TypeIdx] = MutMap.empty
private val wasmIntrinsicTags: MutMap[Str, TagIdx] = MutMap.empty
Expand All @@ -281,37 +312,54 @@ class Ctx(
labelTargets.collectFirst:
case (sym, target) if sym eq label => target

/** Returns a new number to be used as an object tag. */
def getFreshObjectTag(): Int =
val tag = objectTagNum
objectTagNum += 1
tag

/** Adds a type into this context. */
def addType(sym: Opt[BlockMemberSymbol], typeInfo: TypeInfo): TypeIdx =
val numIdx = NumIdx(types.size)
val numIdx = types.size
types += typeInfo
sym.foreach:
namedTypes(_) = numIdx
TypeIdx(typeInfo.id.getOrElse(numIdx))
TypeIdx(typeInfo.id)

@deprecated("Use the overload without `resolveSymIdx` instead.")
def getType(typeref: TypeIdx | BlockMemberSymbol, resolveSymIdx: Bool): Opt[TypeIdx] =
if resolveSymIdx then
typeref match
case TypeIdx(SymIdx(nme)) =>
namedTypes.find(_._1.nme == nme).map(t => TypeIdx(NumIdx(t._2)))
case typeidx: TypeIdx => S(typeidx)
case sym: BlockMemberSymbol => namedTypes.get(sym).map(idx => TypeIdx(NumIdx(idx)))
else getType(typeref)

/** Returns the [[TypeIdx]] of the given `typeref`, optionally resolving the symbolic index into a numeric index.
*/
def getType(typeref: TypeIdx | BlockMemberSymbol, resolveSymIdx: Bool = false): Opt[TypeIdx] =
typeref match
case TypeIdx(SymIdx(nme)) if resolveSymIdx =>
namedTypes.find(_._1.nme == nme).map(t => TypeIdx(t._2))
case typeidx: TypeIdx => S(typeidx)
case sym: BlockMemberSymbol if resolveSymIdx => namedTypes.get(sym).map(TypeIdx(_))
case sym: BlockMemberSymbol =>
getType(sym, resolveSymIdx = true).map: numIdx =>
getTypeInfo(numIdx).flatMap(_.id).fold(numIdx)(TypeIdx(_))
def getType(typeref: TypeIdx | BlockMemberSymbol): Opt[TypeIdx] = typeref match
case typeidx: TypeIdx => S(typeidx)
case sym: BlockMemberSymbol => getTypeInfo(typeref).map(ti => TypeIdx(ti.id))

/** Same as [[getType]] but throws an exception when the `typeref` is not found. */
def getType_!(typeref: TypeIdx | BlockMemberSymbol, resolveSymIdx: Bool = false): TypeIdx =
@deprecated("Use the overload without `resolveSymIdx` instead.")
def getType_!(typeref: TypeIdx | BlockMemberSymbol, resolveSymIdx: Bool): TypeIdx =
getType(typeref, resolveSymIdx).getOrElse:
lastWords(s"Missing type definition for ${typeref.prettyString}")

/** Same as [[getType]] but throws an exception when the `typeref` is not found. */
def getType_!(typeref: TypeIdx | BlockMemberSymbol): TypeIdx =
getType(typeref).getOrElse:
lastWords(s"Missing type definition for ${typeref.prettyString}")

/** Returns the [[TypeInfo]] instance associated with the given `typeref`. */
@nowarn("cat=deprecation")
def getTypeInfo(typeref: TypeIdx | BlockMemberSymbol): Opt[TypeInfo] = typeref match
case TypeIdx(NumIdx(idx)) => types.unapply(idx.toInt)
case TypeIdx(SymIdx(nme)) =>
namedTypes.find(_._1.nme == nme).flatMap(t => getTypeInfo(TypeIdx(t._2)))
case sym: BlockMemberSymbol => namedTypes.get(sym).flatMap(idx => getTypeInfo(TypeIdx(idx)))
// TODO(Derppening): Consider adding a `Map[SymIdx, TypeInfo]` for faster lookup
types.find(_.id.id == nme)
case sym: BlockMemberSymbol => namedTypes.get(sym).map(idx => types(idx))

/** Same as [[getTypeInfo]] but throws an exception when the `typeref` is not found. */
def getTypeInfo_!(typeref: TypeIdx | BlockMemberSymbol): TypeInfo =
Expand All @@ -320,23 +368,23 @@ class Ctx(

/** Adds a function into this context. */
def addFunc(sym: Opt[Symbol], funcInfo: FuncInfo): FuncIdx =
val numIdx = NumIdx(functionImports.size + funcs.size)
val numIdx = functionImports.size + funcs.size
funcs += funcInfo
funcInfosByIndex(numIdx) = funcInfo
sym.foreach:
namedFuncs(_) = numIdx
FuncIdx(funcInfo.id.getOrElse(numIdx))
FuncIdx(funcInfo.id)

/** Adds a function import into this context.
*
* Returns the function index in the global function index space.
*/
def addFunctionImport(sym: Opt[Symbol], funcImport: FuncImport): FuncIdx =
val numIdx = NumIdx(functionImports.size + funcs.size)
val numIdx = functionImports.size + funcs.size
functionImports += funcImport
sym.foreach:
namedFuncs(_) = numIdx
FuncIdx(funcImport.id.getOrElse(numIdx))
FuncIdx(funcImport.id)

/** Returns the cached function import for (`module`, `name`), creating it with `createImport` if needed.
*/
Expand All @@ -359,13 +407,18 @@ class Ctx(
memoryImports(idx) = existing.copy(minPages = newMin)
case N =>
val idx = memoryImports.size
memoryImports += MemoryImport(module, name, minPages)
memoryImports += MemoryImport(module, name, SymIdx(name), minPages)
cachedMemoryImport(key) = idx

/** Returns the minimum page requirement of memory import (`module`, `name`) if present. */
@deprecated("Use `getMemoryImport` instead to get the full `MemoryImport` information.")
def getMemoryImportMinPages(module: Str, name: Str): Opt[Int] =
memoryImports.find(m => m.module === module && m.name === name).map(_.minPages)

/** Returns the memory import information for the given (`module`, `name`) tuple if present. */
def getMemoryImport(module: Str, name: Str): Opt[MemoryImport] =
memoryImports.find(m => m.module === module && m.name === name)

/** Adds a data segment into this context. */
def addDataSegment(seg: DataSegment): Unit =
dataSegments += seg
Expand All @@ -375,29 +428,43 @@ class Ctx(
tags += tagInfo
TagIdx(tagInfo.id)

@deprecated("Use the overload without `resolveSymIdx` instead.")
def getFunc(funcref: FuncIdx | Symbol, resolveSymIdx: Bool): Opt[FuncIdx] =
if resolveSymIdx then
funcref match
case FuncIdx(SymIdx(nme)) if resolveSymIdx =>
namedFuncs.find(_._1.nme == nme).map(f => FuncIdx(NumIdx(f._2)))
case funcidx: FuncIdx => S(funcidx)
case sym: Symbol => namedFuncs.get(sym).map(idx => FuncIdx(NumIdx(idx)))
else getFunc(funcref)

/** Returns the [[FuncIdx]] of the given `funcref`, optionally resolving the symbolic index into a numeric index.
*/
def getFunc(funcref: FuncIdx | Symbol, resolveSymIdx: Bool = false): Opt[FuncIdx] = funcref match
case FuncIdx(SymIdx(nme)) if resolveSymIdx =>
namedFuncs.find(_._1.nme == nme).map(f => FuncIdx(f._2))
def getFunc(funcref: FuncIdx | Symbol): Opt[FuncIdx] = funcref match
case funcidx: FuncIdx => S(funcidx)
case sym: Symbol if resolveSymIdx => namedFuncs.get(sym).map(FuncIdx(_))
case sym: Symbol =>
getFunc(sym, resolveSymIdx = true).map: numIdx =>
getFuncInfo(numIdx).flatMap(_.id).fold(numIdx)(FuncIdx(_))
case sym: Symbol => getFuncInfo(funcref).map(fi => FuncIdx(fi.id))

/** Same as [[getFunc]] but throws an exception when the `funcref` is not found. */
def getFunc_!(funcref: FuncIdx | Symbol, resolveSymIdx: Bool = false): FuncIdx =
@deprecated("Use the overload without `resolveSymIdx` instead.")
def getFunc_!(funcref: FuncIdx | Symbol, resolveSymIdx: Bool): FuncIdx =
getFunc(funcref, resolveSymIdx).getOrElse:
lastWords(s"Missing function definition for ${funcref.prettyString}")

/** Same as [[getFunc]] but throws an exception when the `funcref` is not found. */
def getFunc_!(funcref: FuncIdx | Symbol): FuncIdx =
getFunc(funcref).getOrElse:
lastWords(s"Missing function definition for ${funcref.prettyString}")

/** Returns the [[FuncInfo]] instance associated with the given `funcref`. */
@nowarn("cat=deprecation")
def getFuncInfo(funcref: FuncIdx | Symbol): Opt[FuncInfo] = funcref match
case FuncIdx(numIdx @ NumIdx(idx)) =>
funcInfosByIndex.get(numIdx).orElse:
case FuncIdx(NumIdx(idx)) =>
funcInfosByIndex.get(idx).orElse:
val localIdx = idx.toInt - functionImports.size
if localIdx < 0 then N else funcs.unapply(localIdx)
case funcref => getFunc(funcref, resolveSymIdx = true).flatMap(getFuncInfo(_))
case FuncIdx(SymIdx(nme)) =>
// TODO(Derppening): Consider adding a `Map[SymIdx, FuncInfo]` for faster lookup
funcs.find(_.id.id == nme)
case funcref: Symbol => namedFuncs.get(funcref).map(idx => funcs(idx))

/** Same as [[getFuncInfo]] but throws an exception when the `funcref` is not found. */
def getFuncInfo_!(funcref: FuncIdx | Symbol): FuncInfo =
Expand All @@ -412,9 +479,9 @@ class Ctx(

/** Adds a new local variable into the top-most variable scope. */
def addLocal(sym: Local): LocalIdx =
val numIdx = NumIdx(locals.head.size)
val numIdx = locals.head.size
locals.head(sym) = numIdx
LocalIdx(numIdx)
LocalIdx(SymIdx(sym.nme))

/** Adds a [[Seq]] of local variables into the top-most variable scope. */
def addLocals(syms: Seq[Local]): Seq[LocalIdx] =
Expand All @@ -425,7 +492,7 @@ class Ctx(

/** Adds a new variable into the global variable scope. */
def addGlobal(sym: Symbol, globalInfo: GlobalInfo): GlobalIdx =
val numIdx = NumIdx(globals.size)
val numIdx = globals.size
globals += globalInfo
namedGlobals(sym) = numIdx
GlobalIdx(globalInfo.id)
Expand Down Expand Up @@ -472,8 +539,8 @@ class Ctx(
/** Converts a [[Map]] of symbols and their respective numeric identifiers into a [[Seq]] of symbols sorted by its
* numeric index.
*/
private def wasmLocalsToSeq(scope: Map[Symbol, NumIdx]): Seq[Local] =
scope.toSeq.sortBy(_._2.index).map(_._1)
private def wasmLocalsToSeq(scope: Map[Symbol, Int]): Seq[Local] =
scope.toSeq.sortBy(_._2).map(_._1)

/** Returns a tuple containing the variables in the current `global` and `local` scopes respectively.
*/
Expand Down
Loading
Loading