summaryrefslogtreecommitdiff
path: root/www/wiki/extensions/Scribunto/tests/phpunit/engines/LuaCommon/TestFramework.lua
blob: 6231695ab8c897cc3d8b2b84d698fa22bcc6669b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
local testframework = testframework or {}

-- Return a string represetation of a value, including the deep structure of a table
local function deepToString( val, indent, done )
	done = done or {}
	indent = indent or 0

	local tp = type( val )
	if tp == 'string' then
		return string.format( "%q", val )
	elseif tp == 'table' then
		if done[val] then return '{ ... }' end
		done[val] = true
		local sb = { '{\n' }
		local donekeys = {}
		for key, value in ipairs( val ) do
			donekeys[key] = true
			sb[#sb + 1] = string.rep( " ", indent + 2 )
			sb[#sb + 1] = deepToString( value, indent + 2, done )
			sb[#sb + 1] = ",\n"
		end
		local keys = {}
		for key in pairs( val ) do
			if not donekeys[key] then
				keys[#keys + 1] = key
			end
		end
		table.sort( keys )
		for i = 1, #keys do
			local key = keys[i]
			sb[#sb + 1] = string.rep( " ", indent + 2 )
			if type( key ) == 'table' then
				sb[#sb + 1] = '[{ ... }] = '
			else
				sb[#sb + 1] = '['
				sb[#sb + 1] = deepToString( key, indent + 3, done )
				sb[#sb + 1] = '] = '
			end
			sb[#sb + 1] = deepToString( val[key], indent + 2, done )
			sb[#sb + 1] = ",\n"
		end
		sb[#sb + 1] = string.rep( " ", indent )
		sb[#sb + 1] = "}"
		return table.concat( sb )
	else
		return tostring( val )
	end
end
testframework.deepToString = deepToString

-- Test whether two objects are equal, including the deep structure of a table.
-- Returns 4 values:
--  boolean  equal?
--  list     key path to first inequality
--  mixed    value from 'a' for key path
--  mixed    value from 'b' for key path
local function deepEquals( a, b, keypath, done )
	-- Simple equality
	if a == b then
		return true
	end

	keypath = keypath or {}
	done = done or {}

	-- Must be equal types to be equal
	local tp = type( a )
	if type( b ) ~= tp then
		return false, keypath, a, b
	end

	-- Special tests for certain types

	if tp == 'number' then
		-- For test framework purposes, NaNs are equivalent. Lua has no
		-- standard "isNaN" function, but only NaN will return true for
		-- "x ~= x".
		if a ~= a and b ~= b then
			return true
		end

		return false, keypath, a, b
	end

	if tp == 'table' then
		-- To avoid recursion, see if we've seen this pair of tables before. If
		-- so, they must be equal or the test would have failed the first time we saw them.
		done[a] = done[a] or {}
		done[b] = done[b] or {}
		if done[a][b] or done[b][a] then
			return true
		end

		-- Not seen before, record them and compare key by key.
		done[a][b] = true

		local n = #keypath + 1
		-- First, check if the values for all keys in 'a' are equal in 'b'.
		for k in pairs( a ) do
			keypath[n] = k
			local ok, kp, aa, bb = deepEquals( a[k], b[k], keypath, done )
			if not ok then
				return false, kp, aa, bb
			end
		end
		keypath[n] = nil

		-- Then check if there are any keys in 'b' that don't exist in 'a'.
		for k, v in pairs( b ) do
			if a[k] == nil then
				keypath[n] = k
				return false, keypath, nil, v
			end
		end

		-- Ok, all keys equal so it must match.
		return true
	end

	-- Ok, they're not equal
	return false, keypath, a, b
end
testframework.deepEquals = deepEquals

-- Skip a test (throws an error)
function testframework.markTestSkipped( message )
	error( 'SKIP: ' .. message, 0 )
end

---- Test types available ---
-- Each type has a formatter and an executor:
--  Formatters take 1 arg: expected return value from the function.
--  Executors take 2 args: function and arguments.
--  Both return a string. The test passes if the two strings match.
testframework.types = testframework.types or {}

-- Execute a function and assert expected results
-- Expected value is a list of return values, or a string error message
testframework.types.Normal = {
	format = function ( expect )
		if type( expect ) == 'string' then
			return 'ERROR: ' .. expect
		else
			return deepToString( expect )
		end
	end,
	exec = function ( func, args )
		local got = { pcall( func, unpack( args ) ) }
		if table.remove( got, 1 ) then
			return deepToString( got )
		else
			if string.sub( got[1], 1, 6 ) == 'SKIP: ' then
				error( got[1], 0 )
			end
			got = string.gsub( got[1], '^%S+:%d+: ', '' )
			return 'ERROR: ' .. got
		end
	end
}

-- Execute an iterator-returning function and assert expected results from each
-- iteration.
-- Expected value is a list of return value lists.
testframework.types.Iterator = {
	format = function ( expect )
		local sb = {}
		for i = 1, #expect do
			sb[i] = '[iteration ' .. i .. ']:\n' .. deepToString( expect[i] )
		end
		return table.concat( sb, '\n\n' )
	end,
	exec = function ( func, args )
		local sb = {}
		local i = 0
		local f, s, var = func( unpack( args ) )
		while true do
			local got = { f( s, var ) }
			var = got[1]
			if var == nil then break end
			i = i + 1
			sb[i] = '[iteration ' .. i .. ']:\n' .. deepToString( got )
		end
		return table.concat( sb, '\n\n' )
	end
}

-- Execute a function and assert expected results
-- Expected value is a list of return values, or a string error message
testframework.types.ToString = {
	format = function ( expect )
		if type( expect ) == 'string' then
			return 'ERROR: ' .. expect
		else
			local ret = {}
			for k, v in pairs( expect ) do
				ret[k] = tostring( v )
			end
			return deepToString( ret )
		end
	end,
	exec = function ( func, args )
		local got = { pcall( func, unpack( args ) ) }
		if table.remove( got, 1 ) then
			for k, v in pairs( got ) do
				got[k] = tostring( v )
			end
			return deepToString( got )
		else
			if string.sub( got[1], 1, 6 ) == 'SKIP: ' then
				error( got[1], 0 )
			end
			got = string.gsub( got[1], '^%S+:%d+: ', '' )
			return 'ERROR: ' .. got
		end
	end
}

-- This takes a list of tests to run, and returns the object used by PHP to
-- call them.
--
-- Each test is a table with the following keys:
--  name: Name of the test
--  expect: Table of results expected
--  func: Function to execute
--  args: (optional) Table of args to be unpacked and passed to the function
--  type: (optional) Formatter/Executor name, default "Normal"
function testframework.getTestProvider( tests )
	return {
		count = #tests,

		provide = function ( n )
			local t = tests[n]
			return n, t.name, testframework.types[t.type or 'Normal'].format( t.expect )
		end,

		run = function ( n )
			local t = tests[n]
			if not t then
				return 'Test ' .. name .. ' does not exist'
			end
			return testframework.types[t.type or 'Normal'].exec( t.func, t.args or {} )
		end,
	}
end

return testframework