import { erf } from "mathjs";
const math = require('mathjs');
const ss = require('simple-statistics');

export function get_ROI_params(
  outcome_id,
  outcome_name,
  core_data,
  data,
  action = false
) {
  if (!data || !core_data) {
    return {};
  }

  let model;
  if (action) {
    model = multiple_regression2(data, outcome_id, core_data?.questions);
  } else {
    model = multiple_regression(data, outcome_id, core_data?.questions);
  }
  function calculateAdjustedCoefficient(
    actualScore,
    minScore,
    maxScore,
    globalBenchmark,
    regressionCoefficient
  ) {
    // Normalize the score
    const normalizedScore = (actualScore - minScore) / (maxScore - minScore);

    const potentialForOptimization = 100 - (actualScore - globalBenchmark);
    const adjustedCoefficient =
      regressionCoefficient * potentialForOptimization;

    return adjustedCoefficient;
  }

  let _data = [];
  Object.keys(model.model).map((item) => {
    let _item = model.model[item];
    let benchmark = core_data?.standards?.response?.pillars.find(
      (f) => f.factor == _item.index.factor && f.id == _item.index.dimension
    )?.average;

    benchmark = benchmark ? benchmark : 7;
    let adjustedCoefficient = calculateAdjustedCoefficient(
      _item.average_score * 10,
      0,
      100,
      benchmark * 10,
      _item.coefficient
    );

    if (_item.importance > 0) {
      _data.push({
        name: item,
        value: adjustedCoefficient,
        score: _item.average_score,
        factor: _item.index.factor,
        dimension: _item.index.dimension,
        benchmark: benchmark,
        coefficient: _item.coefficient,
        impact:
          _item.coefficient * 0.1 + _item.SE * 1.96 * 0.1 * _item.coefficient,
        lower_impact:
          _item.coefficient * 0.1 - _item.SE * 1.96 * 0.1 * _item.coefficient,
      });
    }
  });

  _data.sort((a, b) => b.value - a.value);

  return _data;
}

function calculateAverage(data) {
  // Utility function to get the reversed score
  function getReverseScore(score) {
    return 10 - score + 1; // Assuming the scores range from 1-10
  }

  const aggregated = {};

  data.forEach((item) => {
    const key = `${item.factor}-${item.id}`;
    const score = item.reverse ? getReverseScore(item.response) : item.response;

    if (!aggregated[key]) {
      aggregated[key] = {
        sum: 0,
        count: 0,
      };
    }

    aggregated[key].sum += score;
    aggregated[key].count++;
  });

  const averages = [];

  for (const [key, values] of Object.entries(aggregated)) {
    const [factor, dimension] = key.split("-").map(Number);
    averages.push({
      factor: factor,
      dimension: dimension,
      average: values.sum / values.count,
    });
  }

  return averages;
}

// Calculate the average.
function calculateOutcomeAverage(arr) {
  if (!arr.length) return 0; // Handle empty arrays
  const sum = arr.reduce((acc, val) => acc + val.response, 0);
  return sum / arr.length;
}

// Get the response average outcome and the factor averages
const get_outcome_and_factor_averages = (data, q) => {
  if (data?.employee_outcomes?.responses?.find((f) => f.q === q)) {
    let outcomes = calculateOutcomeAverage(
      data?.employee_outcomes?.responses?.filter((f) => f.q === q)
    );
    let factor_scores = calculateAverage(data.questions);
    return { outcomes, factor_scores };
  }

  return { outcomes: null, factor_scores: null };
};

