import * as tf from "@tensorflow/tfjs";
import cv, { bool } from "@techstark/opencv-js";
import { Tensor4D } from "@tensorflow/tfjs";
import {
  DetectedObject,
  ImageObject,
  RecognitionModules,
  FilterResponse,
} from "types/imageRecognition";
// import '@tensorflow/tfjs-backend-webgpu';

const ASPECT_THRESH = 1.6;
const AREA_THRESH = 0.05;

export async function loadModel(
  detectionModelPath: string,
  classificationModelPath: string
): Promise<RecognitionModules> {
  // await tf.setBackend('webgpu')
  // await tf.ready()
  tf.enableProdMode();
  const detectionModel: tf.GraphModel = await tf.loadGraphModel(
    detectionModelPath
  );
  detectionModel.executeAsync(
    tf.ones(detectionModel.inputs[0].shape as number[])
  );

  const classificationModel: tf.GraphModel = await tf.loadGraphModel(
    classificationModelPath
  );
  classificationModel.execute(
    tf.ones(classificationModel.inputs[0].shape as number[])
  );

  const recognitionModules: RecognitionModules = {
    detectionModel: detectionModel,
    classificationModel: classificationModel,
  };
  return recognitionModules;
}

export const printTS_memory = (title: string) => {
  console.log("printTS_memory: ", {title},tf.memory());
}

export const deleteModels = (models: RecognitionModules): void => {
  console.log("deleteModels")
  models.detectionModel.dispose()
  models.classificationModel.dispose()
}

function isBboxInside(bbox1: number[], bbox2: number[]): boolean {
  const [xmin1, ymin1, xmax1, ymax1] = bbox1;
  const [xmin2, ymin2, xmax2, ymax2] = bbox2;
  return xmin2 >= xmin1 && ymin2 >= ymin1 && xmax2 <= xmax1 && ymax2 <= ymax1;
}

export async function localization(
  model: tf.GraphModel,
  img: ImageObject
): Promise<DetectedObject[]> {
  // console.log('Hello, world! Using Backend-', tf.getBackend());
  const [width, height]: [number, number] = [img.data.width, img.data.height];
  // console.log(width, height)
  const tensor: tf.Tensor3D = tf.browser.fromPixels(img.data);
  const resizedObj: tf.Tensor = tf.image.resizeBilinear(tensor, [640, 640]);
  const resized: tf.Tensor = resizedObj.div(255.0);
  const expanded: tf.Tensor = resized.expandDims(0);

  const start: number = Date.now();
  const predictions = (await model.executeAsync(expanded)) as tf.Tensor[];
  const timeTaken: number = Date.now() - start;
  //console.log("Total time taken for detection: " + timeTaken + " milliseconds");
  let cls2obj: { [key: number]: string } = { 0: "cassette", 1: "strip" };
  const [boxes, scores, classes, valid_detections] = predictions as [
    tf.Tensor,
    tf.Tensor,
    tf.Tensor,
    tf.Tensor
  ];
  const boxes_data = boxes.dataSync() as Float32Array;
  const scores_data = scores.dataSync() as Float32Array;
  const classes_data = classes.dataSync() as Int32Array;
  let maxScore0: number = -1;
  let maxScore1: number = -1;
  let maxIdx0: number = -1;
  let maxIdx1: number = -1;
  for (let i = 0; i < classes_data.length; i++) {
    let score = scores_data[i];
    let cls = classes_data[i];
    if (cls === 0) {
      if (score > maxScore0) {
        maxScore0 = score;
        maxIdx0 = i;
      }
    } else if (cls === 1) {
      if (score > maxScore1) {
        maxScore1 = score;
        maxIdx1 = i;
      }
    }
  }

  let localizedObjects: DetectedObject[] = [];
  if (maxScore0 > -1) {
    let box1: Float32Array = boxes_data.slice(4 * maxIdx0, 4 * maxIdx0 + 4);
    let bbox1: number[] = [
      box1[0] * width,
      box1[1] * height,
      box1[2] * width,
      box1[3] * height,
    ];
    const areaRatio = ((bbox1[3]-bbox1[1])*(bbox1[2]-bbox1[0]))/(width*height)
    const aspectRatio = (bbox1[3] - bbox1[1]) / (bbox1[2] - bbox1[0]);
    if (aspectRatio >= ASPECT_THRESH && areaRatio>=AREA_THRESH) {
      const expandedTensor = tensor.expandDims() as Tensor4D;
      let crop: tf.Tensor = tf.image.cropAndResize(
        expandedTensor,
        [[box1[1], box1[0], box1[3], box1[2]]],
        [0],
        [320, 320]
      );
      let cassete: DetectedObject = {
        bbox: bbox1,
        class: classes_data[maxIdx0],
        conf: scores_data[maxIdx0],
        aspectRatio: aspectRatio,
        crop: crop.squeeze(),
        index: img.index,
      };
      localizedObjects.push(cassete);

      expandedTensor.dispose();
      crop.dispose();
    }
  }
  if (maxScore1 > -1 && localizedObjects.length > 0) {
    let box2: Float32Array = boxes_data.slice(4 * maxIdx1, 4 * maxIdx1 + 4);
    let bbox2: number[] = [
      box2[0] * width,
      box2[1] * height,
      box2[2] * width,
      box2[3] * height,
    ];
    const aspectRatio = (bbox2[3] - bbox2[1]) / (bbox2[2] - bbox2[0]);

    const expandedTensor = tensor.expandDims() as Tensor4D;
    let crop: tf.Tensor = tf.image.cropAndResize(
      expandedTensor,
      [[box2[1], box2[0], box2[3], box2[2]]],
      [0],
      [320, 320]
    );
    const squeezedCrop = crop.squeeze();
    let strip: DetectedObject = {
      bbox: bbox2,
      class: classes_data[maxIdx1],
      conf: scores_data[maxIdx1],
      aspectRatio: aspectRatio,
      crop: squeezedCrop,
      index: img.index,
    };
    let isStripInside = isBboxInside(localizedObjects[0].bbox, strip.bbox);
    // console.log("Strip in cassette", isStripInside)
    expandedTensor.dispose();
    crop.dispose();

    if (isStripInside) {
      localizedObjects.push(strip);
    } else {
      squeezedCrop.dispose();
      localizedObjects[0].crop.dispose();
    }
  }
  // console.log(localizedObjects)
  if (localizedObjects.length > 0) localizedObjects[0].crop.dispose();

  tensor.dispose();
  resized.dispose();
  resizedObj.dispose();
  expanded.dispose();

  boxes.dispose();
  scores.dispose();
  classes.dispose();
  valid_detections.dispose();
  return localizedObjects;
}

