2. CEK Machine Tutorial : Mockup Demo - wimsio/universities GitHub Wiki

{-# LANGUAGE LambdaCase #-}

import Data.Map (Map)
import qualified Data.Map as Map
import Data.List (intercalate)
import Data.Maybe (fromMaybe)

-- === Simple Term AST ===
data Term
    = Var String
    | Lam String Term
    | App Term Term
    | Const Integer
    | Builtin String
    | BuiltinAdd Integer
    | BuiltinMul Integer
    deriving (Eq, Show)

-- === Cost and Budget ===
data ExBudget = ExBudget { cpu :: Int, mem :: Int }
    deriving (Eq)

instance Show ExBudget where
    show (ExBudget c m) = "{ cpu = " ++ show c ++ ", mem = " ++ show m ++ " }"

zeroBudget :: ExBudget
zeroBudget = ExBudget 0 0

addBudget :: ExBudget -> ExBudget -> ExBudget
addBudget (ExBudget c1 m1) (ExBudget c2 m2) = ExBudget (c1+c2) (m1+m2)

costForStep :: Term -> ExBudget
costForStep = \case
    Lam{}    -> ExBudget 2 1
    Var{}    -> ExBudget 1 1
    App{}    -> ExBudget 2 2
    Const{}  -> ExBudget 1 1
    Builtin{}-> ExBudget 3 1
    BuiltinAdd{} -> ExBudget 3 1
    BuiltinMul{} -> ExBudget 3 1

-- === Costing Mode ===
data CostingMode = Counting | Restricting ExBudget deriving (Eq, Show)

-- === Emitter Mode ===
data EmitterMode = NoEmitter | LogEmitter | LogWithBudgetEmitter deriving (Eq, Show)

-- === CEK Evaluation Errors ===
data CEKError
    = StructuralError String
    | OperationalError String
    | OutOfBudget
    deriving (Eq, Show)

-- === Environment for variable bindings ===
type Env = Map String Term

-- === Evaluation Result ===
data EvalResult = EvalSuccess Term | EvalFailure CEKError deriving (Eq, Show)

-- === The CEK Machine ===
data MachineState = MachineState
    { env     :: Env
    , budget  :: ExBudget
    , costing :: CostingMode
    , emitter :: EmitterMode
    , logs    :: [String]
    }

defaultState :: CostingMode -> EmitterMode -> MachineState
defaultState cm em = MachineState Map.empty zeroBudget cm em []

-- Budget update and logging
updateState :: MachineState -> Term -> String -> (MachineState, Bool)
updateState ms term msg =
    let c = costForStep term
        newBudget = addBudget (budget ms) c
        budgetOk = case costing ms of
            Counting -> True
            Restricting lim -> cpu newBudget <= cpu lim && mem newBudget <= mem lim
        newLog = case emitter ms of
            NoEmitter -> []
            LogEmitter -> [msg]
            LogWithBudgetEmitter -> [msg ++ " | Budget: " ++ show newBudget]
        ms' = ms { budget = newBudget, logs = logs ms ++ newLog }
    in (ms', budgetOk)

-- === CEK Evaluation Function ===
cek :: MachineState -> Term -> (EvalResult, MachineState)
cek ms term = case term of
    Const n -> 
        let (ms', ok) = updateState ms (Const n) ("Const " ++ show n)
        in if ok then (EvalSuccess (Const n), ms') else (EvalFailure OutOfBudget, ms')
    Var x ->
        let (ms', ok) = updateState ms (Var x) ("Var " ++ x)
        in if not ok then (EvalFailure OutOfBudget, ms')
           else case Map.lookup x (env ms) of
                Just val -> (EvalSuccess val, ms')
                Nothing  -> (EvalFailure (StructuralError ("Unbound variable: " ++ x)), ms')
    Lam x body -> 
        let (ms', ok) = updateState ms (Lam x body) ("Lam " ++ x)
        in if ok then (EvalSuccess (Lam x body), ms') else (EvalFailure OutOfBudget, ms')
    App f a -> 
        let (ms1, ok1) = updateState ms (App f a) "App"
        in if not ok1 then (EvalFailure OutOfBudget, ms1) else
            case cek ms1 f of
                (EvalSuccess (Lam x body), ms2) ->
                    let (EvalSuccess aval, ms3) = cek ms2 a
                        ms4 = ms3 { env = Map.insert x aval (env ms3) }
                    in cek ms4 body
                (EvalSuccess (Builtin bname), ms2) -> 
                    cekBuiltin ms2 bname a
                (EvalSuccess (BuiltinAdd n), ms2) ->
                    case cek ms2 a of
                        (EvalSuccess (Const m), ms3) -> 
                            let (ms4, ok4) = updateState ms3 (App (BuiltinAdd n) a) ("Add " ++ show n ++ " " ++ show m)
                            in if ok4 then (EvalSuccess (Const (n + m)), ms4) else (EvalFailure OutOfBudget, ms4)
                        (fail@(EvalFailure _), ms3) -> (fail, ms3)
                        _ -> (EvalFailure (OperationalError "Add expects integer argument"), ms2)
                (EvalSuccess (BuiltinMul n), ms2) ->
                    case cek ms2 a of
                        (EvalSuccess (Const m), ms3) -> 
                            let (ms4, ok4) = updateState ms3 (App (BuiltinMul n) a) ("Mul " ++ show n ++ " " ++ show m)
                            in if ok4 then (EvalSuccess (Const (n * m)), ms4) else (EvalFailure OutOfBudget, ms4)
                        (fail@(EvalFailure _), ms3) -> (fail, ms3)
                        _ -> (EvalFailure (OperationalError "Mul expects integer argument"), ms2)
                (EvalSuccess other, ms2) -> 
                    (EvalFailure (OperationalError ("Trying to apply non-function: " ++ show other)), ms2)
                (fail@(EvalFailure _), ms2) -> (fail, ms2)
    Builtin bname ->
        let (ms', ok) = updateState ms (Builtin bname) ("Builtin " ++ bname)
        in if ok then (EvalSuccess (Builtin bname), ms') else (EvalFailure OutOfBudget, ms')
    BuiltinAdd n -> 
        let (ms', ok) = updateState ms (BuiltinAdd n) ("Partial Add " ++ show n)
        in if ok then (EvalSuccess (BuiltinAdd n), ms') else (EvalFailure OutOfBudget, ms')
    BuiltinMul n -> 
        let (ms', ok) = updateState ms (BuiltinMul n) ("Partial Mul " ++ show n)
        in if ok then (EvalSuccess (BuiltinMul n), ms') else (EvalFailure OutOfBudget, ms')
  where
    cekBuiltin ms' "addInteger" arg = case cek ms' arg of
        (EvalSuccess (Const n), ms'') -> (EvalSuccess (BuiltinAdd n), ms'')
        (fail@(EvalFailure _), ms'')  -> (fail, ms'')
        _ -> (EvalFailure (OperationalError "addInteger expects integer argument"), ms')
    cekBuiltin ms' "multiplyInteger" arg = case cek ms' arg of
        (EvalSuccess (Const n), ms'') -> (EvalSuccess (BuiltinMul n), ms'')
        (fail@(EvalFailure _), ms'')  -> (fail, ms'')
        _ -> (EvalFailure (OperationalError "multiplyInteger expects integer argument"), ms')
    cekBuiltin ms' other _ = (EvalFailure (OperationalError ("Unknown builtin: " ++ other)), ms')

-- === Utility: Run and pretty print ===
runCEK :: CostingMode -> EmitterMode -> Term -> IO ()
runCEK cm em term = do
    let ms = defaultState cm em
        (result, finalState) = cek ms term
    putStrLn $ "Result: " ++ show result
    putStrLn $ "Final Budget: " ++ show (budget finalState)
    putStrLn $ "Logs:\n" ++ unlines (logs finalState)

-- === Test Examples ===

-- (\x -> x + 1) 41
test1 :: Term
test1 = App (Lam "x" (App (App (Builtin "addInteger") (Var "x")) (Const 1))) (Const 41)

-- (\x -> x * 2) 5
test2 :: Term
test2 = App (Lam "x" (App (App (Builtin "multiplyInteger") (Var "x")) (Const 2))) (Const 5)

-- (\x -> (\y -> x + y)) 10 20
test3 :: Term
test3 = App (App (Lam "x" (Lam "y" (App (App (Builtin "addInteger") (Var "x")) (Var "y")))) (Const 10)) (Const 20)

main :: IO ()
main = do
    putStrLn "-- Test 1: (\\x -> x + 1) 41 --"
    runCEK Counting LogWithBudgetEmitter test1

    putStrLn "\n-- Test 2: (\\x -> x * 2) 5 --"
    runCEK Counting LogEmitter test2

    putStrLn "\n-- Test 3: (\\x -> (\\y -> x + y)) 10 20 --"
    runCEK (Restricting (ExBudget 5 5)) LogWithBudgetEmitter test3  -- should fail if budget too small

    putStrLn "\n-- Test 4: Structural Error (unbound var) --"
    runCEK Counting LogEmitter (Var "z")