const multiple_regression = (data, qId = 1, questions) => {
  let results_arr = [];

  if (data[0] && data[0]?.length > 0) {
    data[0]?.forEach((set) => {
      const averages = get_outcome_and_factor_averages(set, qId);
      if (averages.outcomes) {
        results_arr = [...results_arr, averages];
      }
    });
  } else {
    data?.forEach((set) => {
      const averages = get_outcome_and_factor_averages(set, qId);
      if (averages.outcomes) {
        results_arr = [...results_arr, averages];
      }
    });
  }

  if (results_arr.length == 0) {
    return { correlation: 0, model: {} };
  }

  const features = results_arr.map((i) =>
    i.factor_scores.map((f) => f.average)
  );
  const labels = results_arr.map((i) => i.outcomes);
  const feature_labels = results_arr[0].factor_scores.map((f) => {
    return {
      title: questions.dimensions[f.dimension].factors[f.factor].title,
      dimension: f.dimension,
      factor: f.factor,
    };
  });

  let model = {};

  // Function to calculate regression coefficients
  function calculateCoefficients(X, y) {
    const XMatrix = math.matrix(X);
    const yMatrix = math.matrix(y);
    const XTranspose = math.transpose(XMatrix);
    const XTX = math.multiply(XTranspose, XMatrix);
    const XTXInverse = math.inv(XTX);
    const XTy = math.multiply(XTranspose, yMatrix);
    const coefficients = math.multiply(XTXInverse, XTy);
    return coefficients._data;
  }

  // Calculate coefficients
  const coefficients = calculateCoefficients(features, labels);

  // Standardize features
  function standardize(features) {
    const means = features[0].map((_, colIndex) =>
      ss.mean(features.map((row) => row[colIndex]))
    );
    const stdDevs = features[0].map((_, colIndex) =>
      ss.standardDeviation(features.map((row) => row[colIndex]))
    );
    return features.map((row) =>
      row.map(
        (value, colIndex) => (value - means[colIndex]) / stdDevs[colIndex]
      )
    );
  }

  const standardizedFeatures = standardize(features);

  // Calculate coefficients for standardized features
  const standardizedCoefficients = calculateCoefficients(
    standardizedFeatures,
    labels
  );

  standardizedCoefficients.forEach((coef, index) => {
    model[feature_labels[index].title] = { coefficient: Math.abs(coef) };
  });

  // Permutation Importance

  function shuffleFeature(data, featureIndex) {
    const shuffledData = data.map((row) => [...row]);
    for (let i = shuffledData.length - 1; i > 0; i--) {
      const j = Math.floor(Math.random() * (i + 1));
      [shuffledData[i][featureIndex], shuffledData[j][featureIndex]] = [
        shuffledData[j][featureIndex],
        shuffledData[i][featureIndex],
      ];
    }
    return shuffledData;
  }

  function predict(features, coefficients) {
    return features.map((row) =>
      row.reduce((sum, value, index) => sum + value * coefficients[index], 0)
    );
  }

  function meanSquaredError(actual, predicted) {
    const n = actual.length;
    return (
      actual.reduce(
        (sum, val, idx) => sum + Math.pow(val - predicted[idx], 2),
        0
      ) / n
    );
  }

  function rSquared(actual, predicted) {
    const mean = ss.mean(actual);
    const totalSumOfSquares = actual.reduce(
      (sum, val) => sum + Math.pow(val - mean, 2),
      0
    );
    const residualSumOfSquares = actual.reduce(
      (sum, val, idx) => sum + Math.pow(val - predicted[idx], 2),
      0
    );
    return 1 - residualSumOfSquares / totalSumOfSquares;
  }

  const baselinePredictions = predict(features, coefficients);
  const baselineError = meanSquaredError(labels, baselinePredictions);

  // Function to calculate residuals
  function calculateResiduals(X, y, coefficients) {
    const predictions = predict(X, coefficients);
    return y.map((actual, idx) => actual - predictions[idx]);
  }

  // Function to calculate Residual Sum of Squares (RSS)
  function calculateRSS(residuals) {
    return residuals.reduce((sum, res) => sum + res * res, 0);
  }

  // Function to calculate the variance of the residuals
  function calculateResidualVariance(RSS, n, p) {
    return RSS / (n - p);
  }

  // Function to calculate the variance-covariance matrix
  function calculateVarianceCovarianceMatrix(X, residualVariance) {
    const XMatrix = math.matrix(X);
    const XTranspose = math.transpose(XMatrix);
    const XTX = math.multiply(XTranspose, XMatrix);
    const XTXInverse = math.inv(XTX);
    return math.multiply(XTXInverse, residualVariance);
  }

  // Function to calculate standard errors of the coefficients
  function calculateStandardErrors(varianceCovarianceMatrix) {
    return math.sqrt(math.diag(varianceCovarianceMatrix)._data);
  }

  // Given data
  const n = labels.length; // Number of observations
  const p = features[0].length + 1; // Number of coefficients (including intercept)

  // Calculate residuals
  const residuals = calculateResiduals(features, labels, coefficients);

  // Calculate RSS
  const RSS = calculateRSS(residuals);

  // Calculate variance of the residuals
  const residualVariance = calculateResidualVariance(RSS, n, p);

  // Calculate variance-covariance matrix
  const varianceCovarianceMatrix = calculateVarianceCovarianceMatrix(
    features,
    residualVariance
  );

  // Calculate standard errors of the coefficients
  const standardErrors = calculateStandardErrors(varianceCovarianceMatrix);

  features[0].forEach((_, featureIndex) => {
    const shuffledFeatures = shuffleFeature(features, featureIndex);
    const shuffledPredictions = predict(shuffledFeatures, coefficients);
    const shuffledError = meanSquaredError(labels, shuffledPredictions);
    const importance = shuffledError - baselineError;
    model[feature_labels[featureIndex].title]["importance"] = importance;
    model[feature_labels[featureIndex].title]["index"] =
      feature_labels[featureIndex];
    model[feature_labels[featureIndex].title]["average_score"] = ss.mean(
      features.map((row) => row[featureIndex])
    );
    model[feature_labels[featureIndex].title]["SE"] =
      standardErrors[featureIndex];
  });

  // Calculate and log the R-squared value
  const r2 = rSquared(labels, baselinePredictions);

  return { correlation: Math.sqrt(r2), model: model };
};;