export async function classification(
  model: tf.GraphModel,
  detections: DetectedObject[]
): Promise<{resault: boolean, trueCount: number, falseCount: number}> {
  let trueCount = 0;
  let falseCount = 0;
  const start: number = Date.now();
  await Promise.all(
    detections.map(async (detection) => {
      const resizedObj: tf.Tensor = detection.crop;
      const resized = resizedObj.div(255);
      const expanded = resized.expandDims(0);
      // console.log('Hello, world! Using Backend-', tf.getBackend());

      const predictions = (await model.execute(expanded)) as tf.Tensor;
      const results = predictions.dataSync() as Float32Array;

      // console.log(results);
      const ifPositive =
        results.indexOf(Math.max(...results)) == 1 ? true : false;
      ifPositive ? trueCount++ : falseCount++;

      resized.dispose();
      detection.crop.dispose()
      resizedObj.dispose();
      expanded.dispose();
      predictions.dispose();
    })
  );
  const timeTaken: number = Date.now() - start;
  // console.log(
  //   "Total time taken for classification: " + timeTaken + " milliseconds"
  // );
  // console.log("TrueVotes:", trueCount, "FalseVotes:", falseCount)
  // console.log("Patient is ", trueCount>falseCount?"+ve":"-ve")
  const res = {resault: trueCount > falseCount, trueCount, falseCount}
  return res;
}

async function blurDetection(src: cv.Mat): Promise<boolean> {
  // let src:cv.Mat = cv.imread(img)

  // Convert image to grayscale
  const gray: cv.Mat = new cv.Mat();
  cv.cvtColor(src, gray, cv.COLOR_RGBA2GRAY);

  // Apply Laplacian filter
  const laplacian: cv.Mat = new cv.Mat();
  cv.Laplacian(gray, laplacian, cv.CV_64F);

  // Calculate mean and standard deviation of the Laplacian result
  const mean: cv.Mat = new cv.Mat();
  const stddev: cv.Mat = new cv.Mat();
  cv.meanStdDev(laplacian, mean, stddev);

  // Extract the standard deviation value
  const stddevValue: number = stddev.data64F[0];

  // Define threshold value for blur detection
  const threshold: number = 3.0;

  // Check if image is blurred
  const isBlurred: boolean = stddevValue < threshold;

  // Output result
  // console.log("Focus value is ", stddevValue, isBlurred);

  // Cleanup
  // src.delete();
  gray.delete();
  laplacian.delete();
  mean.delete();
  stddev.delete();

  return isBlurred;
}

