import {
  Assessment,
  MeasurementDimensionKeys,
  RiskCategory,
  RisksSnapshot,
  statisticalMeasurementDimensions,
} from '../../../types';
import {
  DimensionsStats,
  DimensionStats,
  RiskCutoffs,
  RisksCutoffs,
  RisksThresholds,
  RiskThresholds,
  SnapshotsStatsIndex,
} from '../../slices/risks_snapshots';

/**
 * {@link https://cutt.ly/7QbBWCL Incremental mean calculation}
 */
const calculateIncrementalMean = (
  value: number,
  number: number,
  previousMean: number
): number => previousMean + (value - previousMean) / number;

/**
 * {@link https://cutt.ly/1QbBUQ4 Incremental standard deviation calculation}
 */
const calculateIncrementalStddev = (
  value: number,
  number: number,
  previousStddev: number,
  mean: number,
  previousMean: number
): number => {
  const previousS = previousStddev ** 2 * (number - 1);
  const currentS = previousS + (value - previousMean) * (value - mean);
  return Math.sqrt(currentS / number);
};

const calculateDimensionStats = (
  dimension: MeasurementDimensionKeys,
  dimensionMeasurementsValues: number[],
  dimensionPreviousStats: DimensionStats[]
): DimensionStats => {
  const currentMeasurementsAmount = dimensionMeasurementsValues.length;
  const previousMeasurementsAmount = dimensionPreviousStats.reduce(
    (totalAmount, previousStat) => totalAmount + (previousStat.measurementsAmount || 0),
    0
  );

  const lastStats = dimensionPreviousStats[dimensionPreviousStats.length - 1];
  const { mean: previousSnapshotMean, stddev: previousSnapshotStddev } = lastStats || {
    mean: 0,
    stddev: 0,
  };

  let dimensionMean = previousSnapshotMean;
  let dimensionStddev = previousSnapshotStddev;

  /**
   * We transformed the range to LOG10 initially because it gave us better results.
   * We then transform the values back before comparing the actual values.
   * {@link https://soomo.height.app/T-2009#b6b81259-9484-4e06-ae9d-8fb6dc6bff9d Example}
   */
  dimensionMeasurementsValues
    .map((value) => Math.log10(value))
    .forEach((value, index) => {
      // traversing numbers should start from 1 to not have division by 0
      const dimensionMeasurementIndex = index + 1;
      const measurementNumber = previousMeasurementsAmount + dimensionMeasurementIndex;

      const previousMean = dimensionMean;
      dimensionMean = calculateIncrementalMean(value, measurementNumber, previousMean);

      const previousStddev = dimensionStddev;
      dimensionStddev = calculateIncrementalStddev(
        value,
        measurementNumber,
        previousStddev,
        dimensionMean,
        previousMean
      );
    });

  return {
    dimension,
    measurementsAmount: currentMeasurementsAmount,
    mean: dimensionMean,
    stddev: dimensionStddev,
  };
};

export const aggregateSnapshotMeasurementsStats = (
  assessments: Assessment[],
  previousSnapshotsDimensionsStats: DimensionsStats[]
): DimensionsStats => {
  const snapshotMeasurements = assessments.flatMap((assessment) => assessment.measurements);

  return statisticalMeasurementDimensions.reduce((measurementsStats, dimension) => {
    const dimensionPreviousStats = previousSnapshotsDimensionsStats.map(
      (measurementStats) => measurementStats[dimension]
    );

    const dimensionMeasurements = snapshotMeasurements.filter(
      (measurement) => measurement.dimension === dimension
    );

    // Consider only valid non-null and non-zero numbers
    const dimensionMeasurementsValues = dimensionMeasurements
      .map((measurement) => measurement.value)
      .filter((value): value is number => typeof value === 'number' && value !== 0);

    measurementsStats[dimension] = calculateDimensionStats(
      dimension,
      dimensionMeasurementsValues,
      dimensionPreviousStats
    );
    return measurementsStats;
  }, {} as DimensionsStats);
};

