diff --git a/src/Agda/Compiler/Rust/AgdaToRustExpr.hs b/src/Agda/Compiler/Rust/AgdaToRustExpr.hs index 3e63c95..4fa9190 100644 --- a/src/Agda/Compiler/Rust/AgdaToRustExpr.hs +++ b/src/Agda/Compiler/Rust/AgdaToRustExpr.hs @@ -5,17 +5,19 @@ module Agda.Compiler.Rust.AgdaToRustExpr ( compile, compileModule ) where import Control.Monad.IO.Class ( MonadIO(liftIO) ) import qualified Data.List.NonEmpty as Nel -import Agda.Compiler.Backend ( IsMain ) +import Agda.Compiler.Backend ( Defn(..), funCompiled, funClauses, IsMain, RecordData(..)) import Agda.Syntax.Abstract.Name ( QName ) import Agda.Syntax.Common.Pretty ( prettyShow ) -import Agda.Syntax.Common ( Arg(..), ArgName, Named(..), moduleNameParts ) +import Agda.Syntax.Common ( moduleNameParts ) +import Agda.Syntax.Common ( Arg(..), ArgName, Named(..), NamedName, WithOrigin(..), Ranged(..) ) import Agda.Syntax.Internal ( - Clause(..), DeBruijnPattern, DBPatVar(..), Dom(..), unDom, PatternInfo(..), Pattern'(..), + Clause(..), DeBruijnPattern, DBPatVar(..), Dom(..), Dom'(..), unDom, PatternInfo(..), Pattern'(..), qnameName, qnameModule, Telescope, Tele(..), Term(..), Type, Type''(..) ) import Agda.Syntax.TopLevelModuleName ( TopLevelModuleName ) import Agda.TypeChecking.Monad.Base ( Definition(..) ) import Agda.TypeChecking.Monad import Agda.TypeChecking.CompiledClause ( CompiledClauses(..), CompiledClauses'(..) ) +import Agda.TypeChecking.Telescope ( teleNamedArgs, teleArgs, teleArgNames ) import Agda.Compiler.Rust.CommonTypes ( Options, CompiledDef, ModuleEnv ) import Agda.Compiler.Rust.RustExpr ( RustExpr(..), RustName, RustType, RustElem(..), FunBody ) @@ -30,24 +32,42 @@ compile _ _ _ Defn{..} compileDefn :: QName -> Defn -> CompiledDef compileDefn defName theDef = + -- https://hackage.haskell.org/package/Agda/docs/Agda-Compiler-Backend.html#t:Defn case theDef of Datatype{dataCons = fields} -> compileDataType defName fields Function{funCompiled = funDef, funClauses = fc} -> compileFunction defName funDef fc - _ -> - Unhandled "compileDefn" (show defName ++ " = " ++ show theDef) + RecordDefn(RecordData{_recFields = recFields, _recTel = recTel}) -> + compileRecord defName recFields recTel + other -> + Unhandled "compileDefn" (show defName ++ "\n = \n" ++ show theDef) compileDataType :: QName -> [QName] -> CompiledDef -compileDataType defName fields = TeEnum (showName defName) (map showName fields) +compileDataType defName fields = ReEnum (showName defName) (map showName fields) + +compileRecord :: QName -> [Dom QName] -> Telescope -> CompiledDef +compileRecord defName recFields recTel = ReRec (showName defName) (foldl varsFromTelescope [] recTel) + +varsFromTelescope :: [RustElem] -> Dom Type -> [RustElem] +varsFromTelescope xs dt = RustElem (nameFromDom dt) (fromDom dt) : xs + +nameFromDom :: Dom Type -> RustName +nameFromDom dt = case (domName dt) of + Nothing -> error ("\nnameFromDom [" ++ show dt ++ "]\n") + Just a -> namedNameToStr a + +-- https://hackage.haskell.org/package/Agda-2.6.4.3/docs/Agda-Syntax-Common.html#t:NamedName +namedNameToStr :: NamedName -> RustName +namedNameToStr n = rangedThing (woThing n) compileFunction :: QName -> Maybe CompiledClauses -> [Clause] -> CompiledDef -compileFunction defName funDef fc = TeFun +compileFunction defName funDef fc = ReFun (showName defName) - (RustElem (compileFunctionArgument fc) (compileFunctionArgType fc)) + [(RustElem (compileFunctionArgument fc) (compileFunctionArgType fc))] (compileFunctionResultType fc) (compileFunctionBody funDef) @@ -120,7 +140,7 @@ showName = prettyShow . qnameName compileModule :: TopLevelModuleName -> [CompiledDef] -> CompiledDef compileModule mName cdefs = - TeMod (moduleName mName) cdefs + ReMod (moduleName mName) cdefs moduleName :: TopLevelModuleName -> String moduleName n = prettyShow (Nel.last (moduleNameParts n)) diff --git a/src/Agda/Compiler/Rust/PrettyPrintingUtils.hs b/src/Agda/Compiler/Rust/PrettyPrintingUtils.hs index f14e4d3..5edda62 100644 --- a/src/Agda/Compiler/Rust/PrettyPrintingUtils.hs +++ b/src/Agda/Compiler/Rust/PrettyPrintingUtils.hs @@ -1,12 +1,12 @@ module Agda.Compiler.Rust.PrettyPrintingUtils ( prettyPrintRustExpr, moduleHeader ) where -import Data.List ( intersperse ) +import Data.List ( intersperse, intercalate ) import Agda.Compiler.Rust.CommonTypes ( CompiledDef ) import Agda.Compiler.Rust.RustExpr ( RustExpr(..), RustElem(..), FunBody ) prettyPrintRustExpr :: CompiledDef -> String prettyPrintRustExpr def = case def of - (TeEnum name fields) -> + (ReEnum name fields) -> "enum" <> exprSeparator <> name <> exprSeparator @@ -14,7 +14,7 @@ prettyPrintRustExpr def = case def of indent -- TODO this is too simplistic indentation <> concat (intersperse ", " fields)) <> defsSeparator - (TeFun fName (RustElem aName aType) resType fBody) -> + (ReFun fName [RustElem aName aType] resType fBody) -> "pub fn" <> exprSeparator <> fName <> argList ( @@ -25,12 +25,15 @@ prettyPrintRustExpr def = case def of <> exprSeparator <> bracket ( indent <> (prettyPrintFunctionBody fBody)) <> defsSeparator - (TeMod mName defs) -> + (ReMod mName defs) -> moduleHeader mName <> bracket ( defsSeparator -- empty line before first definition in module <> combineLines (map prettyPrintRustExpr defs)) <> defsSeparator + (ReRec name args) -> "pub struct" <> exprSeparator <> name + <> exprSeparator <> (bracket (combineThem ",\n" (map (indent ++) (map printVar args)))) + <> defsSeparator (Unhandled name payload) -> "" -- XXX at the end there should be no Unhandled expression -- other -> "unsupported prettyPrintRustExpr " ++ (show other) @@ -41,6 +44,9 @@ bracket str = "{\n" <> str <> "\n}" argList :: String -> String argList str = "(" <> str <> ")" +printVar :: RustElem -> String +printVar (RustElem sName sType) = sName <> ":" <> exprSeparator <> sType + indent :: String indent = " " @@ -59,6 +65,9 @@ funReturnTypeSeparator = "->" combineLines :: [String] -> String combineLines xs = unlines (filter (not . null) xs) +combineThem :: String -> [String] -> String +combineThem s xs = intercalate s xs + prettyPrintFunctionBody :: FunBody -> String prettyPrintFunctionBody fBody = "return" <> exprSeparator <> fBody <> ";" diff --git a/src/Agda/Compiler/Rust/RustExpr.hs b/src/Agda/Compiler/Rust/RustExpr.hs index 3bdae4d..c08c965 100644 --- a/src/Agda/Compiler/Rust/RustExpr.hs +++ b/src/Agda/Compiler/Rust/RustExpr.hs @@ -15,9 +15,10 @@ data RustElem = RustElem RustName RustType deriving ( Show ) data RustExpr - = TeMod RustName [RustExpr] - | TeEnum RustName [RustName] - | TeFun RustName RustElem RustType FunBody + = ReMod RustName [RustExpr] + | ReEnum RustName [RustName] + | ReFun RustName [RustElem] RustType FunBody + | ReRec RustName [RustElem] | Unhandled RustName String deriving ( Show ) diff --git a/test/hello.agda b/test/hello.agda index c904561..c9ace23 100644 --- a/test/hello.agda +++ b/test/hello.agda @@ -18,17 +18,16 @@ id_rgb x = x -- product types --- record ThePair : Set where --- field --- pairFst : Rgb --- pairSnd : WeekDay --- {-# COMPILE AGDA2RUST ThePair #-} +record ThePair : Set where + field + pairFst : Rgb + pairSnd : WeekDay +{-# COMPILE AGDA2RUST ThePair #-} --- record Foo (A : Set) : Set where --- field --- foo : Pair A A - --- TODO Data.Product as Rust tuple +record Foo : Set where + field + foo : ThePair +{-# COMPILE AGDA2RUST Foo #-} -- TODO function returning constant result -- as-friday : TheRgb → TheWeekDay @@ -53,6 +52,7 @@ id_rgb x = x -- TODO polymorphic types +-- TODO Data.Product as Rust tuple -- TODO Data.Bool -- TODO if expressions, and, or diff --git a/test/hello.rs b/test/hello.rs index 6f5a207..fcddcbf 100644 --- a/test/hello.rs +++ b/test/hello.rs @@ -12,5 +12,14 @@ pub fn id_rgb(x: Rgb) -> Rgb { return x; } +pub struct ThePair { + pairSnd: WeekDay, + pairFst: Rgb +} + +pub struct Foo { + foo: ThePair +} + }