export const multiple_regression2 = (data, qId = 1, questions) => {
  let results_arr = [];

  // Process the data to get averages
  if (data[0] && data[0]?.length > 0) {
    data[0]?.forEach((set) => {
      const averages = get_outcome_and_factor_averages(set, qId);
      if (averages.outcomes) {
        results_arr.push(averages);
      }
    });
  } else {
    data?.forEach((set) => {
      const averages = get_outcome_and_factor_averages(set, qId);
      if (averages.outcomes) {
        results_arr.push(averages);
      }
    });
  }

  if (results_arr.length === 0) {
    return {
      correlation: 0,
      model: {},
      outlierInfo: {},
      model_without_outliers: {},
      correlation_without_outliers: null,
    };
  }

  // Extract features and labels
  const features = results_arr.map((i) =>
    i.factor_scores.map((f) => f.average)
  );

  const labels = results_arr.map((i) => i.outcomes);

  // **Add this check**
  if (features.length <= features[0].length) {
    console.error(
      "Insufficient data: Number of observations must be greater than the number of features."
    );
    return {
      correlation: 0,
      model: {},
      outlierInfo: {},
      model_without_outliers: {},
      correlation_without_outliers: null,
    };
  }


  const feature_labels = results_arr[0].factor_scores.map((f) => {
    return {
      title: questions.dimensions[f.dimension].factors[f.factor].title,
      dimension: f.dimension,
      factor: f.factor,
    };
  });

  // Function to calculate regression coefficients
  function calculateCoefficients(X, y) {
    const XMatrix = math.matrix(X);
    const yMatrix = math.matrix(y);
    const XTranspose = math.transpose(XMatrix);
    const XTX = math.multiply(XTranspose, XMatrix);
    const XTXInverse = math.inv(XTX);
    const XTy = math.multiply(XTranspose, yMatrix);
    const coefficients = math.multiply(XTXInverse, XTy);
    return coefficients._data;
  }

  // Function to standardize features
  function standardize(features) {
    const means = features[0].map((_, colIndex) =>
      ss.mean(features.map((row) => row[colIndex]))
    );
    const stdDevs = features[0].map((_, colIndex) =>
      ss.standardDeviation(features.map((row) => row[colIndex]))
    );
    return features.map((row) =>
      row.map(
        (value, colIndex) => (value - means[colIndex]) / stdDevs[colIndex]
      )
    );
  }

  // Function to predict labels
  function predict(features, coefficients) {
    return features.map((row) =>
      row.reduce((sum, value, index) => sum + value * coefficients[index], 0)
    );
  }

  // Function to calculate mean squared error
  function meanSquaredError(actual, predicted) {
    const n = actual.length;
    return (
      actual.reduce(
        (sum, val, idx) => sum + Math.pow(val - predicted[idx], 2),
        0
      ) / n
    );
  }

  // Function to calculate R-squared
  function rSquared(actual, predicted) {
    const mean = ss.mean(actual);
    const totalSumOfSquares = actual.reduce(
      (sum, val) => sum + Math.pow(val - mean, 2),
      0
    );
    const residualSumOfSquares = actual.reduce(
      (sum, val, idx) => sum + Math.pow(val - predicted[idx], 2),
      0
    );
    return 1 - residualSumOfSquares / totalSumOfSquares;
  }

  // Function to calculate residuals
  function calculateResiduals(X, y, coefficients) {
    const predictions = predict(X, coefficients);
    return y.map((actual, idx) => actual - predictions[idx]);
  }

  // Function to calculate Residual Sum of Squares (RSS)
  function calculateRSS(residuals) {
    return residuals.reduce((sum, res) => sum + res * res, 0);
  }

  // Function to calculate the variance of the residuals
  function calculateResidualVariance(RSS, n, p) {
    return RSS / (n - p);
  }

  // Function to calculate the variance-covariance matrix
  function calculateVarianceCovarianceMatrix(X, residualVariance) {
    const XMatrix = math.matrix(X);
    const XTranspose = math.transpose(XMatrix);
    const XTX = math.multiply(XTranspose, XMatrix);
    const XTXInverse = math.inv(XTX);
    return math.multiply(XTXInverse, residualVariance);
  }

  // Function to calculate standard errors of the coefficients
  function calculateStandardErrors(varianceCovarianceMatrix) {
    return math.sqrt(math.diag(varianceCovarianceMatrix)._data);
  }

  // Function to calculate standardized residuals
  function calculateStandardizedResiduals(X, y, coefficients) {
    const predictions = predict(X, coefficients);
    const residuals = y.map((actual, idx) => actual - predictions[idx]);
    const residualMean = ss.mean(residuals);
    const residualStdDev = ss.standardDeviation(residuals);
    const standardizedResiduals = residuals.map(
      (res) => (res - residualMean) / residualStdDev
    );
    return standardizedResiduals;
  }

  // Function to identify outliers based on standardized residuals
  function identifyOutliers(standardizedResiduals, threshold = 2) {
    const outlierIndices = [];
    standardizedResiduals.forEach((residual, index) => {
      if (Math.abs(residual) > threshold) {
        outlierIndices.push(index);
      }
    });
    return outlierIndices;
  }
  // Function to remove outliers from data
  function removeOutliers(features, labels, outlierIndices) {
    const filteredFeatures = features.filter(
      (_, idx) => !outlierIndices.includes(idx)
    );
    const filteredLabels = labels.filter(
      (_, idx) => !outlierIndices.includes(idx)
    );
    return { filteredFeatures, filteredLabels };
  }
  // Function to shuffle a feature column for permutation importance
  function shuffleFeature(data, featureIndex) {
    const shuffledData = data.map((row) => [...row]);
    for (let i = shuffledData.length - 1; i > 0; i--) {
      const j = Math.floor(Math.random() * (i + 1));
      [shuffledData[i][featureIndex], shuffledData[j][featureIndex]] = [
        shuffledData[j][featureIndex],
        shuffledData[i][featureIndex],
      ];
    }
    return shuffledData;
  }

  // Initialize models
  let model = {};
  let model_without_outliers = {};

  // === Model with All Data ===
  // Calculate coefficients
  const coefficients = calculateCoefficients(features, labels);


  // Standardize features
  const standardizedFeatures = standardize(features);

  // Calculate standardized coefficients
  const standardizedCoefficients = calculateCoefficients(
    standardizedFeatures,
    labels
  );

  // Store coefficients in the model
  standardizedCoefficients.forEach((coef, index) => {
    model[feature_labels[index].title] = { coefficient: Math.abs(coef) };
  });

  // Calculate residuals and standard errors
  const residuals = calculateResiduals(features, labels, coefficients);
  const RSS = calculateRSS(residuals);
  const n = labels.length; // Number of observations
  const p = features[0].length; // Number of predictors
  const residualVariance = calculateResidualVariance(RSS, n, p);
  const varianceCovarianceMatrix = calculateVarianceCovarianceMatrix(
    features,
    residualVariance
  );
  const standardErrors = calculateStandardErrors(varianceCovarianceMatrix);

  // Calculate baseline predictions and error
  const baselinePredictions = predict(features, coefficients);
  const baselineError = meanSquaredError(labels, baselinePredictions);

  // Calculate permutation importance
  features[0].forEach((_, featureIndex) => {
    const shuffledFeatures = shuffleFeature(features, featureIndex);
    const shuffledPredictions = predict(shuffledFeatures, coefficients);
    const shuffledError = meanSquaredError(labels, shuffledPredictions);
    const importance = shuffledError - baselineError;

    model[feature_labels[featureIndex].title] = {
      ...model[feature_labels[featureIndex].title],
      importance: importance,
      index: feature_labels[featureIndex],
      average_score: ss.mean(features.map((row) => row[featureIndex])),
      SE: standardErrors[featureIndex],
    };
  });

  // Calculate R-squared
  const r2 = rSquared(labels, baselinePredictions);

  // === Identify Outliers ===

  const standardizedResiduals = calculateStandardizedResiduals(
    features,
    labels,
    coefficients
  );
  const outlierIndices = identifyOutliers(standardizedResiduals);
  const outlierInfo = {
    numberOfOutliers: outlierIndices.length,
    indices: outlierIndices,
  };

  // === Model without Outliers ===

  // Remove outliers
  const { filteredFeatures, filteredLabels } = removeOutliers(
    features,
    labels,
    outlierIndices
  );

  // Declare variables outside the if-else block
  let coefficients_wo = null;
  let standardizedCoefficients_wo = null;
  let residuals_wo = null;
  let RSS_wo = null;
  let residualVariance_wo = null;
  let varianceCovarianceMatrix_wo = null;
  let standardErrors_wo = null;
  let baselinePredictions_wo = null;
  let baselineError_wo = null;
  let r2_wo = null;

  if (filteredLabels.length < 2) {
    // Not enough data to build a model without outliers
    console.warn("Not enough data to build a model without outliers.");
  } else {
    // Recalculate coefficients without outliers
    coefficients_wo = calculateCoefficients(filteredFeatures, filteredLabels);

    // Standardize filtered features
    const standardizedFilteredFeatures = standardize(filteredFeatures);

    // Calculate standardized coefficients without outliers
    standardizedCoefficients_wo = calculateCoefficients(
      standardizedFilteredFeatures,
      filteredLabels
    );

    // Store coefficients in the model without outliers
    standardizedCoefficients_wo.forEach((coef, index) => {
      model_without_outliers[feature_labels[index].title] = {
        coefficient: Math.abs(coef),
      };
    });

    // Calculate residuals and standard errors without outliers
    residuals_wo = calculateResiduals(
      filteredFeatures,
      filteredLabels,
      coefficients_wo
    );
    RSS_wo = calculateRSS(residuals_wo);
    const n_wo = filteredLabels.length; // Number of observations without outliers
    const p_wo = filteredFeatures[0].length; // Number of predictors
    residualVariance_wo = calculateResidualVariance(RSS_wo, n_wo, p_wo);
    varianceCovarianceMatrix_wo = calculateVarianceCovarianceMatrix(
      filteredFeatures,
      residualVariance_wo
    );
    standardErrors_wo = calculateStandardErrors(varianceCovarianceMatrix_wo);

    // Calculate baseline predictions and error without outliers
    baselinePredictions_wo = predict(filteredFeatures, coefficients_wo);
    baselineError_wo = meanSquaredError(filteredLabels, baselinePredictions_wo);

    // Calculate permutation importance without outliers
    filteredFeatures[0].forEach((_, featureIndex) => {
      const shuffledFeatures = shuffleFeature(filteredFeatures, featureIndex);
      const shuffledPredictions = predict(shuffledFeatures, coefficients_wo);
      const shuffledError = meanSquaredError(
        filteredLabels,
        shuffledPredictions
      );
      const importance = shuffledError - baselineError_wo;

      model_without_outliers[feature_labels[featureIndex].title] = {
        ...model_without_outliers[feature_labels[featureIndex].title],
        importance: importance,
        index: feature_labels[featureIndex],
        average_score: ss.mean(
          filteredFeatures.map((row) => row[featureIndex])
        ),
        SE: standardErrors_wo[featureIndex],
      };
    });

    // Calculate R-squared without outliers
    r2_wo = rSquared(filteredLabels, baselinePredictions_wo);
  }

  // Return both models and outlier information
  return {
    correlation: Math.sqrt(r2),
    model: model,
    outlierInfo: outlierInfo,
    model_without_outliers: model_without_outliers,
    correlation_without_outliers: r2_wo !== null ? Math.sqrt(r2_wo) : null,
  };
};