export const createSnapshotRisksCutoffs = (
  dimensionsStats: DimensionsStats,
  risksThresholds: RisksThresholds
): RisksCutoffs => {
  const getDimensionStats = (
    dimension: MeasurementDimensionKeys
  ): DimensionStats | { mean: number; stddev: number } => {
    return dimensionsStats[dimension] || { mean: 0, stddev: 0 };
  };

  // Example of cutoff calculation: https://cutt.ly/3Q5vyp6
  const createZScoreCutoff = (
    dimension: MeasurementDimensionKeys,
    zScore: number | undefined,
    type: 'low' | 'high'
  ): number => {
    if (zScore === undefined) return 0;

    const { mean, stddev } = getDimensionStats(dimension); // EXPRESSED AS LOG10 VALUES
    const zScoreValue = zScore * stddev;
    const exponent = type === 'low' ? mean - zScoreValue : mean + zScoreValue;
    return 10 ** exponent;
  };

  const createLowScoreCutoffs = (threshold: RiskThresholds): RiskCutoffs => {
    const { low } = threshold;
    const lowCutoff = Math.max(0, createZScoreCutoff('score-achieved', low, 'low'));
    return { lowCutoff };
  };

  const createLowCompletionCutoffs = (threshold: RiskThresholds): RiskCutoffs => {
    const { low, high } = threshold;
    const lowCutoff = Math.max(0, createZScoreCutoff('completion-achieved', low, 'low'));
    const highCutoff = Math.max(
      lowCutoff,
      createZScoreCutoff('completion-achieved', high, 'high')
    );
    return { lowCutoff, highCutoff };
  };

  const createNoSignInCutoffs = (threshold: RiskThresholds): RiskCutoffs => {
    const { low: lowCutoff, high: highCutoff } = threshold; // Expressed in days
    return { lowCutoff, highCutoff };
  };

  return Object.entries(risksThresholds).reduce((cutoffs, [riskCategory, thresholds]) => {
    if (!thresholds || (thresholds.low === undefined && thresholds.high === undefined)) {
      return cutoffs;
    }

    switch (riskCategory) {
      case RiskCategory['low-score']:
        cutoffs[riskCategory] = createLowScoreCutoffs(thresholds);
        break;
      case RiskCategory['low-completion']:
        cutoffs[riskCategory] = createLowCompletionCutoffs(thresholds);
        break;
      case RiskCategory['no-sign-in']:
        cutoffs[riskCategory] = createNoSignInCutoffs(thresholds);
        break;
    }
    return cutoffs;
  }, {} as RisksCutoffs);
};

export const createSnapshotsStatsIndex = ({
  snapshots,
  snapshotsStatsIndex,
  risksThresholds,
}: {
  snapshots: RisksSnapshot[];
  snapshotsStatsIndex: SnapshotsStatsIndex;
  risksThresholds: RisksThresholds;
}): SnapshotsStatsIndex => {
  const previousSnapshotsStats = Object.values(snapshotsStatsIndex).map(
    (snapshotStatsIndex) => snapshotStatsIndex.dimensionsStats
  );

  return snapshots.reduce((statsIndex, snapshot) => {
    const processedSnapshotsStats = Object.values(statsIndex)
      .map((snapshotStatsIndex) => snapshotStatsIndex.dimensionsStats)
      .concat(previousSnapshotsStats);

    const {
      id: snapshotId,
      assessments: { assessments },
    } = snapshot;

    const assessmentsAmount = assessments.length;
    const dimensionsStats = aggregateSnapshotMeasurementsStats(
      assessments,
      processedSnapshotsStats
    );
    const risksCutoffs = createSnapshotRisksCutoffs(dimensionsStats, risksThresholds);

    statsIndex[snapshotId] = {
      snapshotId,
      assessmentsAmount,
      dimensionsStats,
      risksCutoffs,
    };
    return statsIndex;
  }, {} as SnapshotsStatsIndex);
};
