symbols.lua
require 'pl'
utils.import 'pl.func'
local ops = require 'pl.operator'
local List = require 'pl.List'
local append,concat = table.insert,table.concat
local compare,find_if,compare_no_order,imap,reduce,count_map = tablex.compare,tablex.find_if,tablex.compare_no_order,tablex.imap,tablex.reduce,tablex.count_map
local unpack = table.unpack
function bindval (self,val)
rawset(self,'value',val)
end
local optable = ops.optable
function sexpr (e)
if isPE(e) then
if e.op ~= 'X' then
local args = tablex.imap(sexpr,e)
return '('..e.op..' '..table.concat(args,' ')..')'
else
return e.repr
end
else
return tostring(e)
end
end
psexpr = compose(print,sexpr)
function equals (e1,e2)
local p1,p2 = isPE(e1),isPE(e2)
if p1 ~= p2 then return false end if p1 and p2 then if e1.op ~= e2.op then return false end
if e1.op == 'X' then return e1.repr == e2.repr
elseif e1.op == '+' or e1.op == '*' then
return compare_no_order(e1,e2,equals)
else
return compare(e1,e2,equals)
end
else return e1 == e2
end
end
function tcollect (op,e,ls)
if isPE(e) and e.op == op then
for i = 1,#e do
tcollect(op,e[i],ls)
end
else
ls:append(e)
return
end
end
function rcollect (e)
local res = List()
tcollect(e.op,e,res)
return res
end
function balance (e)
if isPE(e) and e.op ~= 'X' then
local op,args = e.op
if op == '+' or op == '*' then
args = rcollect(e)
else
args = imap(balance,e)
end
for i = 1,#args do
e[i] = args[i]
end
end
return e
end
function fold (e)
if isPE(e) then
if e.op == 'X' then
local val = rawget(e,'value')
return val and val or e
else
local op = e.op
local addmul = op == '*' or op == '+'
local args = imap(fold,e)
if not addmul and not find_if(args,isPE) then
local opfn = optable[op]
if opfn then
return opfn(unpack(args))
else
return '?'
end
elseif addmul then
local classes = List.partition(args,isPE)
local pe,npe = classes[true],classes[false]
if npe then if #npe == 1 then npe = npe[1]
else npe = npe:reduce(optable[op])
end
if not pe then return npe end
if op == '*' then
if npe == 0 then return 0
elseif npe == 1 then if #pe == 1 then return pe[1] else npe = nil end
end
else if npe == 0 then if #pe == 1 then return pe[1] else npe = nil end
end
end
end
local res = {}
if npe then append(res,npe) end
for val,count in pairs(count_map(pe,equals)) do
if count > 1 then
if op == '*' then val = val ^ count
else val = val * count
end
end
append(res,val)
end
if #res == 1 then return res[1] end
return PE{op=op,unpack(res)}
elseif op == '^' then
if args[2] == 1 then return args[1] end if args[2] == 0 then return 1 end
end
return PE{op=op,unpack(args)}
end
else
return e
end
end
function expand (e)
if isPE(e) and e.op == '*' and isPE(e[2]) and e[2].op == '+' then
local a,b = e[1],e[2]
return expand(b[1]*a) + expand(b[2]*a)
else
return e
end
end
function isnumber (x)
return type(x) == 'number'
end
function references (e,x)
if isPE(e) then
if e.op == 'X' then return x.repr == e.repr
else
return find_if(e,references,x)
end
else
return false
end
end
local function muli (args)
return PE{op='*',unpack(args)}
end
local function addi (args)
return PE{op='+',unpack(args)}
end
function diff (e,x)
if isPE(e) and references(e,x) then
local op = e.op
if op == 'X' then
return 1
else
local a,b = e[1],e[2]
if op == '+' then local args = imap(diff,e,x)
return balance(addi(args))
elseif op == '*' then local res,d,ee = {}
for i = 1,#e do
d = fold(diff(e[i],x))
if d ~= 0 then
ee = {unpack(e)}
ee[i] = d
append(res,balance(muli(ee)))
end
end
if #res > 1 then return addi(res)
else return res[1] end
elseif op == '^' and isnumber(b) then return b*x^(b-1)
end
end
else
return 0
end
end