1 module zua.compiler.ir; 2 import zua.parser.lexer; 3 import ast = zua.parser.ast; 4 import std.typecons; 5 import std.variant; 6 import std.uuid; 7 8 /** Represents a single IR node */ 9 abstract class IRNode { 10 /** The range of tokens that represent this IR node */ 11 Token start; 12 Token end; /// ditto 13 } 14 15 /** A statement */ 16 abstract class Stat : IRNode {} 17 18 /** An assignment statement */ 19 final class AssignStat : Stat { 20 /** A list of lvalues to modify */ 21 LvalueExpr[] keys; 22 23 /** The values to set each variable to */ 24 Expr[] values; 25 } 26 27 /** An expression statement */ 28 final class ExprStat : Stat { 29 /** The expression to evaluate in this statement */ 30 Expr expr; 31 } 32 33 /** A block statement */ 34 final class Block : Stat { 35 /** A list of statements to be executed by this block */ 36 Stat[] body; 37 } 38 39 /** A local variable declaration statement */ 40 final class DeclarationStat : Stat { 41 /** A list of variables to declare */ 42 UUID[] keys; 43 44 /** The values to set each variable to */ 45 Expr[] values; 46 } 47 48 /** A while statement */ 49 final class WhileStat : Stat { 50 /** The condition for this while statement */ 51 Expr cond; 52 53 /** The body of this while statement */ 54 Block body; 55 } 56 57 /** A repeat statement */ 58 final class RepeatStat : Stat { 59 /** The end condition for this repeat statement */ 60 Expr endCond; 61 62 /** The body of this repeat statement */ 63 Block body; 64 } 65 66 alias IfEntry = Tuple!(Expr, "cond", Block, "body"); 67 68 /** An if statement */ 69 final class IfStat : Stat { 70 /** The various condition-code pairs in this if statement */ 71 IfEntry[] entries; 72 73 /** The 'else' body of this if statement */ 74 Nullable!Block elseBody; 75 } 76 77 /** A numeric for loop */ 78 final class NumericForStat : Stat { 79 /** The variable to use in this for loop */ 80 UUID var; 81 82 /** Defines the range to loop over */ 83 Expr low; 84 Expr high; /// ditto 85 Nullable!Expr step; /// ditto 86 87 /** The body of the for loop */ 88 Block body; 89 } 90 91 /** A foreach loop */ 92 final class ForeachStat : Stat { 93 /** The variables to use in this for loop */ 94 UUID[] vars; 95 96 /** Defines the iterator to loop over */ 97 Expr[] iter; 98 99 /** The body of the for loop */ 100 Block body; 101 } 102 103 /** A return statement */ 104 final class ReturnStat : Stat { 105 /** The values to return */ 106 Expr[] values; 107 } 108 109 /** A type of atomic statement */ 110 enum AtomicStatType { 111 Break 112 } 113 114 /** An atomic statement */ 115 final class AtomicStat : Stat { 116 /** The type of atomic statement this represents */ 117 AtomicStatType type; 118 } 119 120 /** An expression */ 121 abstract class Expr : IRNode { 122 } 123 124 /** An expression whose value can be set */ 125 abstract class LvalueExpr : Expr {} 126 127 /** A global variable expression */ 128 final class GlobalExpr : LvalueExpr { 129 /** The name of this global */ 130 string name; 131 } 132 133 /** A local variable expression */ 134 final class LocalExpr : LvalueExpr { 135 /** The unique ID of this local */ 136 UUID id; 137 } 138 139 /** An upvalue variable expression */ 140 final class UpvalueExpr : LvalueExpr { 141 /** The unique ID of this upvalue */ 142 UUID id; 143 } 144 145 /** An index expression */ 146 final class IndexExpr : LvalueExpr { 147 /** The object to index */ 148 Expr base; 149 150 /** The key to use in indexing the given object */ 151 Expr key; 152 } 153 154 /** A bracket expression */ 155 final class BracketExpr : Expr { 156 /** The expression in the bracket expression */ 157 Expr expr; 158 } 159 160 /** A function call expression */ 161 final class CallExpr : Expr { 162 Expr base; /** The base of the call expression */ 163 Expr[] args; /** The arguments to use in the call */ 164 165 /** 166 * The method name to use in the call. 167 * If supplied, the call is a selfcall, i.e. 'base:method(args)' is called. 168 * Otherwise, 'base' is treated as a function, i.e. 'base(args)' is called. 169 */ 170 Nullable!string method; 171 } 172 173 /** A type of atomic expression */ 174 enum AtomicExprType { 175 Nil, 176 False, 177 True, 178 VariadicTuple 179 } 180 181 /** An atomic expression */ 182 final class AtomicExpr : Expr { 183 /** The type of atomic expression this represents */ 184 AtomicExprType type; 185 } 186 187 /** A number expression */ 188 final class NumberExpr : Expr { 189 /** The value of this expression */ 190 double value; 191 } 192 193 /** A string expression */ 194 final class StringExpr : Expr { 195 /** The value of this expression */ 196 string value; 197 } 198 199 /** A function expression */ 200 final class FunctionExpr : Expr { 201 /** A list of argument IDs */ 202 UUID[] args; 203 204 /** A list of upvalue IDs */ 205 UUID[] upvalues; 206 207 /** A list of closed IDs */ 208 UUID[] closed; 209 210 /** The number of variables local to this function */ 211 ulong localsCount; 212 213 /** Determines if the function is variadic */ 214 bool variadic; 215 216 /** The body of the function */ 217 Block body; 218 } 219 220 /** A binary operation */ 221 enum BinaryOperation { 222 Add, 223 Sub, 224 Mul, 225 Div, 226 Exp, 227 Mod, 228 Concat, 229 CmpLt, 230 CmpLe, 231 CmpGt, 232 CmpGe, 233 CmpEq, 234 CmpNe, 235 And, 236 Or 237 } 238 239 /** A binary operation expression */ 240 final class BinaryExpr : Expr { 241 /** The token representing this operation */ 242 Token opToken; 243 244 /** The operation performed on the two sides */ 245 BinaryOperation op; 246 247 Expr lhs; /** The left-hand side of the expression */ 248 Expr rhs; /** The right-hand side of the expression */ 249 } 250 251 /** A unary operation */ 252 enum UnaryOperation { 253 Negate, 254 Not, 255 Length 256 } 257 258 /** A unary operation expression */ 259 final class UnaryExpr : Expr { 260 /** The operation performed on the given expression */ 261 UnaryOperation op; 262 263 /** The expression to manipulate using the given unary operation */ 264 Expr expr; 265 } 266 267 /** A single field in a table */ 268 struct TableField { 269 /** The key component of this field */ 270 Expr key; 271 272 /** The value component of this field */ 273 Expr value; 274 } 275 276 alias FieldEntry = Algebraic!(TableField, Expr); 277 278 /** A table constructor expression */ 279 final class TableExpr : Expr { 280 /** A list of fields in the table */ 281 FieldEntry[] fields; 282 } 283 284 private final class Environment { 285 Environment parent; 286 UUID[string] vars; 287 UUID[string] upvalues; 288 bool[UUID] closed; 289 bool isFunction; 290 291 this(bool isFunction, Environment parent) { 292 this.isFunction = isFunction; 293 this.parent = parent; 294 } 295 296 bool has(string var) { 297 return var in vars || (parent && parent.has(var)); 298 } 299 300 bool isUpvalue(string var) { 301 if (var in vars) return false; 302 if (!parent) return false; 303 304 if (isFunction) { 305 return parent.has(var); 306 } 307 else { 308 return parent.isUpvalue(var); 309 } 310 } 311 312 UUID get(string var) { 313 if (var in vars) { 314 return vars[var]; 315 } 316 else if (var in upvalues) { 317 return upvalues[var]; 318 } 319 else if (parent) { 320 UUID res = parent.get(var); 321 if (isFunction) { 322 Environment at = parent; 323 while (!at.isFunction) at = at.parent; 324 at.closed[res] = true; 325 upvalues[var] = res; 326 } 327 return res; 328 } 329 else assert(0); 330 } 331 332 UUID make(string var) { 333 UUID res = randomUUID(); 334 vars[var] = res; 335 return res; 336 } 337 } 338 339 private final class ASTCompiler { 340 341 Environment env; 342 343 Stat compile(ast.Stat stat) { 344 if (auto s = cast(ast.AssignStat)stat) return compile(s); 345 if (auto s = cast(ast.ExprStat)stat) return compile(s); 346 if (auto s = cast(ast.Block)stat) return compile(s); 347 if (auto s = cast(ast.DeclarationStat)stat) return compile(s); 348 if (auto s = cast(ast.FunctionDeclarationStat)stat) return compile(s); 349 if (auto s = cast(ast.WhileStat)stat) return compile(s); 350 if (auto s = cast(ast.RepeatStat)stat) return compile(s); 351 if (auto s = cast(ast.IfStat)stat) return compile(s); 352 if (auto s = cast(ast.NumericForStat)stat) return compile(s); 353 if (auto s = cast(ast.ForeachStat)stat) return compile(s); 354 if (auto s = cast(ast.ReturnStat)stat) return compile(s); 355 if (auto s = cast(ast.AtomicStat)stat) return compile(s); 356 assert(0); 357 } 358 359 AssignStat compile(ast.AssignStat stat) { 360 auto res = new AssignStat; 361 res.start = stat.start; 362 res.end = stat.end; 363 foreach (k; stat.keys) res.keys ~= compile(k); 364 foreach (v; stat.values) res.values ~= compile(v); 365 return res; 366 } 367 368 ExprStat compile(ast.ExprStat stat) { 369 auto res = new ExprStat; 370 res.start = stat.start; 371 res.end = stat.end; 372 res.expr = compile(stat.expr); 373 return res; 374 } 375 376 Block compile(ast.Block stat, bool setupEnv = true) { 377 auto res = new Block; 378 res.start = stat.start; 379 res.end = stat.end; 380 if (setupEnv) env = new Environment(false, env); 381 foreach (s; stat.body) { 382 res.body ~= compile(s); 383 } 384 if (setupEnv) env = env.parent; 385 return res; 386 } 387 388 DeclarationStat compile(ast.DeclarationStat stat) { 389 auto res = new DeclarationStat; 390 res.start = stat.start; 391 res.end = stat.end; 392 foreach (e; stat.values) { 393 res.values ~= compile(e); 394 } 395 foreach (v; stat.keys) { 396 res.keys ~= env.make(v); 397 } 398 return res; 399 } 400 401 Block compile(ast.FunctionDeclarationStat stat) { 402 auto res = new Block; 403 res.start = stat.start; 404 res.end = stat.end; 405 auto decl = new DeclarationStat; 406 decl.start = stat.start; 407 decl.end = stat.start; // this is not a typo (the declaration only refers to the `local` keyword) 408 if (stat.key != "") { 409 decl.keys ~= env.make(stat.key); 410 } 411 auto assign = new AssignStat; 412 assign.start = stat.start; 413 assign.end = stat.end; 414 if (stat.key != "") { 415 auto lvalue = new LocalExpr; 416 lvalue.start = stat.start; 417 lvalue.end = stat.end; 418 lvalue.id = env.get(stat.key); 419 assign.keys ~= lvalue; 420 } 421 assign.values ~= compile(stat.value); 422 res.body ~= decl; 423 res.body ~= assign; 424 return res; 425 } 426 427 WhileStat compile(ast.WhileStat stat) { 428 auto res = new WhileStat; 429 res.start = stat.start; 430 res.end = stat.end; 431 res.cond = compile(stat.cond); 432 res.body = compile(stat.body); 433 return res; 434 } 435 436 RepeatStat compile(ast.RepeatStat stat) { 437 auto res = new RepeatStat; 438 res.start = stat.start; 439 res.end = stat.end; 440 res.endCond = compile(stat.endCond); 441 res.body = compile(stat.body); 442 return res; 443 } 444 445 IfStat compile(ast.IfStat stat) { 446 auto res = new IfStat; 447 res.start = stat.start; 448 res.end = stat.end; 449 foreach (e; stat.entries) { 450 res.entries ~= IfEntry(compile(e.cond), compile(e.body)); 451 } 452 if (!stat.elseBody.isNull) { 453 res.elseBody = compile(stat.elseBody.get).nullable; 454 } 455 return res; 456 } 457 458 NumericForStat compile(ast.NumericForStat stat) { 459 auto res = new NumericForStat; 460 res.start = stat.start; 461 res.end = stat.end; 462 res.low = compile(stat.low); 463 res.high = compile(stat.high); 464 if (!stat.step.isNull) { 465 res.step = compile(stat.step.get).nullable; 466 } 467 env = new Environment(false, env); 468 res.var = env.make(stat.var); 469 res.body = compile(stat.body, false); 470 env = env.parent; 471 return res; 472 } 473 474 ForeachStat compile(ast.ForeachStat stat) { 475 auto res = new ForeachStat; 476 res.start = stat.start; 477 res.end = stat.end; 478 foreach (e; stat.iter) { 479 res.iter ~= compile(e); 480 } 481 env = new Environment(false, env); 482 foreach (v; stat.vars) { 483 res.vars ~= env.make(v); 484 } 485 res.body = compile(stat.body, false); 486 env = env.parent; 487 return res; 488 } 489 490 ReturnStat compile(ast.ReturnStat stat) { 491 auto res = new ReturnStat; 492 res.start = stat.start; 493 res.end = stat.end; 494 foreach (v; stat.values) { 495 res.values ~= compile(v); 496 } 497 return res; 498 } 499 500 AtomicStat compile(ast.AtomicStat stat) { 501 auto res = new AtomicStat; 502 res.start = stat.start; 503 res.end = stat.end; 504 switch (stat.type) { 505 case ast.AtomicStatType.Break: 506 res.type = AtomicStatType.Break; 507 break; 508 default: assert(0); 509 } 510 return res; 511 } 512 513 Expr compile(ast.Expr expr) { 514 if (auto e = cast(ast.PrefixExpr)expr) return compile(e); 515 if (auto e = cast(ast.AtomicExpr)expr) return compile(e); 516 if (auto e = cast(ast.NumberExpr)expr) return compile(e); 517 if (auto e = cast(ast.StringExpr)expr) return compile(e); 518 if (auto e = cast(ast.FunctionExpr)expr) return compile(e); 519 if (auto e = cast(ast.BinaryExpr)expr) return compile(e); 520 if (auto e = cast(ast.UnaryExpr)expr) return compile(e); 521 if (auto e = cast(ast.TableExpr)expr) return compile(e); 522 assert(0); 523 } 524 525 Expr compile(ast.PrefixExpr expr) { 526 if (auto e = cast(ast.LvalueExpr)expr) return compile(e); 527 if (auto e = cast(ast.BracketExpr)expr) return compile(e); 528 if (auto e = cast(ast.CallExpr)expr) return compile(e); 529 assert(0); 530 } 531 532 LvalueExpr compile(ast.LvalueExpr expr) { 533 if (auto e = cast(ast.VariableExpr)expr) return compile(e); 534 if (auto e = cast(ast.IndexExpr)expr) return compile(e); 535 assert(0); 536 } 537 538 LvalueExpr compile(ast.VariableExpr expr) { 539 if (env.has(expr.name)) { 540 UUID id = env.get(expr.name); 541 if (env.isUpvalue(expr.name)) { 542 UpvalueExpr res = new UpvalueExpr; 543 res.start = expr.start; 544 res.end = expr.end; 545 res.id = id; 546 return res; 547 } 548 else { 549 LocalExpr res = new LocalExpr; 550 res.start = expr.start; 551 res.end = expr.end; 552 res.id = id; 553 return res; 554 } 555 } 556 else { 557 GlobalExpr res = new GlobalExpr; 558 res.start = expr.start; 559 res.end = expr.end; 560 res.name = expr.name; 561 return res; 562 } 563 } 564 565 IndexExpr compile(ast.IndexExpr expr) { 566 IndexExpr res = new IndexExpr; 567 res.start = expr.start; 568 res.end = expr.end; 569 res.base = compile(expr.base); 570 res.key = compile(expr.key); 571 return res; 572 } 573 574 BracketExpr compile(ast.BracketExpr expr) { 575 BracketExpr res = new BracketExpr; 576 res.start = expr.start; 577 res.end = expr.end; 578 res.expr = compile(expr.expr); 579 return res; 580 } 581 582 CallExpr compile(ast.CallExpr expr) { 583 CallExpr res = new CallExpr; 584 res.start = expr.start; 585 res.end = expr.end; 586 res.base = compile(expr.base); 587 res.method = expr.method; 588 foreach (arg; expr.args) { 589 res.args ~= compile(arg); 590 } 591 return res; 592 } 593 594 AtomicExpr compile(ast.AtomicExpr expr) { 595 AtomicExpr res = new AtomicExpr; 596 res.start = expr.start; 597 res.end = expr.end; 598 switch (expr.type) { 599 case ast.AtomicExprType.Nil: 600 res.type = AtomicExprType.Nil; 601 break; 602 case ast.AtomicExprType.False: 603 res.type = AtomicExprType.False; 604 break; 605 case ast.AtomicExprType.True: 606 res.type = AtomicExprType.True; 607 break; 608 case ast.AtomicExprType.VariadicTuple: 609 res.type = AtomicExprType.VariadicTuple; 610 break; 611 default: assert(0); 612 } 613 return res; 614 } 615 616 NumberExpr compile(ast.NumberExpr expr) { 617 NumberExpr res = new NumberExpr; 618 res.start = expr.start; 619 res.end = expr.end; 620 res.value = expr.value; 621 return res; 622 } 623 624 StringExpr compile(ast.StringExpr expr) { 625 StringExpr res = new StringExpr; 626 res.start = expr.start; 627 res.end = expr.end; 628 res.value = expr.value; 629 return res; 630 } 631 632 FunctionExpr compile(ast.FunctionExpr expr) { 633 FunctionExpr res = new FunctionExpr; 634 res.start = expr.start; 635 res.end = expr.end; 636 res.variadic = expr.variadic; 637 env = new Environment(true, env); 638 foreach (v; expr.args) { 639 res.args ~= env.make(v); 640 } 641 res.body = compile(expr.body, false); 642 res.localsCount = env.vars.length; 643 foreach (u; env.upvalues.byValue) res.upvalues ~= u; 644 foreach (u; env.closed.byKey) res.closed ~= u; 645 env = env.parent; 646 return res; 647 } 648 649 FunctionExpr compileToplevel(ast.Block stat) { 650 auto res = new FunctionExpr; 651 res.start = stat.start; 652 res.end = stat.end; 653 res.variadic = true; 654 env = new Environment(true, env); 655 res.body = compile(stat, false); 656 res.localsCount = env.vars.length; 657 foreach (u; env.upvalues.byValue) res.upvalues ~= u; 658 foreach (u; env.closed.byKey) res.closed ~= u; 659 env = env.parent; 660 return res; 661 } 662 663 BinaryExpr compile(ast.BinaryExpr expr) { 664 BinaryExpr res = new BinaryExpr; 665 res.start = expr.start; 666 res.end = expr.end; 667 res.opToken = expr.opToken; 668 res.op = expr.op; 669 res.lhs = compile(expr.lhs); 670 res.rhs = compile(expr.rhs); 671 return res; 672 } 673 674 UnaryExpr compile(ast.UnaryExpr expr) { 675 UnaryExpr res = new UnaryExpr; 676 res.start = expr.start; 677 res.end = expr.end; 678 res.op = expr.op; 679 res.expr = compile(expr.expr); 680 return res; 681 } 682 683 TableExpr compile(ast.TableExpr expr) { 684 TableExpr res = new TableExpr; 685 res.start = expr.start; 686 res.end = expr.end; 687 foreach (field; expr.fields) { 688 if (auto f = field.peek!(ast.TableField)) { 689 TableField compiled = { 690 key: compile(f.key), 691 value: compile(f.value) 692 }; 693 res.fields ~= FieldEntry(compiled); 694 } 695 else { 696 res.fields ~= FieldEntry(compile(field.get!(ast.Expr))); 697 } 698 } 699 return res; 700 } 701 702 } 703 704 /** Compile an AST block into an IR block */ 705 FunctionExpr compileAST(ast.Block block) { 706 return new ASTCompiler().compileToplevel(block); 707 } 708 709 unittest { 710 import zua.parser.parser : Parser; 711 import zua.diagnostic : Diagnostic; 712 713 Diagnostic[] d; 714 auto lexer = new Lexer(q"( 715 local i = 2 716 print(i) 717 )", d); 718 auto parser = new Parser(lexer, d); 719 auto tree = parser.toplevel(); 720 721 assert(tree.body.length == 2); 722 723 const ir = compileAST(tree); 724 725 assert(ir.closed.length == 0); 726 assert(ir.upvalues.length == 0); 727 728 auto bd = ir.body; 729 730 const decl = bd.body[0]; 731 const print = bd.body[1]; 732 733 const decl2 = cast(DeclarationStat)decl; 734 assert(decl2); 735 736 const print2 = cast(ExprStat)print; 737 assert(print2); 738 739 const print3 = cast(CallExpr)print2.expr; 740 assert(print3); 741 742 assert(cast(GlobalExpr)print3.base); 743 assert(print3.args.length == 1); 744 745 const local = cast(LocalExpr)print3.args[0]; 746 assert(local); 747 748 assert(decl2.keys.length == 1); 749 assert(decl2.keys[0] == local.id); 750 751 assert(d == []); 752 } 753 754 unittest { 755 import zua.parser.parser : Parser; 756 import zua.diagnostic : Diagnostic; 757 758 Diagnostic[] dg; 759 auto lexer = new Lexer(q"( 760 for i in i do 761 (function() 762 return i 763 local i 764 return i 765 end)() 766 return i 767 end 768 )", dg); 769 auto parser = new Parser(lexer, dg); 770 auto tree = parser.toplevel(); 771 772 assert(tree.body.length == 1); 773 774 const ir = compileAST(tree); 775 776 assert(ir.closed.length == 1); 777 assert(ir.upvalues.length == 0); 778 779 auto bd = ir.body; 780 781 assert(bd.body.length == 1); 782 783 auto a = cast(ForeachStat)bd.body[0]; 784 assert(a); 785 assert(a.iter.length == 1); 786 assert(a.vars.length == 1); 787 assert(a.body.body.length == 2); 788 assert(a.vars[0] == ir.closed[0]); 789 790 auto b = cast(GlobalExpr)a.iter[0]; 791 assert(b); 792 assert(b.name == "i"); 793 794 auto c = cast(ExprStat)a.body.body[0]; 795 assert(c); 796 797 auto d = cast(CallExpr)c.expr; 798 assert(d); 799 assert(d.args == []); 800 801 auto de = cast(BracketExpr)d.base; 802 assert(de); 803 804 auto e = cast(FunctionExpr)de.expr; 805 assert(e); 806 assert(e.upvalues.length == 1); 807 assert(e.upvalues[0] == a.vars[0]); 808 assert(e.args == []); 809 assert(!e.variadic); 810 assert(e.body.body.length == 3); 811 812 auto f = cast(ReturnStat)e.body.body[0]; 813 assert(f); 814 assert(f.values.length == 1); 815 816 auto g = cast(UpvalueExpr)f.values[0]; 817 assert(g); 818 assert(g.id == a.vars[0]); 819 820 auto h = cast(ReturnStat)a.body.body[1]; 821 assert(h); 822 assert(h.values.length == 1); 823 824 auto i = cast(LocalExpr)h.values[0]; 825 assert(i); 826 assert(i.id == a.vars[0]); 827 828 auto j = cast(DeclarationStat)e.body.body[1]; 829 assert(j); 830 assert(j.keys.length == 1); 831 assert(j.values.length == 0); 832 833 auto k = cast(ReturnStat)e.body.body[2]; 834 assert(k); 835 assert(k.values.length == 1); 836 837 auto l = cast(LocalExpr)k.values[0]; 838 assert(l); 839 assert(l.id == j.keys[0]); 840 assert(l.id != a.vars[0]); 841 842 assert(dg == []); 843 }