async function lightingDetection(src: cv.Mat): Promise<boolean> {
  // let src:cv.Mat = cv.imread(img)

  // Convert the image to grayscale
  const grayscale: cv.Mat = new cv.Mat();
  cv.cvtColor(src, grayscale, cv.COLOR_BGR2GRAY);

  // Calculate the histogram of grayscale intensities
  const histSize = 256; // Number of bins for the histogram
  const histRange = [0, 256]; // Range of grayscale intensities
  const accumulate = false;
  const hist: cv.Mat = new cv.Mat();
  const mask1: cv.Mat = new cv.Mat();
  let grayVec: cv.MatVector = new cv.MatVector();
  grayVec.push_back(grayscale);
  cv.calcHist(grayVec, [0], mask1, hist, [histSize], histRange, accumulate);

  // Normalize the histogram
  const histNormalized: cv.Mat = new cv.Mat();
  const alpha = 0; // Minimum value for normalization
  const beta = 1; // Maximum value for normalization
  const normType: number = cv.NORM_MINMAX;
  const dtype = -1; // Output Mat data type, -1 for same as input
  cv.normalize(hist, histNormalized, alpha, beta, normType, dtype);

  // Calculate the cumulative distribution function (CDF)
  const cdf: cv.Mat = new cv.Mat();
  const mask2: cv.Mat = new cv.Mat();
  const histSizeCdf = [histSize]; // Number of bins for the CDF
  const rangesCdf = [0, 256]; // Range of grayscale intensities for the CDF
  const uniform = true; // Indicate that the histogram is uniform
  cv.calcHist(grayVec, [0], mask2, cdf, histSizeCdf, rangesCdf, accumulate);

  // Determine the range of intensities that fall within a certain percentile threshold
  const percentileThreshold = 5; // Example percentile threshold
  const totalPixels = grayscale.cols * grayscale.rows;
  const pixelThreshold = Math.round((percentileThreshold / 100) * totalPixels);
  let lowIntensity = 0;
  let highIntensity = histSize - 1;
  let pixelCount = 0;
  const cdfData = cdf.data32F;
  while (pixelCount < pixelThreshold && lowIntensity < highIntensity) {
    pixelCount += cdfData[lowIntensity];
    lowIntensity++;
  }
  pixelCount = 0;
  while (pixelCount < pixelThreshold && highIntensity > lowIntensity) {
    pixelCount += cdfData[highIntensity];
    highIntensity--;
  }

  // Check if the range of intensities is considered low contrast based on the threshold
  const intensityRange = highIntensity - lowIntensity;
  const isLowContrast = intensityRange / 255 <= 0.15; // Adjust the threshold

  // src.delete()
  grayscale.delete();
  grayVec.delete();
  hist.delete();
  histNormalized.delete();
  mask1.delete();
  mask2.delete();
  cdf.delete();

  return isLowContrast;
}
export async function filtering(img: ImageObject) {
  // let qualityDict = {0:false, 1:false, 2:false, 3:false, 4:false}
  try {
    
    let qualityDict = {
      isGoodQuality: false,
      isClearQuality: false,
      isLowContrast: false,
      isBlurred: false,
      isOccluded: false,
    };
    // console.log("Image dimentions: ", img.data.width, img.data.height);
    let src: cv.Mat = cv.imread(img.data);
  
    // // Determine the dimensions of the input image
    // const imageWidth = src.cols;
    // const imageHeight = src.rows;
  
    // // Determine the size of the square crop
    // const cropSize = Math.min(imageWidth, imageHeight);
  
    // // Calculate the top-left corner coordinates of the crop
    // const cropX = Math.floor((imageWidth - cropSize) / 2);
    // const cropY = Math.floor((imageHeight - cropSize) / 2);
  
    // // Create a new Mat object for the cropped image
    // const cropped = new cv.Mat();
    // const roi = src.roi(new cv.Rect(cropX, cropY, cropSize, cropSize));
    // Get the center coordinates of the image
    let centerX = src.cols / 2;
    let centerY = src.rows / 2;
    // Calculate the top-left corner of the cropping region
    let cropX = centerX - 215 / 2;
    let cropY = centerY - 450 / 2;
    // Define the width and height of the crop
    let cropWidth = 215;
    let cropHeight = 450;
    // Create a rectangle to define the crop region
    let rect = new cv.Rect(cropX, cropY, cropWidth, cropHeight);
    // Crop the image using the defined rectangle
    let cropped = new cv.Mat();
    const roi = src.roi(rect);
    roi.copyTo(cropped);
  
    // console.log("before blurDetection");
    try {
      const [isBlurred, isLowContrast] = await Promise.all([
        blurDetection(cropped),
        lightingDetection(cropped),
      ]);
    
      // console.log(`Image is blurred: ${isBlurred}`);
      qualityDict.isBlurred = isBlurred;
    
      // console.log("Image Low Contrast:", isLowContrast);
      qualityDict.isLowContrast = isLowContrast;
    
      qualityDict.isGoodQuality = true ? !(isBlurred || isLowContrast) : false;
    } catch (error) {
      console.error("Filtering error: blurDetection, lightingDetection functions: ", error)
    }
    // console.log(qualityDict);
    src.delete(); 
    return qualityDict;
  } catch (error) {
    console.error("ERROR in Filtering: ", error)
    return {
      isGoodQuality: false,
      isClearQuality: false,
      isLowContrast: false,
      isBlurred: false,
      isOccluded: false,
    };
  }

}
