/*
 * Decompiled with CFR 0.152.
 */
package org.matheclipse.core.reflection.system;

import org.matheclipse.core.eval.EvalEngine;
import org.matheclipse.core.eval.exception.Validate;
import org.matheclipse.core.eval.interfaces.AbstractFunctionEvaluator;
import org.matheclipse.core.expression.F;
import org.matheclipse.core.expression.IConstantHeaders;
import org.matheclipse.core.generic.Functors;
import org.matheclipse.core.interfaces.IAST;
import org.matheclipse.core.interfaces.IExpr;
import org.matheclipse.core.interfaces.IInteger;
import org.matheclipse.core.interfaces.INumber;
import org.matheclipse.core.reflection.system.Apart;
import org.matheclipse.core.reflection.system.Multinomial;
import org.matheclipse.generic.combinatoric.KPermutationsIterable;

public class Expand
extends AbstractFunctionEvaluator
implements IConstantHeaders {
    private static IAST assurePlus(IExpr expr) {
        IAST astPlus = null;
        if (expr.isPlus()) {
            return (IAST)expr;
        }
        if (astPlus == null) {
            astPlus = F.Plus();
            astPlus.add(expr);
        }
        return astPlus;
    }

    public static IExpr expand(IAST ast) {
        if (ast.isPower()) {
            IExpr header;
            if (ast.get(1) instanceof IAST && ast.get(2) instanceof IInteger && (header = ((IAST)ast.get(1)).head()) == F.Plus) {
                int exp = Validate.checkIntType(ast, 2, Integer.MIN_VALUE);
                if (exp < 0) {
                    return F.Power(Expand.expandPower((IAST)ast.get(1), exp *= -1), F.CN1);
                }
                return Expand.expandPower((IAST)ast.get(1), exp);
            }
        } else {
            if (ast.isTimes()) {
                IExpr[] temp = Apart.getFractionalPartsTimes(ast, false);
                if (temp[0].equals(F.C1)) {
                    if (temp[1].isTimes()) {
                        return F.Power(Expand.expandTimes((IAST)temp[1]), F.CN1);
                    }
                    return null;
                }
                if (temp[1].equals(F.C1)) {
                    return Expand.expandTimes(ast);
                }
                if (temp[0].isTimes()) {
                    temp[0] = Expand.expandTimes((IAST)temp[0]);
                }
                if (temp[1].isTimes()) {
                    temp[1] = Expand.expandTimes((IAST)temp[1]);
                }
                return F.Times(temp[0], (IExpr)F.Power(temp[1], F.CN1));
            }
            if (ast.isASTSizeGE(F.Plus, 3)) {
                return ast.map(Functors.replace1st(F.Expand(F.Null)));
            }
        }
        return null;
    }

    public static IExpr expandPower(IAST plusAST, int n) {
        if (n == 1) {
            return plusAST;
        }
        if (n == 0) {
            return F.C0;
        }
        IAST expandedResult = F.Plus();
        NumberPartititon part = new NumberPartititon(plusAST, n, expandedResult);
        part.partition();
        return expandedResult;
    }

    private static IExpr expandTimes(IAST timesAST) {
        IExpr result = (IExpr)timesAST.get(1);
        int i = 2;
        while (i < timesAST.size()) {
            result = Expand.expandTimesBinary(result, (IExpr)timesAST.get(i));
            ++i;
        }
        return result;
    }

    public static IExpr expandTimesBinary(IExpr expr0, IExpr expr1) {
        if (expr0.isNumber() && expr1.isPlus()) {
            return EvalEngine.eval(Expand.expandTimesPlus((INumber)expr0, (IAST)expr1));
        }
        IAST ast0 = Expand.assurePlus(expr0);
        IAST ast1 = Expand.assurePlus(expr1);
        return EvalEngine.eval(Expand.expandTimesPlus(ast0, ast1));
    }

    public static IAST expandTimesPlus(IAST expr0, IAST expr1) {
        IAST pList = F.Plus();
        int i = 1;
        while (i < expr0.size()) {
            expr1.args().map(pList, Functors.replace2nd(F.Times((IExpr)expr0.get(i), (IExpr)F.Null)));
            ++i;
        }
        return pList;
    }

    public static IAST expandTimesPlus(INumber expr1, IAST ast) {
        IAST pList = F.Plus();
        int i = 1;
        while (i < ast.size()) {
            pList.add(F.Times((IExpr)expr1, (IExpr)ast.get(i)));
            ++i;
        }
        return pList;
    }

    @Override
    public IExpr evaluate(IAST ast) {
        IAST arg1;
        IExpr temp;
        if (ast.size() != 2) {
            return null;
        }
        if (ast.get(1) instanceof IAST && (temp = Expand.expand(arg1 = (IAST)ast.get(1))) != null) {
            return temp;
        }
        return (IExpr)ast.get(1);
    }

    private static class NumberPartititon {
        IAST expandedResult;
        int m;
        int n;
        int[] parts;
        IAST precalculatedPowerASTs;

        public NumberPartititon(IAST plusAST, int n, IAST expandedResult) {
            this.expandedResult = expandedResult;
            this.n = n;
            this.m = plusAST.size() - 1;
            this.parts = new int[this.m];
            this.precalculatedPowerASTs = F.List();
            for (IExpr expr : plusAST) {
                this.precalculatedPowerASTs.add(F.Power(expr, F.Null));
            }
        }

        private void addFactor(int[] j) {
            KPermutationsIterable perm = new KPermutationsIterable(j, this.m, this.m);
            IInteger multinomial = F.integer(Multinomial.multinomial(j, this.n));
            IAST times = F.Times();
            for (int[] indices : perm) {
                IAST timesAST = times.clone();
                timesAST.add(multinomial);
                int k = 0;
                while (k < this.m) {
                    if (indices[k] != 0) {
                        IAST temp = this.precalculatedPowerASTs.getAST(k + 1).clone();
                        temp.set(2, F.integer(indices[k]));
                        timesAST.add(temp);
                    }
                    ++k;
                }
                this.expandedResult.add(timesAST);
            }
        }

        public void partition() {
            this.partition(this.n, this.n, 0);
        }

        private void partition(int n, int max, int currentIndex) {
            int min;
            if (n == 0) {
                this.addFactor(this.parts);
                return;
            }
            if (currentIndex >= this.m) {
                return;
            }
            int old = this.parts[currentIndex];
            int i = min = Math.min(max, n);
            while (i >= 1) {
                this.parts[currentIndex] = i;
                this.partition(n - i, i, currentIndex + 1);
                --i;
            }
            this.parts[currentIndex] = old;
        }
    }
}

