"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
const utils_1 = require("./utils");
exports.All = Object.freeze({});
exports.NewAxis = null;
function range(start, end, step) {
    if (start != null && end == null) {
        end = start;
        start = 0;
    }
    if (step == null) {
        step = 1;
    }
    return { start, end, step };
}
exports.range = range;
function* enumerateRange(range, size) {
    let { start, end, step } = range;
    if (step == null) {
        step = 1;
    }
    if (start == null) {
        if (step > 0) {
            start = -Infinity;
        }
        else {
            start = Infinity;
        }
    }
    else if (start < 0) {
        start = size + start;
    }
    start = Math.floor(start);
    if (end == null) {
        if (step > 0) {
            end = Infinity;
        }
        else {
            end = -Infinity;
        }
    }
    else if (end < 0) {
        end = size + end;
    }
    end = Math.floor(end);
    step = Math.floor(step);
    if (step === 0) {
        throw new Error('Step cannot be 0');
    }
    if (step > 0) {
        for (let i = Math.max(start, 0); i < end && i < size; i += step) {
            yield i;
        }
    }
    else {
        for (let i = Math.min(start, size - 1); i > end && i >= 0; i += step) {
            yield i;
        }
    }
}
function enumerateRanges(ranges, shape) {
    if (ranges.length !== shape.length) {
        throw new Error('Sizes of ranges and shape are different');
    }
    if (ranges.length === 0) {
        return [];
    }
    return ranges.reduce(function* (a, r, i) {
        for (const [idx1, idx2] of a) {
            let k = 0;
            for (const j of enumerateRange(r, shape[i])) {
                yield [idx1.concat([j]), idx2.concat([k])];
                k++;
            }
        }
    }, [[[], []]]);
}
function countRange(r, size) {
    let c = 0;
    for (const i of enumerateRange(r, size)) {
        c++;
    }
    return c;
}
class NDArray {
    constructor(data, shape) {
        if (data.length !== shapeProduct(shape)) {
            throw new Error('invalid array and shape');
        }
        this.data = data;
        this.shape = shape;
        this.size = shapeProduct(shape);
    }
    static empty(shape) {
        return new NDArray([], []);
    }
    get(indices = 0) {
        if (typeof indices === 'number') {
            indices = [indices];
        }
        const idx = flattenIndices(indices, this.shape);
        return this.data[idx];
    }
    set(indices, value) {
        if (typeof indices === 'number') {
            indices = [indices];
        }
        const idx = flattenIndices(indices, this.shape);
        this.data[idx] = value;
    }
    update(indices, updater) {
        if (typeof indices === 'number') {
            indices = [indices];
        }
        const idx = flattenIndices(indices, this.shape);
        this.data[idx] = updater(this.data[idx]);
    }
    reshape(shape) {
        const i = shape.indexOf(-1);
        if (i !== -1) {
            const p = shapeProduct(this.shape);
            const q = shapeProduct(shape.filter((n) => n >= 0));
            if (p % q === 0) {
                shape[i] = p / q;
            }
        }
        if (!isReshapable(this.shape, shape)) {
            throw new Error('incompatible shape');
        }
        return new NDArray(this.data.slice(), shape);
    }
    transpose(axes) {
        if (axes) {
            if (axes.length !== this.shape.length) {
                throw new Error('Invalid axes');
            }
            for (let i = 0; i < this.shape.length; i++) {
                if (!axes.includes(i)) {
                    throw new Error('Invalid axes');
                }
            }
        }
        else {
            axes = utils_1.range(this.shape.length).reverse();
        }
        const resultShape = axes.map((s) => this.shape[s]);
        const newArray = zeros(resultShape);
        for (let idx of enumerateIndices(this.shape)) {
            const resultIndex = axes.map((s) => idx[s]);
            newArray.set(resultIndex, this.get(idx));
        }
        return newArray;
    }
    swapAxes(a1, a2) {
        if (a1 < 0 || a1 >= this.shape.length) {
            throw new Error('Invalid axis 1');
        }
        if (a2 < 0 || a2 >= this.shape.length) {
            throw new Error('Invalid axis 2');
        }
        const i = utils_1.range(this.shape.length);
        i[a1] = a2;
        i[a2] = a1;
        return this.transpose(i);
    }
    add(x, out) {
        return exports.add(this, x, out);
    }
    sub(x, out) {
        return exports.sub(this, x, out);
    }
    mul(x, out) {
        return exports.mul(this, x, out);
    }
    div(x, out) {
        return exports.div(this, x, out);
    }
    pow(x, out) {
        return exports.pow(this, x, out);
    }
    neg(out) {
        return exports.neg(this, out);
    }
    argMin(axis) {
        return argMin(this, axis);
    }
    argMax(axis) {
        return argMax(this, axis);
    }
    slice(...indexOrRanges) {
        const indexOrRanges_ = indexOrRanges.filter((e) => e != null);
        if (indexOrRanges_.length > this.shape.length) {
            throw new Error('Too many indices');
        }
        while (indexOrRanges_.length < this.shape.length) {
            indexOrRanges_.push(exports.All);
            indexOrRanges.push(exports.All);
        }
        const ranges = indexOrRanges_.map((ir) => {
            if (typeof ir === 'number') {
                return range(ir, ir + 1);
            }
            return ir;
        });
        const resultShape = ranges.map((r, i) => countRange(r, this.shape[i]));
        let result;
        if (isSameShape(this.shape, resultShape)) {
            result = new NDArray(this.data.slice(), this.shape);
        }
        else {
            result = zeros(resultShape);
            if (result.size > 0) {
                for (const [idx1, idx2] of enumerateRanges(ranges, this.shape)) {
                    result.set(idx2, this.get(idx1));
                }
            }
        }
        const normalizedShape = [];
        let i = 0;
        for (const ir of indexOrRanges) {
            if (ir == null) {
                normalizedShape.push(1);
            }
            else {
                if (typeof ir !== 'number') {
                    normalizedShape.push(resultShape[i]);
                }
                i++;
            }
        }
        return result.reshape(normalizedShape);
    }
    sum(axes) {
        return sum(this, axes);
    }
    prod(axes) {
        return prod(this, axes);
    }
    mean(axes) {
        return mean(this, axes);
    }
    clip(min, max, out) {
        return clip(this, min, max, out);
    }
    min(axisOrAxes) {
        return min(this, axisOrAxes);
    }
    max(axisOrAxes) {
        return max(this, axisOrAxes);
    }
    operateUnary(f, out) {
        return operateUnary(f, this, out);
    }
    map(f, out) {
        return operateUnary(f, this, out);
    }
}
exports.NDArray = NDArray;
function isSameShape(shape1, shape2) {
    return shape1.length === shape2.length && shape1.every((s1, i) => s1 === shape2[i]);
}
function repeat(x, shapeOrNumber) {
    const shape = typeof shapeOrNumber === 'number' ? [shapeOrNumber] : shapeOrNumber;
    if (!isValidShape(shape)) {
        throw new Error('invalid shape');
    }
    const p = shapeProduct(shape);
    return new NDArray(new Array(p).fill(x), shape);
}
exports.repeat = repeat;
function zeros(shapeOrNumber) {
    return repeat(0, shapeOrNumber);
}
exports.zeros = zeros;
function zerosLike(array) {
    return zeros(array.shape);
}
exports.zerosLike = zerosLike;
function shapeProduct(indices) {
    if (indices.length === 0) {
        return 0;
    }
    return indices.reduce((a, b) => a * b, 1);
}
function isValidShape(shape) {
    return shape.every((n) => Number.isFinite(n) && n >= 0);
}
function isReshapable(oldShape, newShape) {
    return (isValidShape(oldShape) &&
        isValidShape(newShape) &&
        shapeProduct(oldShape) === shapeProduct(newShape));
}
function* enumerateIndices(shape) {
    const p = shapeProduct(shape);
    const n = shape.length;
    if (n === 0 || p === 0) {
        return;
    }
    const indices = new Array(n);
    for (let i = 0; i < p; i++) {
        let k = i;
        for (let j = n - 1; j > 0; j--) {
            const s = shape[j];
            const m = (indices[j] = k % s);
            k = (k - m) / s;
        }
        indices[0] = k;
        yield indices;
    }
}
function createArray(raw) {
    const shape = [];
    for (let a = raw; Array.isArray(a); a = a[0]) {
        shape.push(a.length);
    }
    const data = flatten(raw);
    if (data.length !== shapeProduct(shape) || data.some((x) => typeof x !== 'number')) {
        throw new Error('invalid argument');
    }
    return new NDArray(data, shape);
}
exports.createArray = createArray;
function flatten(array, dest = []) {
    for (const a of array) {
        if (Array.isArray(a)) {
            flatten(a, dest);
        }
        else {
            dest.push(a);
        }
    }
    return dest;
}
function flattenIndices(indices, shape) {
    const ks = [1];
    let k = 1;
    for (let i = shape.length - 1; i >= 1; i--) {
        k *= shape[i];
        ks.unshift(k);
    }
    return indices.reduce((a, idx, i) => a + idx * ks[i], 0);
}
function trim(s) {
    return s.trim();
}
function parseExprForEinsum(expr) {
    const [a, b] = expr.split('->').map(trim);
    return [
        a.split(';').map((s) => s.split(',').map(trim)),
        b.length > 0 ? b.split(',').map(trim) : []
    ];
}
function einsum(expr, ...arrays) {
    const [indexNameLists, resultIndexNames] = parseExprForEinsum(expr);
    if (indexNameLists.length === 0) {
        throw new Error('Specify one or more elements for 1st argument');
    }
    for (const [i, indexNames] of indexNameLists.entries()) {
        if (indexNames.length !== arrays[i].shape.length) {
            throw new Error(`Number of index names and rank of array at ${i}`);
        }
    }
    const idByIndexName = {};
    const dims = [];
    const indexIdLists = indexNameLists.map((a) => a.map((i) => 0));
    for (const [i, indexNames] of indexNameLists.entries()) {
        for (const [j, iName] of indexNames.entries()) {
            if (iName in idByIndexName) {
                if (dims[idByIndexName[iName]] !== arrays[i].shape[j]) {
                    throw new Error('shape is not matched');
                }
                indexIdLists[i][j] = idByIndexName[iName];
            }
            else {
                const id = dims.length;
                dims.push(arrays[i].shape[j]);
                idByIndexName[iName] = id;
                indexIdLists[i][j] = id;
            }
        }
    }
    const resultIndexIds = [];
    const resultShape = [];
    for (const iName of resultIndexNames) {
        const id = idByIndexName[iName];
        if (id == null) {
            throw new Error(`Unknown index name '${iName}'`);
        }
        resultIndexIds.push(id);
        resultShape.push(dims[id]);
    }
    const result = zeros(resultShape.length > 0 ? resultShape : [1]);
    const ai = indexIdLists.map((ids) => new Array(ids.length));
    for (const idx of enumerateIndices(dims)) {
        const ri = resultShape.length > 0 ? resultIndexIds.map((i) => idx[i]) : [0];
        let p = 1;
        for (let i = 0; i < arrays.length; i++) {
            for (let j = 0; j < indexIdLists[i].length; j++) {
                ai[i][j] = idx[indexIdLists[i][j]];
            }
            p *= arrays[i].get(ai[i]);
        }
        result.set(ri, result.get(ri) + p);
    }
    return result;
}
exports.einsum = einsum;
function operate(f, a, b, out) {
    if (typeof a === 'number') {
        if (typeof b === 'number') {
            a = createArray([a]);
        }
        else {
            a = createArray([a]).reshape(b.shape.map(() => 1));
        }
    }
    if (typeof b === 'number') {
        b = createArray([b]).reshape(a.shape.map(() => 1));
    }
    if (a.shape.length !== b.shape.length) {
        throw new Error('Incompatible shape');
    }
    const r = a.shape.length;
    for (let i = 0; i < r; i++) {
        if (a.shape[i] !== 1 && b.shape[i] !== 1 && a.shape[i] !== b.shape[i]) {
            throw new Error('Incompatible shape');
        }
    }
    const ma = a.shape.map((s) => (s > 1 ? 1 : 0));
    const mb = b.shape.map((s) => (s > 1 ? 1 : 0));
    const resultShape = a.shape.map((s, i) => Math.max(s, b.shape[i]));
    if (out != null && !isSameShape(out.shape, resultShape)) {
        throw new Error('Shape of `out` is incompatible');
    }
    const result = out || zeros(resultShape);
    const ia = new Array(r);
    const ib = new Array(r);
    let i = 0;
    for (const idx of enumerateIndices(resultShape)) {
        for (let j = 0; j < r; j++) {
            ia[j] = idx[j] * ma[j];
            ib[j] = idx[j] * mb[j];
        }
        result.data[i] = f(a.get(ia), b.get(ib));
        i++;
    }
    return result;
}
function createUniversalBinaryFunction(f) {
    return (a, b, out) => operate(f, a, b, out);
}
exports.add = createUniversalBinaryFunction((a, b) => a + b);
exports.sub = createUniversalBinaryFunction((a, b) => a - b);
exports.mul = createUniversalBinaryFunction((a, b) => a * b);
exports.div = createUniversalBinaryFunction((a, b) => a / b);
exports.pow = createUniversalBinaryFunction((a, b) => Math.pow(a, b));
function operateUnary(f, a, out) {
    if (typeof a === 'number') {
        a = createArray([a]);
    }
    if (out != null && !isSameShape(a.shape, out.shape)) {
        throw new Error('Shape of `out` is incompatible');
    }
    const result = out || zerosLike(a);
    for (const i of enumerateIndices(result.shape)) {
        result.set(i, f(a.get(i)));
    }
    return result;
}
exports.operateUnary = operateUnary;
exports.map = operateUnary;
function createUniversalUnaryOperator(f) {
    return (a, out) => operateUnary(f, a, out);
}
exports.neg = createUniversalUnaryOperator((a) => -a);
exports.exp = createUniversalUnaryOperator(Math.exp);
exports.abs = createUniversalUnaryOperator(Math.abs);
function argMin(array, axis) {
    if (axis < 0 || axis >= array.shape.length) {
        throw new Error('invalid axis');
    }
    const resultShape = array.shape.slice();
    resultShape.splice(axis, 1);
    if (shapeProduct(array.shape) === 0) {
        return new NDArray([], resultShape);
    }
    const shape = array.shape;
    const subShape = shape.slice();
    subShape[axis] = 1;
    const result = zeros(resultShape);
    for (const i of enumerateIndices(subShape)) {
        let min = array.get(i);
        let minIndex = 0;
        for (let j = 1; j < shape[axis]; j++) {
            i[axis] = j;
            const v = array.get(i);
            if (v < min) {
                min = v;
                minIndex = j;
            }
        }
        i.splice(axis, 1);
        result.set(i, minIndex);
    }
    return result;
}
exports.argMin = argMin;
function argMax(array, axis) {
    if (axis < 0 || axis >= array.shape.length) {
        throw new Error('invalid axis');
    }
    const resultShape = array.shape.slice();
    resultShape.splice(axis, 1);
    if (shapeProduct(array.shape) === 0) {
        return new NDArray([], resultShape);
    }
    const shape = array.shape;
    const subShape = shape.slice();
    subShape[axis] = 1;
    const result = zeros(resultShape);
    for (const i of enumerateIndices(subShape)) {
        let max = array.get(i);
        let maxIndex = 0;
        for (let j = 1; j < shape[axis]; j++) {
            i[axis] = j;
            const v = array.get(i);
            if (v > max) {
                max = v;
                maxIndex = j;
            }
        }
        i.splice(axis, 1);
        result.set(i, maxIndex);
    }
    return result;
}
exports.argMax = argMax;
function checkAxesForAggregation(axesOrAxis, shape) {
    const dim = shape.length;
    if (axesOrAxis == null) {
        return [utils_1.range(dim), []];
    }
    let axes;
    if (typeof axesOrAxis === 'number') {
        axes = [axesOrAxis];
    }
    else {
        axes = axesOrAxis.slice();
    }
    {
        const _axes = [];
        for (const ax of axes) {
            if (ax !== (ax | 0) || ax < 0 || ax >= dim || _axes.includes(ax)) {
                throw new Error('invalid axes');
            }
            _axes.push(ax);
        }
    }
    const remaining = utils_1.range(dim).filter((a) => !axes.includes(a));
    return [axes, remaining];
}
function sum(array, axisOrAxes = undefined) {
    const shape = array.shape;
    const [, remainingAxes] = checkAxesForAggregation(axisOrAxes, shape);
    const newShape = remainingAxes.map((a) => shape[a]);
    if (newShape.length === 0) {
        let sum = 0;
        for (const idx of enumerateIndices(shape)) {
            sum += array.get(idx);
        }
        return createArray([sum]);
    }
    const newArray = zeros(newShape);
    for (const idx of enumerateIndices(shape)) {
        const newIdx = remainingAxes.map((a) => idx[a]);
        newArray.update(newIdx, (x) => x + array.get(idx));
    }
    return newArray;
}
exports.sum = sum;
function prod(array, axisOrAxes = undefined) {
    const shape = array.shape;
    const [, remainingAxes] = checkAxesForAggregation(axisOrAxes, shape);
    const newShape = remainingAxes.map((a) => shape[a]);
    if (newShape.length === 0) {
        let sum = 1;
        for (const idx of enumerateIndices(shape)) {
            sum *= array.get(idx);
        }
        return createArray([sum]);
    }
    const newArray = repeat(1, newShape);
    for (const idx of enumerateIndices(shape)) {
        const newIdx = remainingAxes.map((a) => idx[a]);
        newArray.update(newIdx, (x) => x * array.get(idx));
    }
    return newArray;
}
exports.prod = prod;
function mean(array, axisOrAxes = undefined) {
    const s = sum(array, axisOrAxes);
    const sp = shapeProduct(array.shape);
    const newSp = shapeProduct(s.shape);
    if (sp === 0) {
        return s;
    }
    return s.div(sp / newSp);
}
exports.mean = mean;
function clip(array, min, max, out) {
    return operateUnary((n) => Math.max(Math.min(n, max), min), array, out);
}
exports.clip = clip;
function min(array, axisOrAxes) {
    const shape = array.shape;
    const [, remainingAxes] = checkAxesForAggregation(axisOrAxes, shape);
    const newShape = remainingAxes.map((a) => shape[a]);
    if (newShape.length === 0) {
        let min = Infinity;
        for (const idx of enumerateIndices(shape)) {
            min = Math.min(min, array.get(idx));
        }
        return createArray([min]);
    }
    const newArray = repeat(Infinity, newShape);
    for (const idx of enumerateIndices(shape)) {
        const newIdx = remainingAxes.map((a) => idx[a]);
        newArray.update(newIdx, (x) => Math.min(x, array.get(idx)));
    }
    return newArray;
}
exports.min = min;
function max(array, axisOrAxes) {
    const shape = array.shape;
    const [, remainingAxes] = checkAxesForAggregation(axisOrAxes, shape);
    const newShape = remainingAxes.map((a) => shape[a]);
    if (newShape.length === 0) {
        let max = -Infinity;
        for (const idx of enumerateIndices(shape)) {
            max = Math.max(max, array.get(idx));
        }
        return createArray([max]);
    }
    const newArray = repeat(-Infinity, newShape);
    for (const idx of enumerateIndices(shape)) {
        const newIdx = remainingAxes.map((a) => idx[a]);
        newArray.update(newIdx, (x) => Math.max(x, array.get(idx)));
    }
    return newArray;
}
exports.max = max;
