1 module zua.compiler.ir;
2 import zua.parser.lexer;
3 import ast = zua.parser.ast;
4 import std.typecons;
5 import std.variant;
6 import std.uuid;
8 /** Represents a single IR node */
9 abstract class IRNode {
10 	/** The range of tokens that represent this IR node */
11 	Token start;
12 	Token end; /// ditto
13 }
15 /** A statement */
16 abstract class Stat : IRNode {}
18 /** An assignment statement */
19 final class AssignStat : Stat {
20 	/** A list of lvalues to modify */
21 	LvalueExpr[] keys;
23 	/** The values to set each variable to */
24 	Expr[] values;
25 }
27 /** An expression statement */
28 final class ExprStat : Stat {
29 	/** The expression to evaluate in this statement */
30 	Expr expr;
31 }
33 /** A block statement */
34 final class Block : Stat {
35 	/** A list of statements to be executed by this block */
36 	Stat[] body;
37 }
39 /** A local variable declaration statement */
40 final class DeclarationStat : Stat {
41 	/** A list of variables to declare */
42 	UUID[] keys;
44 	/** The values to set each variable to */
45 	Expr[] values;
46 }
48 /** A while statement */
49 final class WhileStat : Stat {
50 	/** The condition for this while statement */
51 	Expr cond;
53 	/** The body of this while statement */
54 	Block body;
55 }
57 /** A repeat statement */
58 final class RepeatStat : Stat {
59 	/** The end condition for this repeat statement */
60 	Expr endCond;
62 	/** The body of this repeat statement */
63 	Block body;
64 }
66 alias IfEntry = Tuple!(Expr, "cond", Block, "body");
68 /** An if statement */
69 final class IfStat : Stat {
70 	/** The various condition-code pairs in this if statement */
71 	IfEntry[] entries;
73 	/** The 'else' body of this if statement */
74 	Nullable!Block elseBody;
75 }
77 /** A numeric for loop */
78 final class NumericForStat : Stat {
79 	/** The variable to use in this for loop */
80 	UUID var;
82 	/** Defines the range to loop over */
83 	Expr low;
84 	Expr high; /// ditto
85 	Nullable!Expr step; /// ditto
87 	/** The body of the for loop */
88 	Block body;
89 }
91 /** A foreach loop */
92 final class ForeachStat : Stat {
93 	/** The variables to use in this for loop */
94 	UUID[] vars;
96 	/** Defines the iterator to loop over */
97 	Expr[] iter;
99 	/** The body of the for loop */
100 	Block body;
101 }
103 /** A return statement */
104 final class ReturnStat : Stat {
105 	/** The values to return */
106 	Expr[] values;
107 }
109 /** A type of atomic statement */
110 enum AtomicStatType {
111 	Break
112 }
114 /** An atomic statement */
115 final class AtomicStat : Stat {
116 	/** The type of atomic statement this represents */
117 	AtomicStatType type;
118 }
120 /** An expression */
121 abstract class Expr : IRNode {
122 }
124 /** An expression whose value can be set */
125 abstract class LvalueExpr : Expr {}
127 /** A global variable expression */
128 final class GlobalExpr : LvalueExpr {
129 	/** The name of this global */
130 	string name;
131 }
133 /** A local variable expression */
134 final class LocalExpr : LvalueExpr {
135 	/** The unique ID of this local */
136 	UUID id;
137 }
139 /** An upvalue variable expression */
140 final class UpvalueExpr : LvalueExpr {
141 	/** The unique ID of this upvalue */
142 	UUID id;
143 }
145 /** An index expression */
146 final class IndexExpr : LvalueExpr {
147 	/** The object to index */
148 	Expr base;
150 	/** The key to use in indexing the given object */
151 	Expr key;
152 }
154 /** A bracket expression */
155 final class BracketExpr : Expr {
156 	/** The expression in the bracket expression */
157 	Expr expr;
158 }
160 /** A function call expression */
161 final class CallExpr : Expr {
162 	Expr base; /** The base of the call expression */
163 	Expr[] args; /** The arguments to use in the call */
165 	/**
166 	 * The method name to use in the call.
167 	 * If supplied, the call is a selfcall, i.e. 'base:method(args)' is called.
168 	 * Otherwise, 'base' is treated as a function, i.e. 'base(args)' is called.
169 	 */
170 	Nullable!string method;
171 }
173 /** A type of atomic expression */
174 enum AtomicExprType {
175 	Nil,
176 	False,
177 	True,
178 	VariadicTuple
179 }
181 /** An atomic expression */
182 final class AtomicExpr : Expr {
183 	/** The type of atomic expression this represents */
184 	AtomicExprType type;
185 }
187 /** A number expression */
188 final class NumberExpr : Expr {
189 	/** The value of this expression */
190 	double value;
191 }
193 /** A string expression */
194 final class StringExpr : Expr {
195 	/** The value of this expression */
196 	string value;
197 }
199 /** A function expression */
200 final class FunctionExpr : Expr {
201 	/** A list of argument IDs */
202 	UUID[] args;
204 	/** A list of upvalue IDs */
205 	UUID[] upvalues;
207 	/** A list of closed IDs */
208 	UUID[] closed;
210 	/** The number of variables local to this function */
211 	ulong localsCount;
213 	/** Determines if the function is variadic */
214 	bool variadic;
216 	/** The body of the function */
217 	Block body;
218 }
220 /** A binary operation */
221 enum BinaryOperation {
222 	Add,
223 	Sub,
224 	Mul,
225 	Div,
226 	Exp,
227 	Mod,
228 	Concat,
229 	CmpLt,
230 	CmpLe,
231 	CmpGt,
232 	CmpGe,
233 	CmpEq,
234 	CmpNe,
235 	And,
236 	Or
237 }
239 /** A binary operation expression */
240 final class BinaryExpr : Expr {
241 	/** The token representing this operation */
242 	Token opToken;
244 	/** The operation performed on the two sides */
245 	BinaryOperation op;
247 	Expr lhs; /** The left-hand side of the expression */
248 	Expr rhs; /** The right-hand side of the expression */
249 }
251 /** A unary operation */
252 enum UnaryOperation {
253 	Negate,
254 	Not,
255 	Length
256 }
258 /** A unary operation expression */
259 final class UnaryExpr : Expr {
260 	/** The operation performed on the given expression */
261 	UnaryOperation op;
263 	/** The expression to manipulate using the given unary operation */
264 	Expr expr;
265 }
267 /** A single field in a table */
268 struct TableField {
269 	/** The key component of this field */
270 	Expr key;
272 	/** The value component of this field */
273 	Expr value;
274 }
276 alias FieldEntry = Algebraic!(TableField, Expr);
278 /** A table constructor expression */
279 final class TableExpr : Expr {
280 	/** A list of fields in the table */
281 	FieldEntry[] fields;
282 }
284 private final class Environment {
285 	Environment parent;
286 	UUID[string] vars;
287 	UUID[string] upvalues;
288 	bool[UUID] closed;
289 	bool isFunction;
291 	this(bool isFunction, Environment parent) {
292 		this.isFunction = isFunction;
293 		this.parent = parent;
294 	}
296 	bool has(string var) {
297 		return var in vars || (parent && parent.has(var));
298 	}
300 	bool isUpvalue(string var) {
301 		if (var in vars) return false;
302 		if (!parent) return false;
304 		if (isFunction) {
305 			return parent.has(var);
306 		}
307 		else {
308 			return parent.isUpvalue(var);
309 		}
310 	}
312 	UUID get(string var) {
313 		if (var in vars) {
314 			return vars[var];
315 		}
316 		else if (var in upvalues) {
317 			return upvalues[var];
318 		}
319 		else if (parent) {
320 			UUID res = parent.get(var);
321 			if (isFunction) {
322 				Environment at = parent;
323 				while (!at.isFunction) at = at.parent;
324 				at.closed[res] = true;
325 				upvalues[var] = res;
326 			}
327 			return res;
328 		}
329 		else assert(0);
330 	}
332 	UUID make(string var) {
333 		UUID res = randomUUID();
334 		vars[var] = res;
335 		return res;
336 	}
337 }
339 private final class ASTCompiler {
341 	Environment env;
343 	Stat compile(ast.Stat stat) {
344 		if (auto s = cast(ast.AssignStat)stat) return compile(s);
345 		if (auto s = cast(ast.ExprStat)stat) return compile(s);
346 		if (auto s = cast(ast.Block)stat) return compile(s);
347 		if (auto s = cast(ast.DeclarationStat)stat) return compile(s);
348 		if (auto s = cast(ast.FunctionDeclarationStat)stat) return compile(s);
349 		if (auto s = cast(ast.WhileStat)stat) return compile(s);
350 		if (auto s = cast(ast.RepeatStat)stat) return compile(s);
351 		if (auto s = cast(ast.IfStat)stat) return compile(s);
352 		if (auto s = cast(ast.NumericForStat)stat) return compile(s);
353 		if (auto s = cast(ast.ForeachStat)stat) return compile(s);
354 		if (auto s = cast(ast.ReturnStat)stat) return compile(s);
355 		if (auto s = cast(ast.AtomicStat)stat) return compile(s);
356 		assert(0);
357 	}
359 	AssignStat compile(ast.AssignStat stat) {
360 		auto res = new AssignStat;
361 		res.start = stat.start;
362 		res.end = stat.end;
363 		foreach (k; stat.keys) res.keys ~= compile(k);
364 		foreach (v; stat.values) res.values ~= compile(v);
365 		return res;
366 	}
368 	ExprStat compile(ast.ExprStat stat) {
369 		auto res = new ExprStat;
370 		res.start = stat.start;
371 		res.end = stat.end;
372 		res.expr = compile(stat.expr);
373 		return res;
374 	}
376 	Block compile(ast.Block stat, bool setupEnv = true) {
377 		auto res = new Block;
378 		res.start = stat.start;
379 		res.end = stat.end;
380 		if (setupEnv) env = new Environment(false, env);
381 		foreach (s; stat.body) {
382 			res.body ~= compile(s);
383 		}
384 		if (setupEnv) env = env.parent;
385 		return res;
386 	}
388 	DeclarationStat compile(ast.DeclarationStat stat) {
389 		auto res = new DeclarationStat;
390 		res.start = stat.start;
391 		res.end = stat.end;
392 		foreach (e; stat.values) {
393 			res.values ~= compile(e);
394 		}
395 		foreach (v; stat.keys) {
396 			res.keys ~= env.make(v);
397 		}
398 		return res;
399 	}
401 	Block compile(ast.FunctionDeclarationStat stat) {
402 		auto res = new Block;
403 		res.start = stat.start;
404 		res.end = stat.end;
405 		auto decl = new DeclarationStat;
406 		decl.start = stat.start;
407 		decl.end = stat.start; // this is not a typo (the declaration only refers to the `local` keyword)
408 		if (stat.key != "") {
409 			decl.keys ~= env.make(stat.key);
410 		}
411 		auto assign = new AssignStat;
412 		assign.start = stat.start;
413 		assign.end = stat.end;
414 		if (stat.key != "") {
415 			auto lvalue = new LocalExpr;
416 			lvalue.start = stat.start;
417 			lvalue.end = stat.end;
418 			lvalue.id = env.get(stat.key);
419 			assign.keys ~= lvalue;
420 		}
421 		assign.values ~= compile(stat.value);
422 		res.body ~= decl;
423 		res.body ~= assign;
424 		return res;
425 	}
427 	WhileStat compile(ast.WhileStat stat) {
428 		auto res = new WhileStat;
429 		res.start = stat.start;
430 		res.end = stat.end;
431 		res.cond = compile(stat.cond);
432 		res.body = compile(stat.body);
433 		return res;
434 	}
436 	RepeatStat compile(ast.RepeatStat stat) {
437 		auto res = new RepeatStat;
438 		res.start = stat.start;
439 		res.end = stat.end;
440 		res.endCond = compile(stat.endCond);
441 		res.body = compile(stat.body);
442 		return res;
443 	}
445 	IfStat compile(ast.IfStat stat) {
446 		auto res = new IfStat;
447 		res.start = stat.start;
448 		res.end = stat.end;
449 		foreach (e; stat.entries) {
450 			res.entries ~= IfEntry(compile(e.cond), compile(e.body));
451 		}
452 		if (!stat.elseBody.isNull) {
453 			res.elseBody = compile(stat.elseBody.get).nullable;
454 		}
455 		return res;
456 	}
458 	NumericForStat compile(ast.NumericForStat stat) {
459 		auto res = new NumericForStat;
460 		res.start = stat.start;
461 		res.end = stat.end;
462 		res.low = compile(stat.low);
463 		res.high = compile(stat.high);
464 		if (!stat.step.isNull) {
465 			res.step = compile(stat.step.get).nullable;
466 		}
467 		env = new Environment(false, env);
468 		res.var = env.make(stat.var);
469 		res.body = compile(stat.body, false);
470 		env = env.parent;
471 		return res;
472 	}
474 	ForeachStat compile(ast.ForeachStat stat) {
475 		auto res = new ForeachStat;
476 		res.start = stat.start;
477 		res.end = stat.end;
478 		foreach (e; stat.iter) {
479 			res.iter ~= compile(e);
480 		}
481 		env = new Environment(false, env);
482 		foreach (v; stat.vars) {
483 			res.vars ~= env.make(v);
484 		}
485 		res.body = compile(stat.body, false);
486 		env = env.parent;
487 		return res;
488 	}
490 	ReturnStat compile(ast.ReturnStat stat) {
491 		auto res = new ReturnStat;
492 		res.start = stat.start;
493 		res.end = stat.end;
494 		foreach (v; stat.values) {
495 			res.values ~= compile(v);
496 		}
497 		return res;
498 	}
500 	AtomicStat compile(ast.AtomicStat stat) {
501 		auto res = new AtomicStat;
502 		res.start = stat.start;
503 		res.end = stat.end;
504 		switch (stat.type) {
505 		case ast.AtomicStatType.Break:
506 			res.type = AtomicStatType.Break;
507 			break;
508 		default: assert(0);
509 		}
510 		return res;
511 	}
513 	Expr compile(ast.Expr expr) {
514 		if (auto e = cast(ast.PrefixExpr)expr) return compile(e);
515 		if (auto e = cast(ast.AtomicExpr)expr) return compile(e);
516 		if (auto e = cast(ast.NumberExpr)expr) return compile(e);
517 		if (auto e = cast(ast.StringExpr)expr) return compile(e);
518 		if (auto e = cast(ast.FunctionExpr)expr) return compile(e);
519 		if (auto e = cast(ast.BinaryExpr)expr) return compile(e);
520 		if (auto e = cast(ast.UnaryExpr)expr) return compile(e);
521 		if (auto e = cast(ast.TableExpr)expr) return compile(e);
522 		assert(0);
523 	}
525 	Expr compile(ast.PrefixExpr expr) {
526 		if (auto e = cast(ast.LvalueExpr)expr) return compile(e);
527 		if (auto e = cast(ast.BracketExpr)expr) return compile(e);
528 		if (auto e = cast(ast.CallExpr)expr) return compile(e);
529 		assert(0);
530 	}
532 	LvalueExpr compile(ast.LvalueExpr expr) {
533 		if (auto e = cast(ast.VariableExpr)expr) return compile(e);
534 		if (auto e = cast(ast.IndexExpr)expr) return compile(e);
535 		assert(0);
536 	}
538 	LvalueExpr compile(ast.VariableExpr expr) {
539 		if (env.has(expr.name)) {
540 			UUID id = env.get(expr.name);
541 			if (env.isUpvalue(expr.name)) {
542 				UpvalueExpr res = new UpvalueExpr;
543 				res.start = expr.start;
544 				res.end = expr.end;
545 				res.id = id;
546 				return res;
547 			}
548 			else {
549 				LocalExpr res = new LocalExpr;
550 				res.start = expr.start;
551 				res.end = expr.end;
552 				res.id = id;
553 				return res;
554 			}
555 		}
556 		else {
557 			GlobalExpr res = new GlobalExpr;
558 			res.start = expr.start;
559 			res.end = expr.end;
560 			res.name = expr.name;
561 			return res;
562 		}
563 	}
565 	IndexExpr compile(ast.IndexExpr expr) {
566 		IndexExpr res = new IndexExpr;
567 		res.start = expr.start;
568 		res.end = expr.end;
569 		res.base = compile(expr.base);
570 		res.key = compile(expr.key);
571 		return res;
572 	}
574 	BracketExpr compile(ast.BracketExpr expr) {
575 		BracketExpr res = new BracketExpr;
576 		res.start = expr.start;
577 		res.end = expr.end;
578 		res.expr = compile(expr.expr);
579 		return res;
580 	}
582 	CallExpr compile(ast.CallExpr expr) {
583 		CallExpr res = new CallExpr;
584 		res.start = expr.start;
585 		res.end = expr.end;
586 		res.base = compile(expr.base);
587 		res.method = expr.method;
588 		foreach (arg; expr.args) {
589 			res.args ~= compile(arg);
590 		}
591 		return res;
592 	}
594 	AtomicExpr compile(ast.AtomicExpr expr) {
595 		AtomicExpr res = new AtomicExpr;
596 		res.start = expr.start;
597 		res.end = expr.end;
598 		switch (expr.type) {
599 		case ast.AtomicExprType.Nil:
600 			res.type = AtomicExprType.Nil;
601 			break;
602 		case ast.AtomicExprType.False:
603 			res.type = AtomicExprType.False;
604 			break;
605 		case ast.AtomicExprType.True:
606 			res.type = AtomicExprType.True;
607 			break;
608 		case ast.AtomicExprType.VariadicTuple:
609 			res.type = AtomicExprType.VariadicTuple;
610 			break;
611 		default: assert(0);
612 		}
613 		return res;
614 	}
616 	NumberExpr compile(ast.NumberExpr expr) {
617 		NumberExpr res = new NumberExpr;
618 		res.start = expr.start;
619 		res.end = expr.end;
620 		res.value = expr.value;
621 		return res;
622 	}
624 	StringExpr compile(ast.StringExpr expr) {
625 		StringExpr res = new StringExpr;
626 		res.start = expr.start;
627 		res.end = expr.end;
628 		res.value = expr.value;
629 		return res;
630 	}
632 	FunctionExpr compile(ast.FunctionExpr expr) {
633 		FunctionExpr res = new FunctionExpr;
634 		res.start = expr.start;
635 		res.end = expr.end;
636 		res.variadic = expr.variadic;
637 		env = new Environment(true, env);
638 		foreach (v; expr.args) {
639 			res.args ~= env.make(v);
640 		}
641 		res.body = compile(expr.body, false);
642 		res.localsCount = env.vars.length;
643 		foreach (u; env.upvalues.byValue) res.upvalues ~= u;
644 		foreach (u; env.closed.byKey) res.closed ~= u;
645 		env = env.parent;
646 		return res;
647 	}
649 	FunctionExpr compileToplevel(ast.Block stat) {
650 		auto res = new FunctionExpr;
651 		res.start = stat.start;
652 		res.end = stat.end;
653 		res.variadic = true;
654 		env = new Environment(true, env);
655 		res.body = compile(stat, false);
656 		res.localsCount = env.vars.length;
657 		foreach (u; env.upvalues.byValue) res.upvalues ~= u;
658 		foreach (u; env.closed.byKey) res.closed ~= u;
659 		env = env.parent;
660 		return res;
661 	}
663 	BinaryExpr compile(ast.BinaryExpr expr) {
664 		BinaryExpr res = new BinaryExpr;
665 		res.start = expr.start;
666 		res.end = expr.end;
667 		res.opToken = expr.opToken;
668 		res.op = expr.op;
669 		res.lhs = compile(expr.lhs);
670 		res.rhs = compile(expr.rhs);
671 		return res;
672 	}
674 	UnaryExpr compile(ast.UnaryExpr expr) {
675 		UnaryExpr res = new UnaryExpr;
676 		res.start = expr.start;
677 		res.end = expr.end;
678 		res.op = expr.op;
679 		res.expr = compile(expr.expr);
680 		return res;
681 	}
683 	TableExpr compile(ast.TableExpr expr) {
684 		TableExpr res = new TableExpr;
685 		res.start = expr.start;
686 		res.end = expr.end;
687 		foreach (field; expr.fields) {
688 			if (auto f = field.peek!(ast.TableField)) {
689 				TableField compiled = {
690 					key: compile(f.key),
691 					value: compile(f.value)
692 				};
693 				res.fields ~= FieldEntry(compiled);
694 			}
695 			else {
696 				res.fields ~= FieldEntry(compile(field.get!(ast.Expr)));
697 			}
698 		}
699 		return res;
700 	}
702 }
704 /** Compile an AST block into an IR block */
705 FunctionExpr compileAST(ast.Block block) {
706 	return new ASTCompiler().compileToplevel(block);
707 }
709 unittest {
710 	import zua.parser.parser : Parser;
711 	import zua.diagnostic : Diagnostic;
713 	Diagnostic[] d;
714 	auto lexer = new Lexer(q"(
715 		local i = 2
716 		print(i)
717 	)", d);
718 	auto parser = new Parser(lexer, d);
719 	auto tree = parser.toplevel();
721 	assert(tree.body.length == 2);
723 	const ir = compileAST(tree);
725 	assert(ir.closed.length == 0);
726 	assert(ir.upvalues.length == 0);
728 	auto bd = ir.body;
730 	const decl = bd.body[0];
731 	const print = bd.body[1];
733 	const decl2 = cast(DeclarationStat)decl;
734 	assert(decl2);
736 	const print2 = cast(ExprStat)print;
737 	assert(print2);
739 	const print3 = cast(CallExpr)print2.expr;
740 	assert(print3);
742 	assert(cast(GlobalExpr)print3.base);
743 	assert(print3.args.length == 1);
745 	const local = cast(LocalExpr)print3.args[0];
746 	assert(local);
748 	assert(decl2.keys.length == 1);
749 	assert(decl2.keys[0] == local.id);
751 	assert(d == []);
752 }
754 unittest {
755 	import zua.parser.parser : Parser;
756 	import zua.diagnostic : Diagnostic;
758 	Diagnostic[] dg;
759 	auto lexer = new Lexer(q"(
760 		for i in i do
761 			(function()
762 				return i
763 				local i
764 				return i
765 			end)()
766 			return i
767 		end
768 	)", dg);
769 	auto parser = new Parser(lexer, dg);
770 	auto tree = parser.toplevel();
772 	assert(tree.body.length == 1);
774 	const ir = compileAST(tree);
776 	assert(ir.closed.length == 1);
777 	assert(ir.upvalues.length == 0);
779 	auto bd = ir.body;
781 	assert(bd.body.length == 1);
783 	auto a = cast(ForeachStat)bd.body[0];
784 	assert(a);
785 	assert(a.iter.length == 1);
786 	assert(a.vars.length == 1);
787 	assert(a.body.body.length == 2);
788 	assert(a.vars[0] == ir.closed[0]);
790 	auto b = cast(GlobalExpr)a.iter[0];
791 	assert(b);
792 	assert(b.name == "i");
794 	auto c = cast(ExprStat)a.body.body[0];
795 	assert(c);
797 	auto d = cast(CallExpr)c.expr;
798 	assert(d);
799 	assert(d.args == []);
801 	auto de = cast(BracketExpr)d.base;
802 	assert(de);
804 	auto e = cast(FunctionExpr)de.expr;
805 	assert(e);
806 	assert(e.upvalues.length == 1);
807 	assert(e.upvalues[0] == a.vars[0]);
808 	assert(e.args == []);
809 	assert(!e.variadic);
810 	assert(e.body.body.length == 3);
812 	auto f = cast(ReturnStat)e.body.body[0];
813 	assert(f);
814 	assert(f.values.length == 1);
816 	auto g = cast(UpvalueExpr)f.values[0];
817 	assert(g);
818 	assert(g.id == a.vars[0]);
820 	auto h = cast(ReturnStat)a.body.body[1];
821 	assert(h);
822 	assert(h.values.length == 1);
824 	auto i = cast(LocalExpr)h.values[0];
825 	assert(i);
826 	assert(i.id == a.vars[0]);
828 	auto j = cast(DeclarationStat)e.body.body[1];
829 	assert(j);
830 	assert(j.keys.length == 1);
831 	assert(j.values.length == 0);
833 	auto k = cast(ReturnStat)e.body.body[2];
834 	assert(k);
835 	assert(k.values.length == 1);
837 	auto l = cast(LocalExpr)k.values[0];
838 	assert(l);
839 	assert(l.id == j.keys[0]);
840 	assert(l.id != a.vars[0]);
842 	assert(dg == []);
843 }