1 module zua.interop.classwrapper;
2 import zua.interop.functions;
3 import zua.interop.userdata;
4 import zua.interop.table;
5 import zua.interop;
6 import zua.vm.engine;
7 import std.typecons;
8 import std.traits;
9 import std.range;
10 import std.meta;
11 import std.uuid;
12 
13 private DConsumable makeFunctionFromOverloads(bool isStatic, string member, T...)(T overloads) {
14 	pragma(inline) DConsumable[] func(DConsumable[] consumableArgs...) {
15 		// first we check for an exact match, then we try an approximate match
16 		static foreach (iterations; 0..3) {
17 			static foreach (i; 0..overloads.length) {{
18 				auto func = overloads[i];
19 				alias U = Unqual!(T[i]);
20 				try {
21 					// if this is the first time: exact match
22 					// second time: exact length
23 					// third time: any fit
24 					auto args = convertParameters!(U, member, iterations)(consumableArgs);
25 					static if (is(ReturnType!U == void)) {
26 						func(args.expand);
27 						return [];
28 					}
29 					else {
30 						return func(args.expand).convertReturn!U;
31 					}
32 				}
33 				catch (ConversionException e) {
34 					// Do nothing
35 				}
36 				catch (Exception e) {
37 					if (cast(LuaError)e) {
38 						throw e;
39 					}
40 					else {
41 						throw new LuaError(Value(e.msg));
42 					}
43 				}
44 			}}
45 		}
46 		// if we didn't find a match, we throw an error:
47 		try {
48 			convertParameters!(Unqual!(typeof(overloads[0])), member)(consumableArgs);
49 		}
50 		catch (Exception e) {
51 			if (cast(LuaError)e) {
52 				throw e;
53 			}
54 			else {
55 				throw new LuaError(Value(e.msg));
56 			}
57 		}
58 		assert(0);
59 	}
60 
61 	static if (isStatic) {
62 		return DConsumable(delegate DConsumable[](DConsumable[] args...) {
63 			return func(args);
64 		});
65 	}
66 	else {
67 		return DConsumable(delegate DConsumable[](Userdata _, DConsumable[] args...) {
68 			return func(args);
69 		});
70 	}
71 }
72 
73 alias ClassConverter(T) = Userdata delegate(T instance);
74 
75 private Tuple!(DConsumable, ClassConverter!Object)[TypeInfo] classWrapperMemo;
76 
77 /** Create a class wrapper for use in Lua */
78 Tuple!(DConsumable, ClassConverter!T) makeClassWrapper(T)() if (is(T == class)) {
79 	TypeInfo info = typeid(T);
80 	if (info in classWrapperMemo) {
81 		auto res = classWrapperMemo[info];
82 		return tuple(res[0], cast(ClassConverter!T)res[1]);
83 	}
84 	else {
85 		auto res = makeClassWrapperUnmemoized!T;
86 		classWrapperMemo[info] = tuple(res[0], cast(ClassConverter!Object)res[1]);
87 		return res;
88 	}
89 }
90 
91 private template AllFieldNamesTuple(alias T) {
92 	alias BaseTuple = TransitiveBaseTypeTuple!T;
93 	enum AllFieldNamesTuple = staticMap!(FieldNameTuple, AliasSeq!(T, BaseTuple));
94 }
95 
96 private template AllFields(alias T) {
97 	alias BaseTuple = TransitiveBaseTypeTuple!T;
98 	alias AllFields = staticMap!(Fields, AliasSeq!(T, BaseTuple));
99 }
100 
101 private template IsVisible(alias T) {
102 	static if (__traits(getProtection, T) == "public") {
103 		enum IsVisible = true;
104 	}
105 	else {
106 		enum IsVisible = false;
107 	}
108 }
109 
110 private UUID classWrapperId;
111 private UUID classWrapperStaticId;
112 
113 static this() {
114 	classWrapperId = randomUUID();
115 	classWrapperStaticId = randomUUID();
116 }
117 
118 /** Returns true if the given userdata is a class wrapper */
119 bool isClassWrapper(Userdata data) {
120 	return data.owner == classWrapperId;
121 }
122 
123 /** Returns true if the given userdata is a static class wrapper */
124 bool isStaticClassWrapper(Userdata data) {
125 	return data.owner == classWrapperStaticId;
126 }
127 
128 private Tuple!(DConsumable, ClassConverter!T) makeClassWrapperUnmemoized(T)() if (is(T == class)) {
129 	Userdata staticClass = Userdata.create(cast(void*)-1);
130 	staticClass.owner = classWrapperStaticId;
131 
132 	Table staticMeta = Table.create();
133 	Table instanceMeta = Table.create();
134 
135 	Userdata constructor(DConsumable[] consumableArgs...) {
136 		T res;
137 		static if (!__traits(hasMember, res, "__ctor")) {
138 			res = new T;
139 			auto ures = Userdata.create(cast(void*)res, instanceMeta.Nullable!Table);
140 			ures.owner = classWrapperId;
141 			return ures;
142 		}
143 		else {
144 			alias Overloads = __traits(getOverloads, res, "__ctor");
145 			alias GetPointer(alias U) = typeof(&U);
146 			alias GetDelegate(alias U) = ReturnType!U delegate(Parameters!U);
147 			alias GetDelegateFromPointer(alias U) = GetDelegate!(GetPointer!U);
148 			alias OverloadsArray = staticMap!(GetDelegateFromPointer, Overloads);
149 			// first we check for an exact match, then we try an approximate match
150 			static foreach (iterations; 0..3) {
151 				static foreach (i; 0..OverloadsArray.length) {{
152 					alias U = Unqual!(OverloadsArray[i]);
153 					try {
154 						// if this is the first time: exact match
155 						// second time: exact length
156 						// third time: any fit
157 						auto args = convertParameters!(U, "new", iterations)(consumableArgs);
158 						res = new T(args.expand);
159 						auto ures = Userdata.create(cast(void*)res, instanceMeta.Nullable!Table);
160 						ures.owner = classWrapperId;
161 						return ures;
162 					}
163 					catch (ConversionException e) {
164 						// Do nothing
165 					}
166 					catch (Exception e) {
167 						if (cast(LuaError)e) {
168 							throw e;
169 						}
170 						else {
171 							throw new LuaError(Value(e.msg));
172 						}
173 					}
174 				}}
175 			}
176 			// if we didn't find a match, we throw an error:
177 			try {
178 				convertParameters!(Unqual!(OverloadsArray[0]), "new")(consumableArgs);
179 			}
180 			catch (Exception e) {
181 				if (cast(LuaError)e) {
182 					throw e;
183 				}
184 				else {
185 					throw new LuaError(Value(e.msg));
186 				}
187 			}
188 			assert(0);
189 		}
190 	}
191 
192 	enum bool NotSpecial(string T) =
193 		T != "toString" && T != "toHash" && T != "Monitor" && T != "factory"
194 		&& T != "opUnary" && T != "opIndexUnary" && T != "opSlice" && T != "opCast"
195 		&& T != "opBinary" && T != "opBinaryRight" && T != "opEquals" && T != "opCmp"
196 		&& T != "opCall" && T != "opAssign" && T != "opIndexAssign" && T != "opOpAssign"
197 		&& T != "opIndexOpAssign" && T != "opIndex" && T != "opDollar" && T != "opDispatch";
198 	alias Members = Filter!(NotSpecial, __traits(allMembers, T));
199 
200 	DConsumable instanceIndex(Userdata lself, string key) {
201 		T self = cast(T)lself.data; // @suppress(dscanner.suspicious.unused_variable)
202 
203 		static foreach (member; Members) {{
204 			static if (!hasStaticMember!(T, member) && member[0] != '_') {
205 				if (member == key) {
206 					enum size_t index = staticIndexOf!(member, AllFieldNamesTuple!T);
207 					static if (index != -1) {{
208 						alias FieldType = AllFields!T[index];
209 						static if (isConvertible!FieldType && IsVisible!(__traits(getMember, self, member))) {
210 							return DConsumable(__traits(getMember, self, member));
211 						}
212 					}}
213 					else {
214 						alias Overloads = Filter!(IsVisible, __traits(getOverloads, self, member));
215 						static if (Overloads.length > 0) {
216 							alias GetPointer(alias U) = typeof(&U);
217 							alias GetDelegate(alias U) = ReturnType!U delegate(Parameters!U);
218 							alias GetDelegateFromPointer(alias U) = GetDelegate!(GetPointer!U);
219 							alias OverloadsArray = staticMap!(GetDelegateFromPointer, Overloads);
220 							Tuple!OverloadsArray overloads;
221 							template FindOverload(string file, int line, int col) {
222 								alias ContextedOverloads = __traits(getOverloads, self, member);
223 								static foreach (i; 0..ContextedOverloads.length) {
224 									static if (AliasSeq!(__traits(getLocation, ContextedOverloads[i])) == AliasSeq!(file, line, col)) {
225 										enum FindOverload = i;
226 									}
227 								}
228 							}
229 							static foreach (i; 0..OverloadsArray.length) {
230 								overloads[i] = &__traits(getOverloads, self, member)[FindOverload!(__traits(getLocation, Overloads[i]))];
231 							}
232 							DConsumable func = makeFunctionFromOverloads!(false, member, OverloadsArray)(overloads.expand);
233 							static if (hasFunctionAttributes!(__traits(getMember, self, member), "@property")) {
234 								return (cast(DConsumableFunction)func)([DConsumable(lself)])[0];
235 							}
236 							else {
237 								return func;
238 							}
239 						}
240 					}
241 				}
242 			}
243 		}}
244 
245 		throw new Exception("attempt to index member '" ~ key ~ "'");
246 	}
247 
248 	void instanceNewIndex(Userdata lself, string key, DConsumable value) {
249 		T self = cast(T)lself.data; // @suppress(dscanner.suspicious.unused_variable)
250 
251 		static foreach (member; Members) {{
252 			static if (!hasStaticMember!(T, member) && member[0] != '_') {
253 				if (member == key) {
254 					enum size_t index = staticIndexOf!(member, AllFieldNamesTuple!T);
255 					static if (index != -1) {{
256 						alias FieldType = AllFields!T[index];
257 						static if (isConvertible!FieldType && IsVisible!(__traits(getMember, self, member))) {
258 							__traits(getMember, self, member) = value.opCast!(FieldType, 3);
259 							return;
260 						}
261 					}}
262 					else {
263 						static if (hasFunctionAttributes!(__traits(getMember, self, member), "@property")) {
264 							alias Overloads = Filter!(IsVisible, __traits(getOverloads, self, member));
265 							static if (Overloads.length > 0) {
266 								alias GetPointer(alias U) = typeof(&U);
267 								alias GetDelegate(alias U) = ReturnType!U delegate(Parameters!U);
268 								alias GetDelegateFromPointer(alias U) = GetDelegate!(GetPointer!U);
269 								alias OverloadsArray = staticMap!(GetDelegateFromPointer, Overloads);
270 								Tuple!OverloadsArray overloads;
271 								template FindOverload(string file, int line, int col) {
272 									alias ContextedOverloads = __traits(getOverloads, self, member);
273 									static foreach (i; 0..ContextedOverloads.length) {
274 										static if (AliasSeq!(__traits(getLocation, ContextedOverloads[i])) == AliasSeq!(file, line, col)) {
275 											enum FindOverload = i;
276 										}
277 									}
278 								}
279 								static foreach (i; 0..OverloadsArray.length) {
280 									overloads[i] = &__traits(getOverloads, self, member)[FindOverload!(__traits(getLocation, Overloads[i]))];
281 								}
282 								DConsumable func = makeFunctionFromOverloads!(false, member, OverloadsArray)(overloads.expand);
283 								(cast(DConsumableFunction)func)([DConsumable(lself), value]);
284 								return;
285 							}
286 						}
287 						else {
288 							throw new Exception("attempt to modify member '" ~ key ~ "'");
289 						}
290 					}
291 				}
292 			}
293 		}}
294 
295 		throw new Exception("attempt to modify member '" ~ key ~ "'");
296 	}
297 
298 	instanceMeta["__index"] = &instanceIndex;
299 	instanceMeta["__newindex"] = &instanceNewIndex;
300 	instanceMeta["__tostring"] = delegate(Userdata lself) {
301 		T self = cast(T)lself.data;
302 		return self.toString;
303 	};
304 	instanceMeta["__metatable"] = "The metatable is locked";
305 
306 	DConsumable staticIndex(Userdata, string key) {
307 		if (key == "new") {
308 			DConsumable res;
309 			res.__ctor!(typeof(&constructor), "new")(&constructor);
310 			return res;
311 		}
312 
313 		static foreach (member; Members) {{
314 			static if (hasStaticMember!(T, member) && member[0] != '_') {
315 				if (member == key) {
316 					static if (!__traits(isStaticFunction, __traits(getMember, T, member))) {{
317 						alias FieldType = typeof(__traits(getMember, T, member));
318 						static if (isConvertible!FieldType && IsVisible!(__traits(getMember, T, member))) {
319 							return DConsumable(__traits(getMember, T, member));
320 						}
321 					}}
322 					else {
323 						alias Overloads = Filter!(IsVisible, __traits(getOverloads, T, member));
324 						static if (Overloads.length > 0) {
325 							alias GetPointer(alias U) = typeof(&U);
326 							alias OverloadsArray = staticMap!(GetPointer, Overloads);
327 							Tuple!OverloadsArray overloads;
328 							static foreach (i; 0..OverloadsArray.length) {
329 								overloads[i] = &Overloads[i];
330 							}
331 							DConsumable func = makeFunctionFromOverloads!(true, member, OverloadsArray)(overloads.expand);
332 							static if (hasFunctionAttributes!(__traits(getMember, T, member), "@property")) {
333 								return (cast(DConsumableFunction)func)([])[0];
334 							}
335 							else {
336 								return func;
337 							}
338 						}
339 					}
340 				}
341 			}
342 		}}
343 
344 		throw new Exception("attempt to index member '" ~ key ~ "'");
345 	}
346 
347 	void staticNewIndex(Userdata, string key, DConsumable value) {
348 		static foreach (member; Members) {{
349 			static if (hasStaticMember!(T, member) && member[0] != '_') {
350 				if (member == key) {
351 					static if (!__traits(isStaticFunction, __traits(getMember, T, member))) {{
352 						alias FieldType = typeof(__traits(getMember, T, member));
353 						static if (isConvertible!FieldType && IsVisible!(__traits(getMember, T, member))) {
354 							__traits(getMember, T, member) = value.opCast!(FieldType, 3);
355 							return;
356 						}
357 					}}
358 					else {
359 						static if (hasFunctionAttributes!(__traits(getMember, T, member), "@property")) {
360 							alias Overloads = Filter!(IsVisible, __traits(getOverloads, T, member));
361 							alias GetPointer(alias U) = typeof(&U);
362 							alias OverloadsArray = staticMap!(GetPointer, Overloads);
363 							Tuple!OverloadsArray overloads;
364 							static foreach (i; 0..OverloadsArray.length) {
365 								overloads[i] = &Overloads[i];
366 							}
367 							DConsumable func = makeFunctionFromOverloads!(true, member, OverloadsArray)(overloads.expand);
368 							(cast(DConsumableFunction)func)([value]);
369 							return;
370 						}
371 						else {
372 							throw new Exception("attempt to modify member '" ~ key ~ "'");
373 						}
374 					}
375 				}
376 			}
377 		}}
378 
379 		throw new Exception("attempt to modify member '" ~ key ~ "'");
380 	}
381 
382 	staticMeta["__index"] = &staticIndex;
383 	staticMeta["__newindex"] = &staticNewIndex;
384 	staticMeta["__tostring"] = delegate() {
385 		return fullyQualifiedName!T;
386 	};
387 	staticMeta["__metatable"] = "The metatable is locked";
388 
389 	staticClass.metatable = staticMeta;
390 
391 	return tuple(DConsumable(staticClass), delegate Userdata(T instance) {
392 		auto ures = Userdata.create(cast(void*)instance, instanceMeta.Nullable!Table);
393 		ures.owner = classWrapperId;
394 		return ures;
395 	});
396 }
397 
398 version(unittest) {
399 	class C {
400 
401 		static int y = 5;
402 		int x;
403 
404 		int rand() {
405 			return 4; // chosen randomly by a dice roll
406 		}
407 
408 		int rand2() const @property {
409 			return 4; // see above
410 		}
411 
412 		void xMangler(int x) @property {
413 			this.x = x * 8;
414 		}
415 
416 		static int rand3() @property {
417 			return 6; // we decided to add another dice roll into the mix
418 		}
419 
420 		static void yMangler(int y) @property {
421 			C.y = y * 9;
422 		}
423 
424 		int foo(int x) {
425 			return x * 3;
426 		}
427 
428 		string foo(string x) {
429 			return x ~ " is fun";
430 		}
431 
432 		int foo(int x, int y) {
433 			return x + y * 100;
434 		}
435 
436 		static int goo() {
437 			return 7;
438 		}
439 
440 		static int goo(int s) {
441 			return s * 2;
442 		}
443 
444 		override string toString() const {
445 			return "C is a class";
446 		}
447 
448 	}
449 
450 	class D : C {
451 
452 		private int z = 10;
453 
454 		override int rand() {
455 			return 5; // turns out the last one wasn't so random
456 		}
457 
458 		int fooey() {
459 			return 71;
460 		}
461 
462 		int fooey(int s) {
463 			return s * 2;
464 		}
465 
466 		private int fooey(string s) {
467 			return cast(int)s.length;
468 		}
469 
470 		private void privateProp(int) @property {
471 			throw new Exception("how could you fail these tests");
472 		}
473 
474 		private static void privateStatic() {}
475 		protected static void protectedStatic() {}
476 		static void publicStatic() {}
477 
478 		void privateProp(string a) @property {
479 			z = cast(int)a.length * 12;
480 		}
481 
482 		int getZ() @property {
483 			return z;
484 		}
485 
486 		int go() {
487 			return x * 3;
488 		}
489 
490 	}
491 
492 	class E {
493 
494 		int x;
495 
496 		this(int x) {
497 			this.x = x;
498 		}
499 
500 		this(string y) {
501 			this.x = cast(int)y.length;
502 		}
503 
504 		this(string y, int z) {
505 			this.x = cast(int)y.length * z;
506 		}
507 
508 	}
509 }
510 
511 unittest {
512 	import zua;
513 	import std.stdio;
514 
515 	Common c = new Common(GlobalOptions.FullAccess);
516 
517 	c.env.expose!("C", C);
518 	c.env.expose!("D", D);
519 	c.env.expose!("E", E);
520 	c.env["ins2"] = new C;
521 
522 	int getX(C self) {
523 		return self.x;
524 	}
525 
526 	c.env.expose!"getX"(&getX);
527 
528 	try {
529 		c.run("file.lua", q"{
530 			assert(tostring(C) == "zua.interop.classwrapper.C")
531 			assert(ins2.rand2 == 4)
532 			local ins3 = D.new()
533 			assert(not pcall(function() return ins3.z end))
534 			assert(ins3:fooey(3) == 6)
535 			assert(ins3:fooey("3") == 6)
536 			ins3.privateProp = 32
537 			assert(ins3.getZ == 24)
538 			assert(not pcall(function() return D.privateStatic end))
539 			assert(not pcall(function() return D.protectedStatic end))
540 			assert(pcall(D.publicStatic))
541 			assert(tostring(ins3) == "C is a class")
542 			assert(not pcall(function() return ins3.toString end))
543 			assert(ins3:fooey() == 71)
544 			assert(ins3:rand() == 5)
545 			assert(ins3.x == 0)
546 			assert(getX(ins3) == 0)
547 			ins3.x = 7
548 			assert(ins3:go() == 21)
549 			assert(getX(ins3) == 7)
550 			local ins = C.new()
551 			assert(tostring(ins) == "C is a class")
552 			assert(ins.x == 0)
553 			assert(getX(ins) == 0)
554 			assert(ins:foo(2) == 6)
555 			assert(ins:foo(2.9) == 6)
556 			assert(ins:foo("programming") == "programming is fun")
557 			assert(ins:foo(7, "8") == 807)
558 			ins.x = 23
559 			assert(getX(ins) == 23)
560 			assert(ins.x == 23)
561 			assert(select(2, pcall(ins.foo, ins)) == "bad argument #1 to 'foo' (number expected, got nil)")
562 			assert(C.y == 5)
563 			C.y = 18
564 			assert(C.y == 18)
565 			assert(C.goo() == 7)
566 			assert(C.goo(4) == 8)
567 			assert(not pcall(function() return C.foo end))
568 			assert(ins:rand() == 4)
569 			assert(ins.rand2 == 4)
570 			ins.xMangler = 16
571 			assert(ins.x == 128)
572 			assert(C.rand3 == 6)
573 			C.yMangler = 8
574 			assert(C.y == 72)
575 			assert(not pcall(E.new))
576 			assert(E.new(12).x == 12)
577 			assert(E.new("12").x == 2)
578 			assert(E.new(12, 3).x == 6)
579 			local e = E.new(3)
580 			assert(not pcall(getX, e))
581 		}");
582 	}
583 	catch (LuaError e) {
584 		stderr.writeln("Error: " ~ e.data.toString);
585 		assert(0);
586 	}
587 }