import npyjs from 'npyjs';
import ndarray from 'ndarray'; // TODO: consider replacing with https://github.com/stdlib-js/ndarray
import concatenate from 'ndarray-concat-rows';

const cumProd = arr => arr.map((sum => value => sum *= value)(1));

function scalarMul(arr, s) {
    let result = concatenate([arr]);
    for (let i = 0; i < result.data.length; i++)
    {
        result.data[i] = result.data[i] * s;
    }
    return result;
}

function mapIf(arr, conditionArr, checkFunction, ifFunction, elseFunction) {
    let result = concatenate([arr]);
    for (let i = 0; i < result.data.length; i++) {
        if (checkFunction(conditionArr.get(i))) {
            result.data[i] = ifFunction(arr.get(i));
        } else {
            result.data[i] = elseFunction(arr.get(i));
        }
    }
    return result;
}

class CPUPowerScalingCurveData {

    static async loadArray(url) {
        let rawData = await new npyjs().load(url);

        const stride = rawData.fortranOrder ?
            [1].concat(cumProd(rawData.shape).slice(0, -1)) :
            [1].concat(cumProd(rawData.shape.reverse()).slice(0, -1)).reverse();

        const array = new ndarray(rawData.data, rawData.shape, stride);
        return array;
    }

    static async load(curveMetadata) {
        const curvesArrayPromise = CPUPowerScalingCurveData.loadArray(curveMetadata.url);
        const pointsArrayPromise = (curveMetadata.rawDataUrl != null) ?
            CPUPowerScalingCurveData.loadArray(curveMetadata.rawDataUrl) :
            Promise.resolve(new ndarray(new Int16Array(), [0, 2]));
        
        const [curvesArray, pointsArray] = await Promise.all([curvesArrayPromise, pointsArrayPromise]);

        if (curvesArray.shape[1] !== 2 * curveMetadata['confidences'].length + 2) {
            throw new Error("loaded curve data does not provide enough curves for all confidence levels");
        }

        let wattTicks = curvesArray.pick(null, 0);
        let nonzeroRatios = curvesArray.pick(null, curvesArray.shape[1] - 1);
        nonzeroRatios = scalarMul(new ndarray(Float32Array.from(concatenate([nonzeroRatios]).data), nonzeroRatios.shape), 1e-3);
        
        let curves = Object.fromEntries(curveMetadata['confidences'].map((confidence, index) => {
            let data = {
                lower: curvesArray.pick(null, 1 + 2 * index),
                upper: curvesArray.pick(null, 2 + 2 * index),
            };
            return [confidence, data];
        }));

        return new CPUPowerScalingCurveData(curveMetadata, wattTicks, curves, nonzeroRatios, pointsArray);
    }

    constructor(curveMetadata, wattTicks, curves, nonzeroRatios, rawPoints) {
        this.curveMetadata = curveMetadata;
        this.curves = curves;
        this.wattTicks = wattTicks;
        this.nonzeroRatios = nonzeroRatios;
        this.rawPoints = rawPoints;
    }

    getId() {
        return this.curveMetadata['id'];
    }

    getLabel() {
        return this.curveMetadata['name'];
    }

    getDefaultPower() {
        return this.curveMetadata['defaultPower'];
    }

    getMaxPower() {
        const maxPower = this.curveMetadata['maxPower'];
        return maxPower ? maxPower : this.getDefaultPower();
    }

    getDefaultScore() {
        return this.curveMetadata['score'];
    }

    getAvailableConfidenceLevels() {
        return [...this.curveMetadata['confidences']];
    }

    shouldShowAtStart() {
        return this.curveMetadata['showAtStart'];
    }

    getWattTicks() {
        return this.wattTicks;
    }

    filterCutoff(arr, cutoffThreshold = 0) {
        let filtered = mapIf(arr, this.nonzeroRatios, nzr => nzr > cutoffThreshold, v => v, v => NaN);
        filtered.data[0] = 0; // TODO: hacky fix to ensure that the plotted trace always starts at 0 so that gradients are displayed correctly;
                              // otherwise gradients indicating probability start with first non-NaN value, which is incorrect
        return filtered;
    }

    getCurves(minConfidence = 0, maxConfidence = 1, multiplier = 1) {
        return Object.keys(this.curves)
                        .filter(confidence => ((minConfidence <= confidence) && (confidence <= maxConfidence)))
                        .sort()
                        .reverse()
                        .map(confidence => {
                            const curve = this.curves[confidence];
                            const xs = concatenate([ this.wattTicks, this.wattTicks.step(-1) ]);

                            const ys = scalarMul(concatenate([ 
                                    this.filterCutoff(curve['lower'], 1.0 - confidence),
                                    this.filterCutoff(curve['upper'], 1.0 - confidence).step(-1)
                                ]), multiplier);
                            return {
                                confidence: confidence,
                                x: xs.data,
                                y: ys.data,
                            };
                        });
    }

    getPoints(multiplier = 1) {
        let xs = concatenate([this.rawPoints.pick(null, 0)]);
        let ys = scalarMul(concatenate([this.rawPoints.pick(null, 1)]), multiplier);

        return {
            'x': xs,
            'y': ys
        };
    }

    getNonzeroRatios() {
        return concatenate([ this.nonzeroRatios, this.nonzeroRatios.step(-1) ]);
    }
}

export default CPUPowerScalingCurveData;
