import ast import parse # Does macro expansion using Kohlbecker and Wand's 'Macro by Example' algorithm verbose = -1 def debug(v, level = 0): if verbose >= level: print v # the B function def makeMatcher(pattern, literals): # debug("Pattern '%s' Literals are %s" % (pattern, literals)) if pattern.match("()"): # debug("pattern %s is ()" % pattern, 1) def ok(input): # debug("Match '%s' to ()" % input, 1) return input.match("()") return ok if not pattern.isList() and str(pattern) in literals: def ok(input): # debug("Match '%s' to a literal %s" % (input, literals), 1) return str(input) in literals return ok if pattern.isVariable(): def ok(input): # debug("Match '%s' to variable '%s'" % (input, pattern), 1) # return not str(input) in literals and not input.isList() return not input.isVariable() or not str(input) in literals return ok if pattern.match("(x ...)"): def doit(matcher): def ok(input): # debug("Match '%s' to (%s ...)" % (input, pattern.head()), 1) return input.isList() and \ reduce(lambda a,b: a and b, [matcher(s) for s in input], True) return ok return doit(makeMatcher(pattern.head(), literals)) if pattern.isList(): def doit(headMatch, tailMatch): def ok(input): # debug("Match '%s' to a list %s" % (input, pattern), 1) return input.isList() and \ not input.null() and \ headMatch(input.head()) and \ tailMatch(input.tail()) return ok return doit(makeMatcher(pattern.head(), literals), makeMatcher(pattern.tail(), literals)) raise Exception("Pattern did not match! '%s'" % pattern) def testMatcher(): def test(s_input, s_pattern): input = parse.parseString(s_input) pattern = parse.parseString(s_pattern) if not makeMatcher(pattern, [])(input): raise Exception("Pattern '%s' failed to match '%s'" % (s_pattern, s_input)) def testFail(s_input, s_pattern): failed = False try: test(s_input, s_pattern) except Exception: failed = True if not failed: raise Exception("Pattern '%s' did not fail to match '%s'" % (s_pattern, s_input)) test("(x)", "(x)") test("(a b c)", "(x ...)") test("((a b) (c d))", "((x y) ...)") test("(a (b c) (d e))", "(a (x y) ...)") test("(let ([a b] [c d]) 5)", "(let ([x y] ...) body)") testFail("(x y z)", "((x) ...)") testFail("((a b) (a b c) (a b c d))", "((x y) ...)") # Test it # testMatcher() # the D function # map pattern variables to levels/values def makeEnvironment(pattern, literals): if pattern.match("()"): def ok(input): return {} return ok if str(pattern) in literals: def ok(input): return {} return ok if pattern.isVariable(): def ok(input): # Warning! Converting the input to a string may not be right, but # strings are the easiest to test return {str(pattern) : (0,str(input))} return ok if pattern.match("(x ...)"): def doit(matcher): def combineEnvs(envs): if len(envs) == 0: env = matcher(pattern.head()) for key in env.keys(): level, values = env[key] z = [] for i in xrange(0, level): z = [z] env[key] = (level + 1, z) return env # return matcher(pattern.head()) # return {str(pattern.head()) : (1,[])} # debug("Combine envs %s" % envs, 3) env0 = envs[0] out = {} for key in env0.keys(): level, values = env0[key] save = [] save.append(values) for env in envs[1:]: blah, more_values = env[key] save.append(more_values) out[key] = (level + 1, save) return out def ok(input): # debug("Combine environments for %s" % pattern, 2) return combineEnvs([matcher(s) for s in input]) return ok return doit(makeEnvironment(pattern.head(), literals)) if pattern.isList(): def doit(headMatch, tailMatch): def ok(input): # debug("Head is %s: %s" % (headMatch, headMatch(input.head())), 3) # debug("Tail is %s: %s" % (tailMatch, tailMatch(input.tail())), 3) return dict(headMatch(input.head()).items() + tailMatch(input.tail()).items()) return ok return doit(makeEnvironment(pattern.head(), literals), makeEnvironment(pattern.tail(), literals)) raise Exception("Internal error") def testEnvironment(): def compare(e1, e2): return e1 == e2 def test(s_pattern, s_input, environment): pattern = parse.parseString(s_pattern) input = parse.parseString(s_input) out = makeEnvironment(pattern, [])(input) if not compare(environment, out): raise Exception("Pattern '%s' applied to '%s' does not produce '%s', instead produced '%s'" % (s_pattern, s_input, environment, out)) test("x", "x", {"x":(0, "x")}) test("(x ...)", "(1 2 3)", {"x" : (1, ['1', '2', '3'])}) test("((x ...) ...)", "((a b c) (d e f))", {"x" : (2, [['a','b','c'], ['d','e','f']])}) test("(let ([v e] ...) body)", "(let ([a 1] [b 2] [c 3]) (+ a b c))", {"let":(0,"let"), "v":(1,['a','b','c']), "e":(1,['1','2','3']), 'body':(0,'(+ a b c)')}) # testEnvironment() def makeOutput(pattern, literals): if pattern.match("()"): def ok(env): return parse.parseString("()") return ok if str(pattern) in literals: def ok(env): return pattern return ok if pattern.isVariable(): def ok(env): try: level, values = env[str(pattern)] if level > 0: raise Exception("Pattern variable `%s' bound to a list of values: '%s' but expected the variable to be bound to a single value" % (pattern, values)) return parse.parseString(values) except KeyError: return pattern return ok if pattern.match("(x ...)"): # debug("Pattern '%s' matches (x ...)" % pattern, 1) # This will include numbers in it too.. def freeVariables(sexp): if sexp.isList(): if sexp.null(): return [] else: return freeVariables(sexp.head()) + freeVariables(sexp.tail()) else: return [sexp] def controllable(sexp, env): for var in freeVariables(sexp): try: level, stuff = env[str(var)] if level > 0: return True except KeyError: pass return False def decompose(env, variables): def checkLengths(): length = 0 for var in variables: level, stuff = env[str(var)] if level > 0: if length == 0: length = level elif length != level: raise Exception("Length of pattern variable lists are not the same: %s" % env) def findRepeats(): for var in variables: try: level, stuff = env[str(var)] if level > 0: return len(stuff) except KeyError: pass raise Exception("Internal error") length = findRepeats() all = [] for index in xrange(0, length): out = {} for var in variables: try: level, stuff = env[str(var)] if level == 0: out[str(var)] = (0,stuff) else: out[str(var)] = (level - 1, stuff[index]) except KeyError: out[str(var)] = (0, str(var)) # raise Exception("Could not find %s in %s" % (var, env)) all.append(out) return all def doit(matcher, free): def ok(env): if controllable(pattern.head(), env): return ast.map(matcher, decompose(env, free)) else: raise Exception("No pattern variables with ellipses: %s in %s" % (pattern, env)) return ok return doit(makeOutput(pattern.head(), literals), freeVariables(pattern.head())) if pattern.isList(): def doit(headMatch, tailMatch): def ok(env): return ast.cons(headMatch(env), tailMatch(env)) return ok return doit(makeOutput(pattern.head(), literals), makeOutput(pattern.tail(), literals)) raise Exception("Internal error") def testOutput(): def test(s_pattern, s_input, s_template, output): pattern = parse.parseString(s_pattern) input = parse.parseString(s_input) template = parse.parseString(s_template) env = makeEnvironment(pattern, [])(input) out = makeOutput(template, [])(env) if str(out) != str(output): raise Exception("Expected '%s' but got '%s'" % (output, out)) test("y", "x", "y", "x") test("(a b)", "(1 2)", "(b a)", "(2 1)") test("(x ...)", "(1 2 3)", "(+ x ...)", "(+ 1 2 3)") test("(let ([v e] ...) body)", "(let ([a 1] [b 2]) (+ a b))", "((lambda (v ...) body) e ...)", "((lambda (a b) (+ a b)) 1 2)") # testOutput() def makeMacro(s_pattern, s_template, literals = []): def get(): pattern = parse.parseString(s_pattern) template = parse.parseString(s_template) env = makeEnvironment(pattern, literals) out = makeOutput(template, literals) matcher = makeMatcher(pattern, literals) def doit(input): # debug("Expanding macro '%s' => '%s' on '%s'" % (s_pattern, s_template, input)) try: return out(env(input)) except KeyError, e: import traceback traceback.print_exc() raise Exception("Internal error: %s" % e) return (matcher, doit) save = [None] def memo(): if save[0] == None: save[0] = get() return save[0] return memo macros = { "and" : [makeMacro("(and e1 e2 erest ...)", "(if e1 (and e2 erest ...) #f)"), makeMacro("(and e)", "e"), makeMacro("(and)", "#t")], "if" : [makeMacro("(if condition then)", "(if condition then #f)")], "begin" : [makeMacro("(begin (define (v1 vs ...) body ...) begin-rest ...)", "(begin+define (@defines (define v1 (lambda (vs ...) body ...))) (begin begin-rest ...))", literals = ['define']), makeMacro("(begin (define v e) begin-rest ...)", "(begin+define (@defines (define v e)) (begin begin-rest ...))", literals = ['define']), makeMacro("(begin expr ...)", "(begin+expr expr ...)")], "begin+define" : [makeMacro("(begin+define (@defines defs ...) (begin+define (@defines inner-def) rest ...))", "(begin+define (@defines inner-def defs ...) rest ...)", literals = ['@defines', 'begin+expr']), makeMacro("(begin+define (@defines (define v e) ...) (begin+expr expr ...))", "(letrec ([v e] ...) expr ...)", literals = ['@defines', 'begin+expr'])], "begin+expr" : [makeMacro("(begin+expr (begin+expr exp ...))", "(begin+expr exp ...)", literals = ['begin+expr']), #makeMacro("(begin+expr (set! v e) (set! v2 e2) rest ...)", # "(begin+expr (set! v e) (begin+expr (set! v2 e2) rest ...))", literals = ['set!']), # makeMacro("(begin+expr exp1 exp2 exps ...)", "(begin+expr (let ([dummy$ exp1]) (begin+expr exp2 exps ...)))") ], "or" : [makeMacro("(or e1 e2 rest ...)", "(let ([or-part e1]) (if or-part or-part (or e2 rest ...)))"), makeMacro("(or e1)", "e1"), makeMacro("(or)", "#f")], "lambda" : [makeMacro("(lambda args (begin body ...))", "(@lambda args (begin body ...))", literals = ['begin']), makeMacro("(lambda args body)", "(@lambda args body)"), makeMacro("(lambda args body ...)", "(lambda args (begin body ...))", literals = ['begin'])], "let" : [makeMacro("(let ([v e] ...) body ...)", "((lambda (v ...) body ...) e ...)")], "let*" : [makeMacro("(let* (set1 sets ...) body ...)", "(let (set1) (let* (sets ...) body ...))"), makeMacro("(let* () body ...)", "(let () body ...)")], "letrec" : [makeMacro("(letrec ([v e] ...) body ...)", "(let ([v #f] ...) (begin (set! v e) ...) body ...)")], "cond" : [makeMacro("(cond)", "#f", literals = ['else', '=>']), makeMacro("(cond (question answer) clauses ...)", "(if question answer (cond clauses ...))", literals = ['else', '=>']), makeMacro("(cond (exp) clauses ...)", "(let ([v exp]) (if v v (cond clauses ...)))", literals = ['else', '=>']), makeMacro("(cond (exp => func) clauses ...)", "(let ([v exp]) (if v (func v) (cond clauses ...)))", literals = ['else', '=>']), makeMacro("(cond (else blah))", "blah", literals = ['else', '=>']) ], } def expand(expr): # debug("Expanding '%s'" % expr, 2) if expr.isList() and not expr.null(): try: all_macro = macros[str(expr.head())] for macro in all_macro: test, func = macro() if test(expr): got = func(expr) # debug("Expanded to '%s'" % got, 1) return expand(got) else: debug("Failed to match", 1) return ast.map(expand, expr.body) except KeyError: return ast.map(expand, expr.body) else: return expr def testExpand(): def test(input, output): out = expand(parse.parseString(input)) if str(out) != str(output): raise Exception("Expected\n%s\nGot\n%s" % (output, out)) test("(let ([v 1]) 5)", "((lambda (v) 5) 1)") test("(and 1 2 3)", "(if 1 (if 2 3 #f) #f)") test("(and)", "#t") test("(or)" , "#f") test("(or 1 2 3)", expand(parse.parseString("(let ([or-part 1]) (if or-part or-part (let ([or-part 2]) (if or-part or-part 3))))"))) test("(if 1 2)", "(if 1 2 #f)") test("(cond (5))", expand(parse.parseString("(let ([v 5]) (if v v #f))"))) test("(cond (else 8))", "8") # testExpand() def convert(sexp): operators = ['display', 'newline', '+', '*', '-', '/', '=', '<', '>', 'call/cc'] if sexp.isList() and not sexp.null(): if not sexp.head().isList() and str(sexp.head()) == 'if': return ast.If(convert(sexp.body[1]), convert(sexp.body[2]), convert(sexp.body[3])) if not sexp.head().isList() and str(sexp.head()) == 'set!': return ast.Set(convert(sexp.body[1]), convert(sexp.body[2])) if not sexp.head().isList() and str(sexp.head()) == '@lambda': return ast.Lambda(sexp.body[1].body, convert(sexp.body[2])) if not sexp.head().isList() and str(sexp.head()) == 'begin+expr': if len(sexp.body) == 2: return convert(sexp.body[1]) #if len(sexp.body) == 1: # return sexp #n = convert(sexp.body[1]) #if n.isSet(): return ast.Begin([convert(s) for s in sexp.body[1:]]) # otherwise its an application return ast.Application(convert(sexp.body[0]), [convert(s) for s in sexp.body[1:]]) elif sexp.isList() and sexp.null(): return ast.Null() if sexp.isVariable() and str(sexp) in operators: return ast.Primitive(sexp.name + "-cps") return sexp def replaceVars(body, args): # print "replace Body is %s" % body if body.isVariable(): # print "replace var %s in %s %s" % (body, args, [str(v) for v in args]) if str(body) in [str(v) for v in args]: # print "replace!" return args[ [str(v) for v in args].index(str(body)) ] # print "dont replace" return body if body.isLambda(): more = [str(x) for x in body.args] use = body.args[::] for var in args: if not str(var) in more: use.append(var) return ast.Lambda(body.args, replaceVars(body.body, use)) if body.isApplication(): return ast.Application(replaceVars(body.proc, args), [replaceVars(arg, args) for arg in body.args]) if body.isBegin(): return ast.Begin([replaceVars(exp, args) for exp in body.body]) if body.isIf(): return ast.If(replaceVars(body.cond, args), replaceVars(body.then, args), replaceVars(body.elser, args)) if body.isSet(): return ast.Set(replaceVars(body.var, args), replaceVars(body.exp, args)) if body.isPrimitive(): return body if body.isDatum(): return body raise Exception("Replace vars in %s" % body) return body def simplify(sexp): # print "Simplify %s" % sexp if sexp.isList(): return ast.map(simplify, sexp.body) if sexp.isLambda(): return ast.Lambda(sexp.args, simplify(replaceVars(sexp.body, sexp.args))) if sexp.isBegin(): if len(sexp.body) >= 2 and sexp.body[0].isBegin(): return simplify(ast.Begin([simplify(b) for b in (sexp.body[0].body + sexp.body[1:])])) return ast.Begin([simplify(b) for b in sexp.body]) if sexp.isApplication(): return ast.Application(simplify(sexp.proc), [simplify(b) for b in sexp.args]) if sexp.isSet(): return ast.Set(simplify(sexp.var), simplify(sexp.exp)) if sexp.isIf(): return ast.If(simplify(sexp.cond), simplify(sexp.then), simplify(sexp.elser)) if sexp.isPrimitive(): return sexp if sexp.isVariable(): return sexp if sexp.isDatum(): return sexp raise Exception("Unhandled simplify %s" % sexp) return sexp def all(exp): return simplify(convert(expand(expand(ast.SExp([parse.parseString("begin"), exp])))))