summaryrefslogtreecommitdiff
path: root/www/wiki/extensions/Scribunto/includes/engines/LuaCommon/lualib/ustring/ustring.lua
diff options
context:
space:
mode:
Diffstat (limited to 'www/wiki/extensions/Scribunto/includes/engines/LuaCommon/lualib/ustring/ustring.lua')
-rw-r--r--www/wiki/extensions/Scribunto/includes/engines/LuaCommon/lualib/ustring/ustring.lua1242
1 files changed, 1242 insertions, 0 deletions
diff --git a/www/wiki/extensions/Scribunto/includes/engines/LuaCommon/lualib/ustring/ustring.lua b/www/wiki/extensions/Scribunto/includes/engines/LuaCommon/lualib/ustring/ustring.lua
new file mode 100644
index 00000000..69bfd1d0
--- /dev/null
+++ b/www/wiki/extensions/Scribunto/includes/engines/LuaCommon/lualib/ustring/ustring.lua
@@ -0,0 +1,1242 @@
+local ustring = {}
+
+-- Copy these, just in case
+local S = {
+ byte = string.byte,
+ char = string.char,
+ len = string.len,
+ sub = string.sub,
+ find = string.find,
+ match = string.match,
+ gmatch = string.gmatch,
+ gsub = string.gsub,
+ format = string.format,
+}
+
+---- Configuration ----
+-- To limit the length of strings or patterns processed, set these
+ustring.maxStringLength = math.huge
+ustring.maxPatternLength = math.huge
+
+---- Utility functions ----
+
+local function checkType( name, argidx, arg, expecttype, nilok )
+ if arg == nil and nilok then
+ return
+ end
+ if type( arg ) ~= expecttype then
+ local msg = S.format( "bad argument #%d to '%s' (%s expected, got %s)",
+ argidx, name, expecttype, type( arg )
+ )
+ error( msg, 3 )
+ end
+end
+
+local function checkString( name, s )
+ if type( s ) == 'number' then
+ s = tostring( s )
+ end
+ if type( s ) ~= 'string' then
+ local msg = S.format( "bad argument #1 to '%s' (string expected, got %s)",
+ name, type( s )
+ )
+ error( msg, 3 )
+ end
+ if S.len( s ) > ustring.maxStringLength then
+ local msg = S.format( "bad argument #1 to '%s' (string is longer than %d bytes)",
+ name, ustring.maxStringLength
+ )
+ error( msg, 3 )
+ end
+end
+
+local function checkPattern( name, pattern )
+ if type( pattern ) == 'number' then
+ pattern = tostring( pattern )
+ end
+ if type( pattern ) ~= 'string' then
+ local msg = S.format( "bad argument #2 to '%s' (string expected, got %s)",
+ name, type( pattern )
+ )
+ error( msg, 3 )
+ end
+ if S.len( pattern ) > ustring.maxPatternLength then
+ local msg = S.format( "bad argument #2 to '%s' (pattern is longer than %d bytes)",
+ name, ustring.maxPatternLength
+ )
+ error( msg, 3 )
+ end
+end
+
+-- A private helper that splits a string into codepoints, and also collects the
+-- starting position of each character and the total length in codepoints.
+--
+-- @param s string utf8-encoded string to decode
+-- @return table
+local function utf8_explode( s )
+ local ret = {
+ len = 0,
+ codepoints = {},
+ bytepos = {},
+ }
+
+ local i = 1
+ local l = S.len( s )
+ local cp, b, b2, trail
+ local min
+ while i <= l do
+ b = S.byte( s, i )
+ if b < 0x80 then
+ -- 1-byte code point, 00-7F
+ cp = b
+ trail = 0
+ min = 0
+ elseif b < 0xc2 then
+ -- Either a non-initial code point (invalid here) or
+ -- an overlong encoding for a 1-byte code point
+ return nil
+ elseif b < 0xe0 then
+ -- 2-byte code point, C2-DF
+ trail = 1
+ cp = b - 0xc0
+ min = 0x80
+ elseif b < 0xf0 then
+ -- 3-byte code point, E0-EF
+ trail = 2
+ cp = b - 0xe0
+ min = 0x800
+ elseif b < 0xf4 then
+ -- 4-byte code point, F0-F3
+ trail = 3
+ cp = b - 0xf0
+ min = 0x10000
+ elseif b == 0xf4 then
+ -- 4-byte code point, F4
+ -- Make sure it doesn't decode to over U+10FFFF
+ if S.byte( s, i + 1 ) > 0x8f then
+ return nil
+ end
+ trail = 3
+ cp = 4
+ min = 0x100000
+ else
+ -- Code point over U+10FFFF, or invalid byte
+ return nil
+ end
+
+ -- Check subsequent bytes for multibyte code points
+ for j = i + 1, i + trail do
+ b = S.byte( s, j )
+ if not b or b < 0x80 or b > 0xbf then
+ return nil
+ end
+ cp = cp * 0x40 + b - 0x80
+ end
+ if cp < min then
+ -- Overlong encoding
+ return nil
+ end
+
+ ret.codepoints[#ret.codepoints + 1] = cp
+ ret.bytepos[#ret.bytepos + 1] = i
+ ret.len = ret.len + 1
+ i = i + 1 + trail
+ end
+
+ -- Two past the end (for sub with empty string)
+ ret.bytepos[#ret.bytepos + 1] = l + 1
+ ret.bytepos[#ret.bytepos + 1] = l + 1
+
+ return ret
+end
+
+-- A private helper that finds the character offset for a byte offset.
+--
+-- @param cps table from utf8_explode
+-- @param i int byte offset
+-- @return int
+local function cpoffset( cps, i )
+ local min, max, p = 0, cps.len + 1
+ if i == 0 then
+ return 0
+ end
+ while min + 1 < max do
+ p = math.floor( ( min + max ) / 2 ) + 1
+ if cps.bytepos[p] <= i then
+ min = p - 1
+ end
+ if cps.bytepos[p] >= i then
+ max = p - 1
+ end
+ end
+ return min + 1
+end
+
+---- Trivial functions ----
+-- These functions are the same as the standard string versions
+
+ustring.byte = string.byte
+ustring.format = string.format
+ustring.rep = string.rep
+
+---- Non-trivial functions ----
+-- These functions actually have to be UTF-8 aware
+
+
+-- Determine if a string is valid UTF-8
+--
+-- @param s string
+-- @return boolean
+function ustring.isutf8( s )
+ checkString( 'isutf8', s )
+ return utf8_explode( s ) ~= nil
+end
+
+-- Return the byte offset of a character in a string
+--
+-- @param s string
+-- @param l int codepoint number [default 1]
+-- @param i int starting byte offset [default 1]
+-- @return int|nil
+function ustring.byteoffset( s, l, i )
+ checkString( 'byteoffset', s )
+ checkType( 'byteoffset', 2, l, 'number', true )
+ checkType( 'byteoffset', 3, i, 'number', true )
+ local cps = utf8_explode( s )
+ if cps == nil then
+ error( "bad argument #1 for 'byteoffset' (string is not UTF-8)", 2 )
+ end
+
+ i = i or 1
+ if i < 0 then
+ i = S.len( s ) + i + 1
+ end
+ if i < 1 or i > S.len( s ) then
+ return nil
+ end
+ local p = cpoffset( cps, i )
+ if l > 0 and cps.bytepos[p] == i then
+ l = l - 1
+ end
+ if p + l > cps.len then
+ return nil
+ end
+ return cps.bytepos[p + l]
+end
+
+-- Return codepoints from a string
+--
+-- @see string.byte
+-- @param s string
+-- @param i int Starting character [default 1]
+-- @param j int Ending character [default i]
+-- @return int* Zero or more codepoints
+function ustring.codepoint( s, i, j )
+ checkString( 'codepoint', s )
+ checkType( 'codepoint', 2, i, 'number', true )
+ checkType( 'codepoint', 3, j, 'number', true )
+ local cps = utf8_explode( s )
+ if cps == nil then
+ error( "bad argument #1 for 'codepoint' (string is not UTF-8)", 2 )
+ end
+ i = i or 1
+ if i < 0 then
+ i = cps.len + i + 1
+ end
+ j = j or i
+ if j < 0 then
+ j = cps.len + j + 1
+ end
+ if j < i then
+ return -- empty result set
+ end
+ i = math.max( 1, math.min( i, cps.len + 1 ) )
+ j = math.max( 1, math.min( j, cps.len + 1 ) )
+ return unpack( cps.codepoints, i, j )
+end
+
+-- Return an iterator over the codepoint (as integers)
+-- for cp in ustring.gcodepoint( s ) do ... end
+--
+-- @param s string
+-- @param i int Starting character [default 1]
+-- @param j int Ending character [default -1]
+-- @return function
+-- @return nil
+-- @return nil
+function ustring.gcodepoint( s, i, j )
+ checkString( 'gcodepoint', s )
+ checkType( 'gcodepoint', 2, i, 'number', true )
+ checkType( 'gcodepoint', 3, j, 'number', true )
+ local cps = utf8_explode( s )
+ if cps == nil then
+ error( "bad argument #1 for 'gcodepoint' (string is not UTF-8)", 2 )
+ end
+ i = i or 1
+ if i < 0 then
+ i = cps.len + i + 1
+ end
+ j = j or -1
+ if j < 0 then
+ j = cps.len + j + 1
+ end
+ if j < i then
+ return function ()
+ return nil
+ end
+ end
+ i = math.max( 1, math.min( i, cps.len + 1 ) )
+ j = math.max( 1, math.min( j, cps.len + 1 ) )
+ return function ()
+ if i <= j then
+ local ret = cps.codepoints[i]
+ i = i + 1
+ return ret
+ end
+ return nil
+ end
+end
+
+-- Convert codepoints to a string
+--
+-- @see string.char
+-- @param ... int List of codepoints
+-- @return string
+local function internalChar( t, s, e )
+ local ret = {}
+ for i = s, e do
+ local v = t[i]
+ if type( v ) ~= 'number' then
+ checkType( 'char', i, v, 'number' )
+ end
+ v = math.floor( v )
+ if v < 0 or v > 0x10ffff then
+ error( S.format( "bad argument #%d to 'char' (value out of range)", i ), 2 )
+ elseif v < 0x80 then
+ ret[#ret + 1] = v
+ elseif v < 0x800 then
+ ret[#ret + 1] = 0xc0 + math.floor( v / 0x40 ) % 0x20
+ ret[#ret + 1] = 0x80 + v % 0x40
+ elseif v < 0x10000 then
+ ret[#ret + 1] = 0xe0 + math.floor( v / 0x1000 ) % 0x10
+ ret[#ret + 1] = 0x80 + math.floor( v / 0x40 ) % 0x40
+ ret[#ret + 1] = 0x80 + v % 0x40
+ else
+ ret[#ret + 1] = 0xf0 + math.floor( v / 0x40000 ) % 0x08
+ ret[#ret + 1] = 0x80 + math.floor( v / 0x1000 ) % 0x40
+ ret[#ret + 1] = 0x80 + math.floor( v / 0x40 ) % 0x40
+ ret[#ret + 1] = 0x80 + v % 0x40
+ end
+ end
+ return S.char( unpack( ret ) )
+end
+function ustring.char( ... )
+ return internalChar( { ... }, 1, select( '#', ... ) )
+end
+
+-- Return the length of a string in codepoints, or
+-- nil if the string is not valid UTF-8.
+--
+-- @see string.len
+-- @param string
+-- @return int|nil
+function ustring.len( s )
+ checkString( 'len', s )
+ local cps = utf8_explode( s )
+ if cps == nil then
+ return nil
+ else
+ return cps.len
+ end
+end
+
+-- Private function to return a substring of a string
+--
+-- @param s string
+-- @param cps table Exploded string
+-- @param i int Starting character [default 1]
+-- @param j int Ending character [default -1]
+-- @return string
+local function sub( s, cps, i, j )
+ return S.sub( s, cps.bytepos[i], cps.bytepos[j+1] - 1 )
+end
+
+-- Return a substring of a string
+--
+-- @see string.sub
+-- @param s string
+-- @param i int Starting character [default 1]
+-- @param j int Ending character [default -1]
+-- @return string
+function ustring.sub( s, i, j )
+ checkString( 'sub', s )
+ checkType( 'sub', 2, i, 'number', true )
+ checkType( 'sub', 3, j, 'number', true )
+ local cps = utf8_explode( s )
+ if cps == nil then
+ error( "bad argument #1 for 'sub' (string is not UTF-8)", 2 )
+ end
+ i = i or 1
+ if i < 0 then
+ i = cps.len + i + 1
+ end
+ j = j or -1
+ if j < 0 then
+ j = cps.len + j + 1
+ end
+ if j < i then
+ return ''
+ end
+ i = math.max( 1, math.min( i, cps.len + 1 ) )
+ j = math.max( 1, math.min( j, cps.len + 1 ) )
+ return sub( s, cps, i, j )
+end
+
+---- Table-driven functions ----
+-- These functions load a conversion table when called
+
+-- Convert a string to uppercase
+--
+-- @see string.upper
+-- @param s string
+-- @return string
+function ustring.upper( s )
+ checkString( 'upper', s )
+ local map = require 'ustring/upper';
+ local ret = S.gsub( s, '([^\128-\191][\128-\191]*)', map )
+ return ret
+end
+
+-- Convert a string to lowercase
+--
+-- @see string.lower
+-- @param s string
+-- @return string
+function ustring.lower( s )
+ checkString( 'lower', s )
+ local map = require 'ustring/lower';
+ local ret = S.gsub( s, '([^\128-\191][\128-\191]*)', map )
+ return ret
+end
+
+---- Pattern functions ----
+-- Ugh. Just ugh.
+
+-- Cache for character sets (e.g. [a-z])
+local charset_cache = {}
+setmetatable( charset_cache, { __weak = 'kv' } )
+
+-- Private function to find a pattern in a string
+-- Yes, this basically reimplements the whole of Lua's pattern matching, in
+-- Lua.
+--
+-- @see ustring.find
+-- @param s string
+-- @param cps table Exploded string
+-- @param rawpat string Pattern
+-- @param pattern table Exploded pattern
+-- @param init int Starting index
+-- @param noAnchor boolean True to ignore '^'
+-- @return int starting index of the match
+-- @return int ending index of the match
+-- @return string|int* captures
+local function find( s, cps, rawpat, pattern, init, noAnchor )
+ local charsets = require 'ustring/charsets'
+ local anchor = false
+ local ncapt, captures
+ local captparen = {}
+
+ -- Extract the value of a capture from the
+ -- upvalues ncapt and capture.
+ local function getcapt( n, err, errl )
+ if n > ncapt then
+ error( err, errl + 1 )
+ elseif type( captures[n] ) == 'table' then
+ if captures[n][2] == '' then
+ error( err, errl + 1 )
+ end
+ return sub( s, cps, captures[n][1], captures[n][2] ), captures[n][2] - captures[n][1] + 1
+ else
+ return captures[n], math.floor( math.log10( captures[n] ) ) + 1
+ end
+ end
+
+ local match, match_charset, parse_charset
+
+ -- Main matching function. Uses tail recursion where possible.
+ -- Returns the position of the character after the match, and updates the
+ -- upvalues ncapt and captures.
+ match = function ( sp, pp )
+ local c = pattern.codepoints[pp]
+ if c == 0x28 then -- '(': starts capture group
+ ncapt = ncapt + 1
+ captparen[ncapt] = pp
+ local ret
+ if pattern.codepoints[pp + 1] == 0x29 then -- ')': Pattern is '()', capture position
+ captures[ncapt] = sp
+ ret = match( sp, pp + 2 )
+ else
+ -- Start capture group
+ captures[ncapt] = { sp, '' }
+ ret = match( sp, pp + 1 )
+ end
+ if ret then
+ return ret
+ else
+ -- Failed, rollback
+ ncapt = ncapt - 1
+ return nil
+ end
+ elseif c == 0x29 then -- ')': ends capture group, pop current capture index from stack
+ for n = ncapt, 1, -1 do
+ if type( captures[n] ) == 'table' and captures[n][2] == '' then
+ captures[n][2] = sp - 1
+ local ret = match( sp, pp + 1 )
+ if ret then
+ return ret
+ else
+ -- Failed, rollback
+ captures[n][2] = ''
+ return nil
+ end
+ end
+ end
+ error( 'Unmatched close-paren at pattern character ' .. pp, 3 )
+ elseif c == 0x5b then -- '[': starts character set
+ return match_charset( sp, parse_charset( pp ) )
+ elseif c == 0x5d then -- ']'
+ error( 'Unmatched close-bracket at pattern character ' .. pp, 3 )
+ elseif c == 0x25 then -- '%'
+ c = pattern.codepoints[pp + 1]
+ if charsets[c] then -- A character set like '%a'
+ return match_charset( sp, pp + 2, charsets[c] )
+ elseif c == 0x62 then -- '%b': balanced delimiter match
+ local d1 = pattern.codepoints[pp + 2]
+ local d2 = pattern.codepoints[pp + 3]
+ if not d1 or not d2 then
+ error( 'malformed pattern (missing arguments to \'%b\')', 3 )
+ end
+ if cps.codepoints[sp] ~= d1 then
+ return nil
+ end
+ sp = sp + 1
+ local ct = 1
+ while true do
+ c = cps.codepoints[sp]
+ sp = sp + 1
+ if not c then
+ return nil
+ elseif c == d2 then
+ if ct == 1 then
+ return match( sp, pp + 4 )
+ end
+ ct = ct - 1
+ elseif c == d1 then
+ ct = ct + 1
+ end
+ end
+ elseif c == 0x66 then -- '%f': frontier pattern match
+ if pattern.codepoints[pp + 2] ~= 0x5b then
+ error( 'missing \'[\' after %f in pattern at pattern character ' .. pp, 3 )
+ end
+ local pp, charset = parse_charset( pp + 2 )
+ local c1 = cps.codepoints[sp - 1] or 0
+ local c2 = cps.codepoints[sp] or 0
+ if not charset[c1] and charset[c2] then
+ return match( sp, pp )
+ else
+ return nil
+ end
+ elseif c >= 0x30 and c <= 0x39 then -- '%0' to '%9': backreference
+ local m, l = getcapt( c - 0x30, 'invalid capture index %' .. c .. ' at pattern character ' .. pp, 3 )
+ local ep = math.min( cps.len + 1, sp + l )
+ if sub( s, cps, sp, ep - 1 ) == m then
+ return match( ep, pp + 2 )
+ else
+ return nil
+ end
+ elseif not c then -- percent at the end of the pattern
+ error( 'malformed pattern (ends with \'%\')', 3 )
+ else -- something else, treat as a literal
+ return match_charset( sp, pp + 2, { [c] = 1 } )
+ end
+ elseif c == 0x2e then -- '.': match anything
+ if not charset_cache['.'] then
+ local t = {}
+ setmetatable( t, { __index = function ( t, k ) return k end } )
+ charset_cache['.'] = { 1, t }
+ end
+ return match_charset( sp, pp + 1, charset_cache['.'][2] )
+ elseif c == nil then -- end of pattern
+ return sp
+ elseif c == 0x24 and pattern.len == pp then -- '$': assert end of string
+ return ( sp == cps.len + 1 ) and sp or nil
+ else
+ -- Any other character matches itself
+ return match_charset( sp, pp + 1, { [c] = 1 } )
+ end
+ end
+
+ -- Parse a bracketed character set (e.g. [a-z])
+ -- Returns the position after the set and a table holding the matching characters
+ parse_charset = function ( pp )
+ local _, ep
+ local epp = pattern.bytepos[pp] + 1
+ if S.sub( rawpat, epp, epp ) == '^' then
+ epp = epp + 1
+ end
+ if S.sub( rawpat, epp, epp ) == ']' then
+ -- Lua's string module effectively does this
+ epp = epp + 1
+ end
+ repeat
+ _, ep = S.find( rawpat, ']', epp, true )
+ if not ep then
+ error( 'Missing close-bracket for character set beginning at pattern character ' .. pp, 3 )
+ end
+ epp = ep + 1
+ until S.byte( rawpat, ep - 1 ) ~= 0x25 or S.byte( rawpat, ep - 2 ) == 0x25
+ local key = S.sub( rawpat, pattern.bytepos[pp], ep )
+ if charset_cache[key] then
+ local pl, cs = unpack( charset_cache[key] )
+ return pp + pl, cs
+ end
+
+ local p0 = pp
+ local cs = {}
+ local csrefs = { cs }
+ local invert = false
+ pp = pp + 1
+ if pattern.codepoints[pp] == 0x5e then -- '^'
+ invert = true
+ pp = pp + 1
+ end
+ local first = true
+ while true do
+ local c = pattern.codepoints[pp]
+ if not first and c == 0x5d then -- closing ']'
+ pp = pp + 1
+ break
+ elseif c == 0x25 then -- '%'
+ c = pattern.codepoints[pp + 1]
+ if charsets[c] then
+ csrefs[#csrefs + 1] = charsets[c]
+ else
+ cs[c] = 1
+ end
+ pp = pp + 2
+ elseif pattern.codepoints[pp + 1] == 0x2d and pattern.codepoints[pp + 2] and pattern.codepoints[pp + 2] ~= 0x5d then -- '-' followed by another char (not ']'), it's a range
+ for i = c, pattern.codepoints[pp + 2] do
+ cs[i] = 1
+ end
+ pp = pp + 3
+ elseif not c then -- Should never get here, but Just In Case...
+ error( 'Missing close-bracket', 3 )
+ else
+ cs[c] = 1
+ pp = pp + 1
+ end
+ first = false
+ end
+
+ local ret
+ if not csrefs[2] then
+ if not invert then
+ -- If there's only the one charset table, we can use it directly
+ ret = cs
+ else
+ -- Simple invert
+ ret = {}
+ setmetatable( ret, { __index = function ( t, k ) return k and not cs[k] end } )
+ end
+ else
+ -- Ok, we have to iterate over multiple charset tables
+ ret = {}
+ setmetatable( ret, { __index = function ( t, k )
+ if not k then
+ return nil
+ end
+ for i = 1, #csrefs do
+ if csrefs[i][k] then
+ return not invert
+ end
+ end
+ return invert
+ end } )
+ end
+
+ charset_cache[key] = { pp - p0, ret }
+ return pp, ret
+ end
+
+ -- Match a character set table with optional quantifier, followed by
+ -- the rest of the pattern.
+ -- Returns same as 'match' above.
+ match_charset = function ( sp, pp, charset )
+ local q = pattern.codepoints[pp]
+ if q == 0x2a then -- '*', 0 or more matches
+ pp = pp + 1
+ local i = 0
+ while charset[cps.codepoints[sp + i]] do
+ i = i + 1
+ end
+ while i >= 0 do
+ local ret = match( sp + i, pp )
+ if ret then
+ return ret
+ end
+ i = i - 1
+ end
+ return nil
+ elseif q == 0x2b then -- '+', 1 or more matches
+ pp = pp + 1
+ local i = 0
+ while charset[cps.codepoints[sp + i]] do
+ i = i + 1
+ end
+ while i > 0 do
+ local ret = match( sp + i, pp )
+ if ret then
+ return ret
+ end
+ i = i - 1
+ end
+ return nil
+ elseif q == 0x2d then -- '-', 0 or more matches non-greedy
+ pp = pp + 1
+ while true do
+ local ret = match( sp, pp )
+ if ret then
+ return ret
+ end
+ if not charset[cps.codepoints[sp]] then
+ return nil
+ end
+ sp = sp + 1
+ end
+ elseif q == 0x3f then -- '?', 0 or 1 match
+ pp = pp + 1
+ if charset[cps.codepoints[sp]] then
+ local ret = match( sp + 1, pp )
+ if ret then
+ return ret
+ end
+ end
+ return match( sp, pp )
+ else -- no suffix, must match 1
+ if charset[cps.codepoints[sp]] then
+ return match( sp + 1, pp )
+ else
+ return nil
+ end
+ end
+ end
+
+ init = init or 1
+ if init < 0 then
+ init = cps.len + init + 1
+ end
+ init = math.max( 1, math.min( init, cps.len + 1 ) )
+
+ -- Here is the actual match loop. It just calls 'match' on successive
+ -- starting positions (or not, if the pattern is anchored) until it finds a
+ -- match.
+ local sp = init
+ local pp = 1
+ if not noAnchor and pattern.codepoints[1] == 0x5e then -- '^': Pattern is anchored
+ anchor = true
+ pp = 2
+ end
+
+ repeat
+ ncapt, captures = 0, {}
+ local ep = match( sp, pp )
+ if ep then
+ for i = 1, ncapt do
+ captures[i] = getcapt( i, 'Unclosed capture beginning at pattern character ' .. captparen[i], 2 )
+ end
+ return sp, ep - 1, unpack( captures )
+ end
+ sp = sp + 1
+ until anchor or sp > cps.len + 1
+ return nil
+end
+
+-- Private function to decide if a pattern looks simple enough to use
+-- Lua's built-in string library. The following make a pattern not simple:
+-- * If it contains any bytes over 0x7f. We could skip these if they're not
+-- inside brackets and aren't followed by quantifiers and aren't part of a
+-- '%b', but that's too complicated to check.
+-- * If it contains a negated character set.
+-- * If it contains "%a" or any of the other %-prefixed character sets except %z.
+-- * If it contains a '.' not followed by '*', '+', '-'. A bare '.' or '.?'
+-- matches a partial UTF-8 character, but the others will happily enough
+-- match a whole UTF-8 character thinking it's 2, 3 or 4.
+-- * If it contains position-captures.
+-- * If it matches the empty string
+--
+-- @param string pattern
+-- @return boolean
+local function patternIsSimple( pattern )
+ local findWithPcall = function ( ... )
+ local ok, ret = pcall( S.find, ... )
+ return ok and ret
+ end
+
+ return not (
+ S.find( pattern, '[\128-\255]' ) or
+ S.find( pattern, '%[%^' ) or
+ S.find( pattern, '%%[acdlpsuwxACDLPSUWXZ]' ) or
+ S.find( pattern, '%.[^*+-]' ) or S.find( pattern, '%.$' ) or
+ S.find( pattern, '()', 1, true ) or
+ pattern == '' or findWithPcall( '', pattern )
+ )
+end
+
+-- Find a pattern in a string
+--
+-- This works just like string.find, with the following changes:
+-- * Everything works on UTF-8 characters rather than bytes
+-- * Character classes are redefined in terms of Unicode properties:
+-- * %a - Letter
+-- * %c - Control
+-- * %d - Decimal Number
+-- * %l - Lower case letter
+-- * %p - Punctuation
+-- * %s - Separator, plus HT, LF, FF, CR, and VT
+-- * %u - Upper case letter
+-- * %w - Letter or Decimal Number
+-- * %x - [0-9A-Fa-f0-9A-Fa-f]
+--
+-- @see string.find
+-- @param s string
+-- @param pattern string Pattern
+-- @param init int Starting index
+-- @param plain boolean Literal match, no pattern matching
+-- @return int starting index of the match
+-- @return int ending index of the match
+-- @return string|int* captures
+function ustring.find( s, pattern, init, plain )
+ checkString( 'find', s )
+ checkPattern( 'find', pattern )
+ checkType( 'find', 3, init, 'number', true )
+ checkType( 'find', 4, plain, 'boolean', true )
+ local cps = utf8_explode( s )
+ if cps == nil then
+ error( "bad argument #1 for 'find' (string is not UTF-8)", 2 )
+ end
+ local pat = utf8_explode( pattern )
+ if pat == nil then
+ error( "bad argument #2 for 'find' (string is not UTF-8)", 2 )
+ end
+
+ if plain or patternIsSimple( pattern ) then
+ if init and init > cps.len + 1 then
+ init = cps.len + 1
+ end
+ local m
+ if plain then
+ m = { true, S.find( s, pattern, cps.bytepos[init], plain ) }
+ else
+ m = { pcall( S.find, s, pattern, cps.bytepos[init], plain ) }
+ end
+ if m[1] then
+ if m[2] then
+ m[2] = cpoffset( cps, m[2] )
+ m[3] = cpoffset( cps, m[3] )
+ end
+ return unpack( m, 2 )
+ end
+ end
+
+ return find( s, cps, pattern, pat, init )
+end
+
+-- Match a string against a pattern
+--
+-- @see ustring.find
+-- @see string.match
+-- @param s string
+-- @param pattern string
+-- @param init int Starting offset for match
+-- @return string|int* captures, or the whole match if there are none
+function ustring.match( s, pattern, init )
+ checkString( 'match', s )
+ checkPattern( 'match', pattern )
+ checkType( 'match', 3, init, 'number', true )
+ local cps = utf8_explode( s )
+ if cps == nil then
+ error( "bad argument #1 for 'match' (string is not UTF-8)", 2 )
+ end
+ local pat = utf8_explode( pattern )
+ if pat == nil then
+ error( "bad argument #2 for 'match' (string is not UTF-8)", 2 )
+ end
+
+ if patternIsSimple( pattern ) then
+ local ret = { pcall( S.match, s, pattern, cps.bytepos[init] ) }
+ if ret[1] then
+ return unpack( ret, 2 )
+ end
+ end
+
+ local m = { find( s, cps, pattern, pat, init ) }
+ if not m[1] then
+ return nil
+ end
+ if m[3] then
+ return unpack( m, 3 )
+ end
+ return sub( s, cps, m[1], m[2] )
+end
+
+-- Return an iterator function over the matches for a pattern
+--
+-- @see ustring.find
+-- @see string.gmatch
+-- @param s string
+-- @param pattern string
+-- @return function
+-- @return nil
+-- @return nil
+function ustring.gmatch( s, pattern )
+ checkString( 'gmatch', s )
+ checkPattern( 'gmatch', pattern )
+ if patternIsSimple( pattern ) then
+ local ret = { pcall( S.gmatch, s, pattern ) }
+ if ret[1] then
+ return unpack( ret, 2 )
+ end
+ end
+
+ local cps = utf8_explode( s )
+ if cps == nil then
+ error( "bad argument #1 for 'gmatch' (string is not UTF-8)", 2 )
+ end
+ local pat = utf8_explode( pattern )
+ if pat == nil then
+ error( "bad argument #2 for 'gmatch' (string is not UTF-8)", 2 )
+ end
+ local init = 1
+
+ return function ()
+ local m = { find( s, cps, pattern, pat, init, true ) }
+ if not m[1] then
+ return nil
+ end
+ init = m[2] + 1
+ if m[3] then
+ return unpack( m, 3 )
+ end
+ return sub( s, cps, m[1], m[2] )
+ end
+end
+
+-- Replace pattern matches in a string
+--
+-- @see ustring.find
+-- @see string.gsub
+-- @param s string
+-- @param pattern string
+-- @param repl string|function|table
+-- @param int n
+-- @return string
+-- @return int
+function ustring.gsub( s, pattern, repl, n )
+ checkString( 'gsub', s )
+ checkPattern( 'gsub', pattern )
+ checkType( 'gsub', 4, n, 'number', true )
+ if patternIsSimple( pattern ) then
+ local ret = { pcall( S.gsub, s, pattern, repl, n ) }
+ if ret[1] then
+ return unpack( ret, 2 )
+ end
+ end
+
+ if n == nil then
+ n = 1e100
+ end
+ if n < 1 then
+ -- No replacement
+ return s, 0
+ end
+
+ local cps = utf8_explode( s )
+ if cps == nil then
+ error( "bad argument #1 for 'gsub' (string is not UTF-8)", 2 )
+ end
+ local pat = utf8_explode( pattern )
+ if pat == nil then
+ error( "bad argument #2 for 'gsub' (string is not UTF-8)", 2 )
+ end
+
+ if pat.codepoints[1] == 0x5e then -- '^': Pattern is anchored
+ -- There can be only the one match, so make that explicit
+ n = 1
+ end
+
+ local tp
+ if type( repl ) == 'function' then
+ tp = 1
+ elseif type( repl ) == 'table' then
+ tp = 2
+ elseif type( repl ) == 'string' then
+ tp = 3
+ elseif type( repl ) == 'number' then
+ repl = tostring( repl )
+ tp = 3
+ else
+ checkType( 'gsub', 3, repl, 'function or table or string' )
+ end
+
+ local init = 1
+ local ct = 0
+ local ret = {}
+ local zeroAdjustment = 0
+ repeat
+ local m = { find( s, cps, pattern, pat, init + zeroAdjustment ) }
+ if not m[1] then
+ break
+ end
+ if init < m[1] then
+ ret[#ret + 1] = sub( s, cps, init, m[1] - 1 )
+ end
+ local mm = sub( s, cps, m[1], m[2] )
+ local val
+ if tp == 1 then
+ if m[3] then
+ val = repl( unpack( m, 3 ) )
+ else
+ val = repl( mm )
+ end
+ elseif tp == 2 then
+ val = repl[m[3] or mm]
+ elseif tp == 3 then
+ if ct == 0 and #m < 11 then
+ local ss = S.gsub( repl, '%%[%%0-' .. ( #m - 2 ) .. ']', 'x' )
+ ss = S.match( ss, '%%[0-9]' )
+ if ss then
+ error( 'invalid capture index ' .. ss .. ' in replacement string', 2 )
+ end
+ end
+ local t = {
+ ["%0"] = mm,
+ ["%1"] = m[3],
+ ["%2"] = m[4],
+ ["%3"] = m[5],
+ ["%4"] = m[6],
+ ["%5"] = m[7],
+ ["%6"] = m[8],
+ ["%7"] = m[9],
+ ["%8"] = m[10],
+ ["%9"] = m[11],
+ ["%%"] = "%"
+ }
+ val = S.gsub( repl, '%%[%%0-9]', t )
+ end
+ ret[#ret + 1] = val or mm
+ init = m[2] + 1
+ ct = ct + 1
+ zeroAdjustment = m[2] < m[1] and 1 or 0
+ until init > cps.len or ct >= n
+ if init <= cps.len then
+ ret[#ret + 1] = sub( s, cps, init, cps.len )
+ end
+ return table.concat( ret ), ct
+end
+
+---- Unicode Normalization ----
+-- These functions load a conversion table when called
+
+local function internalDecompose( cps, decomp )
+ local cp = {}
+ local normal = require 'ustring/normalization-data'
+
+ -- Decompose into cp, using the lookup table and logic for hangul
+ for i = 1, cps.len do
+ local c = cps.codepoints[i]
+ local m = decomp[c]
+ if m then
+ for j = 0, #m do
+ cp[#cp + 1] = m[j]
+ end
+ else
+ cp[#cp + 1] = c
+ end
+ end
+
+ -- Now sort combiners by class
+ local i, l = 1, #cp
+ while i < l do
+ local cc1 = normal.combclass[cp[i]]
+ local cc2 = normal.combclass[cp[i+1]]
+ if cc1 and cc2 and cc1 > cc2 then
+ cp[i], cp[i+1] = cp[i+1], cp[i]
+ if i > 1 then
+ i = i - 1
+ else
+ i = i + 1
+ end
+ else
+ i = i + 1
+ end
+ end
+
+ return cp, 1, l
+end
+
+local function internalCompose( cp, _, l )
+ local normal = require 'ustring/normalization-data'
+
+ -- Since NFD->NFC can never expand a character sequence, we can do this
+ -- in-place.
+ local comp = normal.comp[cp[1]]
+ local sc = 1
+ local j = 1
+ local lastclass = 0
+ for i = 2, l do
+ local c = cp[i]
+ local ccc = normal.combclass[c]
+ if ccc then
+ -- Trying a combiner with the starter
+ if comp and lastclass < ccc and comp[c] then
+ -- Yes!
+ c = comp[c]
+ cp[sc] = c
+ comp = normal.comp[c]
+ else
+ -- No, copy it to the right place for output
+ j = j + 1
+ cp[j] = c
+ lastclass = ccc
+ end
+ elseif comp and lastclass == 0 and comp[c] then
+ -- Combining two adjacent starters
+ c = comp[c]
+ cp[sc] = c
+ comp = normal.comp[c]
+ else
+ -- New starter, doesn't combine
+ j = j + 1
+ cp[j] = c
+ comp = normal.comp[c]
+ sc = j
+ lastclass = 0
+ end
+ end
+
+ return cp, 1, j
+end
+
+-- Normalize a string to NFC
+--
+-- Based on MediaWiki's UtfNormal class. Returns nil if the string is not valid
+-- UTF-8.
+--
+-- @param s string
+-- @return string|nil
+function ustring.toNFC( s )
+ checkString( 'toNFC', s )
+
+ -- ASCII is always NFC
+ if not S.find( s, '[\128-\255]' ) then
+ return s
+ end
+
+ local cps = utf8_explode( s )
+ if cps == nil then
+ return nil
+ end
+ local normal = require 'ustring/normalization-data'
+
+ -- First, scan through to see if the string is definitely already NFC
+ local ok = true
+ for i = 1, cps.len do
+ local c = cps.codepoints[i]
+ if normal.check[c] then
+ ok = false
+ break
+ end
+ end
+ if ok then
+ return s
+ end
+
+ -- Next, expand to NFD then recompose
+ return internalChar( internalCompose( internalDecompose( cps, normal.decomp ) ) )
+end
+
+-- Normalize a string to NFD
+--
+-- Based on MediaWiki's UtfNormal class. Returns nil if the string is not valid
+-- UTF-8.
+--
+-- @param s string
+-- @return string|nil
+function ustring.toNFD( s )
+ checkString( 'toNFD', s )
+
+ -- ASCII is always NFD
+ if not S.find( s, '[\128-\255]' ) then
+ return s
+ end
+
+ local cps = utf8_explode( s )
+ if cps == nil then
+ return nil
+ end
+
+ local normal = require 'ustring/normalization-data'
+ return internalChar( internalDecompose( cps, normal.decomp ) )
+end
+
+-- Normalize a string to NFKC
+--
+-- Based on MediaWiki's UtfNormal class. Returns nil if the string is not valid
+-- UTF-8.
+--
+-- @param s string
+-- @return string|nil
+function ustring.toNFKC( s )
+ checkString( 'toNFKC', s )
+
+ -- ASCII is always NFKC
+ if not S.find( s, '[\128-\255]' ) then
+ return s
+ end
+
+ local cps = utf8_explode( s )
+ if cps == nil then
+ return nil
+ end
+ local normal = require 'ustring/normalization-data'
+
+ -- Next, expand to NFKD then recompose
+ return internalChar( internalCompose( internalDecompose( cps, normal.decompK ) ) )
+end
+
+-- Normalize a string to NFKD
+--
+-- Based on MediaWiki's UtfNormal class. Returns nil if the string is not valid
+-- UTF-8.
+--
+-- @param s string
+-- @return string|nil
+function ustring.toNFKD( s )
+ checkString( 'toNFKD', s )
+
+ -- ASCII is always NFKD
+ if not S.find( s, '[\128-\255]' ) then
+ return s
+ end
+
+ local cps = utf8_explode( s )
+ if cps == nil then
+ return nil
+ end
+
+ local normal = require 'ustring/normalization-data'
+ return internalChar( internalDecompose( cps, normal.decompK ) )
+end
+
+return ustring