1 module zua.vm.vm;
2 import zua.vm.engine;
3 import zua.compiler.utils;
4 import std.bitmanip;
5 import std.math;
6 import std.variant;
7 import std.conv;
8 import std.typecons;
9 import std.uuid;
10 import std.algorithm.mutation;
11 
12 version(LDC) {
13 	private pragma(LDC_alloca) void* alloca(size_t);
14 }
15 else {
16 	import core.stdc.stdlib : alloca;
17 }
18 
19 /** A Lua VM engine */
20 class VmEngine : Engine {
21 
22 	/** The bytecode buffer for this engine */
23 	immutable(ubyte)[] buffer;
24 
25 	/** Create a new VM engine */
26 	this(immutable(ubyte)[] buffer, UUID id) {
27 		this.buffer = buffer;
28 		this.id = id;
29 	}
30 
31 	/** Read a single numeric value */
32 	pragma(inline) private T read(T)(size_t index) {
33 		ubyte[T.sizeof] data = buffer[index .. index + T.sizeof];
34 		return littleEndianToNative!(T, T.sizeof)(data);
35 	}
36 
37 	/** Read a string */
38 	pragma(inline) private string readStr(size_t index) {
39 		const length = cast(size_t) read!ulong(index);
40 		string result = cast(string) buffer[index + 8 .. index + 8 + length];
41 		return result;
42 	}
43 
44 	/** Read a single numeric value */
45 	pragma(inline) private T read(T)() {
46 		ubyte[T.sizeof] data = buffer[ip .. ip + T.sizeof];
47 		ip = ip + T.sizeof;
48 		return littleEndianToNative!(T, T.sizeof)(data);
49 	}
50 
51 	/** Read a string */
52 	pragma(inline) private string readStr() {
53 		const length = cast(size_t) read!ulong();
54 		string result = cast(string) buffer[ip .. ip + length];
55 		ip = ip + length;
56 		return result;
57 	}
58 
59 	override Value[] callf(FunctionValue func, Value[] args) {
60 		const size_t returnIp = ip;
61 		scope(exit) {
62 			ip = returnIp;
63 		}
64 
65 		running = func;
66 
67 		size_t numLocals = read!ulong(func.ip + 16);
68 		Value* locals = cast(Value*) alloca(Value.sizeof * numLocals);
69 		const size_t immediateStackSize = 16;
70 		Value[immediateStackSize] stack;
71 		size_t sp = -1;
72 
73 		foreach (i; 0 .. numLocals) {
74 			emplace(&locals[i]);
75 		}
76 
77 		ip = func.ip + 40;
78 
79 		pragma(inline) void push(Value value) {
80 			sp++;
81 			stack[sp] = value;
82 		}
83 
84 		/** Pop a value, as-is, off the stack */
85 		pragma(inline) Value pop() {
86 			Value res = stack[sp];
87 			stack[sp].heap = null;
88 			sp--;
89 			return res;
90 		}
91 
92 		/** Pop a value */
93 		pragma(inline) Value popv() {
94 			return pop();
95 		}
96 
97 		/** Pop a value as a tuple */
98 		pragma(inline) Value[] popt() {
99 			Value res = pop();
100 			if (res.type == ValueType.Tuple) {
101 				return res.tuple;
102 			}
103 			else {
104 				return [res];
105 			}
106 		}
107 
108 		pragma(inline) Nullable!double coercev(Value v) {
109 			if (v.type == ValueType.Number) {
110 				return v.num.nullable;
111 			}
112 			else if (v.type == ValueType.String) {
113 				try {
114 					return v.str.to!double.nullable;
115 				}
116 				catch (ConvException) {
117 					return Nullable!double();
118 				}
119 			}
120 			else {
121 				return Nullable!double();
122 			}
123 		}
124 
125 		pragma(inline) Nullable!string coerceToStr(Value v) {
126 			if (v.type == ValueType.Number) {
127 				return v.num.to!string.nullable;
128 			}
129 			else if (v.type == ValueType.String) {
130 				return v.str.nullable;
131 			}
132 			else {
133 				return Nullable!string();
134 			}
135 		}
136 
137 		pragma(inline) bool isCoerceable(Value v) {
138 			if (v.type == ValueType.Number)
139 				return true;
140 
141 			if (v.type == ValueType.String) {
142 				try {
143 					v.str.to!double;
144 					return true;
145 				}
146 				catch (ConvException) {
147 					return false;
148 				}
149 			}
150 
151 			return false;
152 		}
153 
154 		pragma(inline) void binaryOp(string ifNative, string meta,
155 				string varA = "a", string varB = "b")() {
156 			Value bVal = popv();
157 			Value aVal = popv();
158 			const Nullable!double aNum = coercev(aVal);
159 			const Nullable!double bNum = coercev(bVal);
160 			if (!aNum.isNull && !bNum.isNull) {
161 				mixin("const double ", varA, " = aNum.get;");
162 				mixin("const double ", varB, " = bNum.get;");
163 				push(Value(mixin(ifNative)));
164 				return;
165 			}
166 
167 			Nullable!(Value[]) attempt = aVal.metacall(meta, [aVal, bVal]);
168 			if (attempt.isNull) {
169 				attempt = bVal.metacall(meta, [aVal, bVal]);
170 			}
171 			if (attempt.isNull) {
172 				string type = aVal.typeStr;
173 				if (isCoerceable(aVal)) {
174 					type = bVal.typeStr;
175 				}
176 				throw new LuaError(Value("attempt to perform arithmetic on a " ~ type ~ " value"));
177 			}
178 			Value[] res = attempt.get;
179 			if (res.length == 0) {
180 				push(Value());
181 			}
182 			else {
183 				push(res[0]);
184 			}
185 		}
186 
187 		pragma(inline) void binaryOpStr(string ifNative, string meta,
188 				string varA = "a", string varB = "b")() {
189 			Value bVal = popv();
190 			Value aVal = popv();
191 			const Nullable!string aStr = coerceToStr(aVal);
192 			const Nullable!string bStr = coerceToStr(bVal);
193 			if (!aStr.isNull && !bStr.isNull) {
194 				mixin("const string ", varA, " = aStr.get;");
195 				mixin("const string ", varB, " = bStr.get;");
196 				push(Value(mixin(ifNative)));
197 				return;
198 			}
199 
200 			Nullable!(Value[]) attempt = aVal.metacall(meta, [aVal, bVal]);
201 			if (attempt.isNull) {
202 				attempt = bVal.metacall(meta, [aVal, bVal]);
203 			}
204 			if (attempt.isNull) {
205 				string type = aVal.typeStr;
206 				if (aVal.type == ValueType.String || aVal.type == ValueType.Number) {
207 					type = bVal.typeStr;
208 				}
209 				throw new LuaError(Value("attempt to concatenate a " ~ type ~ " value"));
210 			}
211 			Value[] res = attempt.get;
212 			if (res.length == 0) {
213 				push(Value());
214 			}
215 			else {
216 				push(res[0]);
217 			}
218 		}
219 
220 		pragma(inline) void unm(Value v) {
221 			const Nullable!double vals = coercev(v);
222 			if (!vals.isNull) {
223 				push(Value(-vals.get));
224 			}
225 			else {
226 				const Nullable!(Value[]) attempt = v.metacall("__unm", [v]);
227 				if (attempt.isNull) {
228 					string type = v.typeStr;
229 					throw new LuaError(Value("attempt to perform arithmetic on a " ~ type ~ " value"));
230 				}
231 				Value[] res = cast(Value[]) attempt.get;
232 				if (res.length == 0) {
233 					push(Value());
234 				}
235 				else {
236 					push(res[0]);
237 				}
238 			}
239 		}
240 
241 		pragma(inline) string getString(size_t index) {
242 			ulong localIP = func.ip;
243 			const ulong dataLength = read!ulong(localIP);
244 			const ulong codeLength = read!ulong(localIP + 32);
245 			localIP += 40 + codeLength;
246 			const ulong datasegPtr = localIP + (dataLength + 1) * 8;
247 			string val = readStr(datasegPtr + read!ulong(localIP + 8 * index));
248 			return val;
249 		}
250 
251 		struct ForState {
252 			double at;
253 			double high;
254 			double step;
255 			size_t var;
256 		}
257 
258 		ForState[] forStack;
259 
260 		while (true) {
261 			const op = cast(Opcode) read!OpcodeSize;
262 			switch (op) {
263 			case Opcode.Add:
264 				binaryOp!("a + b", "__add");
265 				break;
266 			case Opcode.Sub:
267 				binaryOp!("a - b", "__sub");
268 				break;
269 			case Opcode.Mul:
270 				binaryOp!("a * b", "__mul");
271 				break;
272 			case Opcode.Div:
273 				binaryOp!("a / b", "__div");
274 				break;
275 			case Opcode.Exp:
276 				binaryOp!("pow(a, b)", "__pow");
277 				break;
278 			case Opcode.Mod:
279 				binaryOp!("(a < 0 ? (a % b + b) % b : (a % b)) + (b < 0 && a > 0 ? b : 0)", "__mod");
280 				break;
281 			case Opcode.Unm:
282 				Value v = popv();
283 				unm(v);
284 				break;
285 			case Opcode.Not:
286 				Value v = popv();
287 				push(Value(!v.toBool));
288 				break;
289 			case Opcode.Len:
290 				Value v = popv();
291 				push(v.length);
292 				break;
293 			case Opcode.Concat:
294 				binaryOpStr!("a ~ b", "__concat");
295 				break;
296 			case Opcode.Eq:
297 				Value b = popv();
298 				Value a = popv();
299 				push(Value(a.equals(b)));
300 				break;
301 			case Opcode.Ne:
302 				Value b = popv();
303 				Value a = popv();
304 				push(Value(!a.equals(b)));
305 				break;
306 			case Opcode.Lt:
307 				Value b = popv();
308 				Value a = popv();
309 				push(Value(a.lessThan(b)));
310 				break;
311 			case Opcode.Le:
312 				Value b = popv();
313 				Value a = popv();
314 				push(Value(a.lessOrEqual(b)));
315 				break;
316 			case Opcode.Gt:
317 				Value b = popv();
318 				Value a = popv();
319 				push(Value(b.lessThan(a)));
320 				break;
321 			case Opcode.Ge:
322 				Value b = popv();
323 				Value a = popv();
324 				push(Value(b.lessOrEqual(a)));
325 				break;
326 			case Opcode.Ret:
327 				return popt();
328 			case Opcode.Getfenv:
329 				push(Value(func.env));
330 				break;
331 			case Opcode.Call:
332 				Value[] callArgs = popt();
333 				Value base = popv();
334 				push(Value.makeTuple(base.call(callArgs)));
335 				break;
336 			case Opcode.NamecallPrep:
337 				size_t index = cast(size_t) read!CommonOperand;
338 				Value base = popv();
339 				push(base);
340 				push(base.get(Value(getString(index))));
341 				break;
342 			case Opcode.Namecall:
343 				Value[] callArgs = popt();
344 				Value method = popv();
345 				Value base = popv();
346 				push(Value.makeTuple(method.call(base ~ callArgs)));
347 				break;
348 			case Opcode.Drop:
349 				pop();
350 				break;
351 			case Opcode.Dup:
352 				Value v = pop();
353 				push(v);
354 				push(v);
355 				break;
356 			case Opcode.DupN:
357 				size_t count = cast(size_t) read!StackOffset;
358 				Value v = pop();
359 				foreach (i; 0 .. count + 1) {
360 					push(v);
361 				}
362 				break;
363 			case Opcode.LdNil:
364 				push(Value());
365 				break;
366 			case Opcode.LdFalse:
367 				push(Value(false));
368 				break;
369 			case Opcode.LdTrue:
370 				push(Value(true));
371 				break;
372 			case Opcode.LdArgs:
373 				push(Value.makeTuple(args));
374 				break;
375 			case Opcode.NewTable:
376 				push(Value(new TableValue));
377 				break;
378 			case Opcode.GetTable:
379 				Value key = popv();
380 				Value base = popv();
381 				push(base.get(key));
382 				break;
383 			case Opcode.SetTable:
384 				Value value = popv();
385 				Value key = popv();
386 				Value base = popv();
387 				base.set(key, value);
388 				break;
389 			case Opcode.SetTableRev:
390 				Value base = popv();
391 				Value key = popv();
392 				Value value = popv();
393 				base.set(key, value);
394 				break;
395 			case Opcode.SetArray:
396 				const count = cast(size_t) read!StackOffset;
397 				Value[] tuple;
398 
399 				foreach (i; 0 .. count) {
400 					tuple ~= pop();
401 				}
402 
403 				tuple = Value.makeTuple(tuple.reverse).tuple;
404 
405 				Value table = popv();
406 
407 				foreach (i; 0 .. tuple.length) {
408 					table.set(Value(i + 1), tuple[i]);
409 				}
410 
411 				break;
412 			case Opcode.DropLoop:
413 				forStack = forStack[0 .. $ - 1];
414 				break;
415 			case Opcode.LdStr:
416 				const index = cast(size_t) read!CommonOperand;
417 				push(Value(getString(index)));
418 				break;
419 			case Opcode.Jmp:
420 				size_t jumpTo = cast(size_t) read!FullWidth;
421 				ip = jumpTo;
422 				break;
423 			case Opcode.JmpT:
424 				size_t jumpTo = cast(size_t) read!FullWidth;
425 				const Value v = popv();
426 				if (v.toBool) {
427 					ip = jumpTo;
428 				}
429 				break;
430 			case Opcode.JmpF:
431 				size_t jumpTo = cast(size_t) read!FullWidth;
432 				const Value v = popv();
433 				if (!v.toBool) {
434 					ip = jumpTo;
435 				}
436 				break;
437 			case Opcode.JmpNil:
438 				size_t jumpTo = cast(size_t) read!FullWidth;
439 				if (popv().isNil) {
440 					ip = jumpTo;
441 				}
442 				break;
443 			case Opcode.Pack:
444 				const count = cast(size_t) read!StackOffset;
445 				Value[] tuple;
446 
447 				foreach (i; 0 .. count) {
448 					tuple ~= pop();
449 				}
450 
451 				push(Value.makeTuple(tuple.reverse));
452 				break;
453 			case Opcode.Unpack:
454 				const count = cast(size_t) read!StackOffset;
455 				const Value[] last = popt();
456 
457 				if (count > last.length) {
458 					foreach (i; last.length .. count) {
459 						push(Value());
460 					}
461 					foreach_reverse (i; 0 .. last.length) {
462 						push(last[i]);
463 					}
464 				}
465 				else {
466 					foreach_reverse (i; 0 .. count) {
467 						push(last[i]);
468 					}
469 				}
470 
471 				break;
472 			case Opcode.UnpackD:
473 				const count = cast(size_t) read!StackOffset;
474 				Value[] last = popt();
475 
476 				if (count > last.length) {
477 					foreach (i; last.length .. count) {
478 						push(Value());
479 					}
480 					foreach_reverse (i; 0 .. last.length) {
481 						push(last[i]);
482 					}
483 				}
484 				else {
485 					foreach_reverse (i; 0 .. count) {
486 						push(last[i]);
487 					}
488 				}
489 
490 				push(Value.rawTupleUnsafe(count >= last.length ? [] : last[count .. $]));
491 
492 				break;
493 			case Opcode.UnpackRev:
494 				const count = cast(size_t) read!StackOffset;
495 				const Value[] last = popt();
496 
497 				if (count > last.length) {
498 					foreach (i; 0 .. last.length) {
499 						push(last[i]);
500 					}
501 					foreach (i; last.length .. count) {
502 						push(Value());
503 					}
504 				}
505 				else {
506 					foreach (i; 0 .. count) {
507 						push(last[i]);
508 					}
509 				}
510 
511 				break;
512 			case Opcode.Mkhv:
513 				const size_t var = cast(size_t) read!CommonOperand;
514 				locals[var] = Value(new Value);
515 				break;
516 			case Opcode.Get:
517 				push(locals[cast(size_t) read!CommonOperand]);
518 				break;
519 			case Opcode.Set:
520 				locals[cast(size_t) read!CommonOperand] = popv();
521 				break;
522 			case Opcode.GetC:
523 				push(*func.upvalues[cast(size_t) read!CommonOperand]);
524 				break;
525 			case Opcode.SetC:
526 				*func.upvalues[cast(size_t) read!CommonOperand] = popv();
527 				break;
528 			case Opcode.GetRef:
529 				Value refv = locals[cast(size_t) read!CommonOperand];
530 				assert(refv.type == ValueType.Heap);
531 				push(*refv.heap);
532 				break;
533 			case Opcode.SetRef:
534 				Value refv = locals[cast(size_t) read!CommonOperand];
535 				assert(refv.type == ValueType.Heap);
536 				*refv.heap = popv();
537 				break;
538 			case Opcode.ForPrep:
539 				size_t var = cast(size_t) read!CommonOperand;
540 				const Nullable!double step = coercev(popv());
541 				const Nullable!double high = coercev(popv());
542 				const Nullable!double low = coercev(popv());
543 				if (low.isNull)
544 					throw new LuaError(Value("'for' initial value must be a number"));
545 				if (high.isNull)
546 					throw new LuaError(Value("'for' limit must be a number"));
547 				if (step.isNull)
548 					throw new LuaError(Value("'for' step must be a number"));
549 				ForState state;
550 				state.var = var;
551 				state.high = high.get;
552 				state.step = step.get;
553 				state.at = low.get - state.step;
554 				forStack.assumeSafeAppend ~= state;
555 				locals[var] = Value(state.at);
556 				break;
557 			case Opcode.Loop:
558 				size_t jumpTo = cast(size_t) read!FullWidth;
559 				ForState* state = &forStack[$ - 1];
560 				state.at += state.step;
561 				locals[state.var] = Value(state.at);
562 				bool repeat = false;
563 				if (state.step < 0) {
564 					repeat = state.at >= state.high;
565 				}
566 				else if (state.step > 0) {
567 					repeat = state.at <= state.high;
568 				}
569 				if (!repeat) {
570 					ip = jumpTo;
571 				}
572 				break;
573 			case Opcode.Introspect:
574 				size_t offset = cast(size_t) read!StackOffset;
575 				push(stack[sp - offset]);
576 				break;
577 			case Opcode.DropTuple:
578 				size_t amount = cast(size_t) read!StackOffset;
579 				foreach (i; 0 .. amount)
580 					pop();
581 				break;
582 			case Opcode.LdNum:
583 				push(Value(read!double));
584 				break;
585 			case Opcode.LdFun:
586 				const ulong index = read!ulong;
587 				const ulong upvaluesCount = read!ulong;
588 				Value*[] upvalues;
589 				foreach (i; 0 .. upvaluesCount) {
590 					const long uv = read!long;
591 					if (uv < 0) {
592 						upvalues ~= func.upvalues[cast(size_t)~uv];
593 					}
594 					else {
595 						Value val = locals[cast(size_t) uv];
596 						assert(val.type == ValueType.Heap);
597 						upvalues ~= val.heap;
598 					}
599 				}
600 				const save = ip;
601 				ip = func.ip;
602 				const ulong dataLength = read!ulong;
603 				const ulong funcLength = read!ulong;
604 				read!ulong;
605 				read!ulong;
606 				const ulong codeLength = read!ulong;
607 				ip = ip + codeLength;
608 				ip = ip + 8 * dataLength;
609 				const ulong datasegSize = read!ulong;
610 				ip = ip + datasegSize;
611 				const ulong funcsegIndices = ip;
612 				const ulong funcsegPtr = ip + (funcLength + 1) * 8;
613 				ip = funcsegIndices + 8 * index;
614 				ip = funcsegPtr + read!ulong;
615 				FunctionValue val = new FunctionValue;
616 				val.env = func.env;
617 				val.ip = ip;
618 				val.engine = this;
619 				val.upvalues = upvalues;
620 				push(Value(val));
621 				ip = save;
622 				break;
623 			default:
624 				assert(0, "I don't know how to handle " ~ op.to!string);
625 			}
626 		}
627 	}
628 
629 	/** Get the FunctionValue for the toplevel function */
630 	FunctionValue getToplevel(TableValue env) {
631 		auto res = new FunctionValue;
632 		res.engine = this;
633 		res.env = env;
634 		res.ip = 0;
635 		return res;
636 	}
637 
638 }