import { FaceLandmarker } from '@mediapipe/tasks-vision';
import { Matrix4, Vector3 } from 'three';
import { currentState } from '../CurrentState';
import { AvatarAnimatable } from './avatar_animatable';
import { avatars } from './avatar';

let refLandmarker: FaceLandmarker;
let refVideo: HTMLVideoElement | undefined;

// export const stopTracking = () => {
//   // disposeVideo();
//   refVideo = undefined;
// };
export const disposeVideo = () => {
  refVideo = undefined;
  resetFaceTracking();
};

export const resetFaceTracking = () => {
  const avatar = currentState.avatar as AvatarAnimatable;
  avatar.skeleton.bones[5].matrix.copy(avatar.head_mat);
  avatar.skeleton.bones[5].matrixAutoUpdate = true;

  // reset coeffs
  for (const [k, v] of Object.entries(avatars)) {
    if (v.is_downloaded) {
      (v as AvatarAnimatable).meshes.forEach((mesh) => {
        mesh.morphTargetInfluences?.fill(0);
      });
    }
  }
};

// We sometimes need to mirror predicted coeffs
let mirror_map: Record<string, string>;
const buildMirrorMap = (blendshapeNames: string[]) => {
  const mirror_map_: typeof mirror_map = {};
  blendshapeNames.forEach((key: string) => {
    if (key.includes('Right')) {
      mirror_map_[key] = key.replace('Right', 'Left');
    } else if (key.includes('Left')) {
      mirror_map_[key] = key.replace('Left', 'Right');
    } else {
      mirror_map_[key] = key;
    }
  });
  return mirror_map_;
};

function getMultiplier(name: string) {
  let multiplier = 1.0;
  if (name.startsWith('browDown')) {
    multiplier = 2.5;
  } else if (name.startsWith('cheekPuff')) {
    multiplier = 1;
  } else if (name == 'jawRight') {
    multiplier = 3;
  } else if (name == 'jawLeft') {
    multiplier = 3;
  }
  return multiplier;
}

export const maybePredictBlendshapes = async () => {
  if (!refVideo || !refLandmarker || refVideo.videoHeight == 0 || refVideo.videoWidth == 0) {
    return;
  }

  const avatar = currentState.avatar as AvatarAnimatable;

  // Actual detection
  const nowInMs = Date.now();
  const detections = refLandmarker.detectForVideo(refVideo, nowInMs, {});

  mirror_map ??= buildMirrorMap(Object.keys(avatar.head_mesh.morphTargetDictionary!));

  const mirror = true;
  if (detections.facialTransformationMatrixes!.length == 0) {
    return;
  }

  const matrix = new Matrix4().fromArray(detections.facialTransformationMatrixes![0].data);
  if (mirror) {
    matrix.elements[1] *= -1;
    matrix.elements[2] *= -1;
    matrix.elements[4] *= -1;
    matrix.elements[8] *= -1;
  }

  // apply Bones
  const m = avatar.skeleton.bones[5].matrix.elements;
  const pos = new Vector3(m[12], m[13], m[14]);

  avatar.skeleton.bones[5].matrix.copy(matrix);
  avatar.skeleton.bones[5].matrix.multiply(avatar.head_mat);
  avatar.skeleton.bones[5].matrix.setPosition(pos);

  // recording part 1
  currentState.tracking_recorder?.newFrame(nowInMs);
  currentState.tracking_recorder?.recordHeadTransform(avatar.skeleton.bones[5].matrix);

  // Apply blendshapes
  detections.faceBlendshapes?.[0].categories.forEach((v, i) => {
    const multiplier = getMultiplier(v.categoryName);
    const _name = mirror ? mirror_map![v.categoryName] : v.categoryName;
    const influence = v.score * multiplier;

    avatar.meshes.forEach((mesh, mesh_index) => {
      const idx = mesh.morphTargetDictionary![_name];

      if (idx === undefined) {
        return;
      }

      if (mesh.morphTargetInfluences && idx !== undefined) mesh.morphTargetInfluences[idx] = influence;

      // recording part 2 (blendshapes)
      if (mesh_index == 0) {
        currentState.tracking_recorder?.recordInfluence(idx, influence);
      }
    });
  });
};

export const initLandmarker = async (video: HTMLVideoElement) => {
  const { FaceLandmarker, FilesetResolver } = await import('@mediapipe/tasks-vision');

  refVideo = video;
  const avatar = currentState.avatar as AvatarAnimatable;
  avatar.skeleton.bones[5].matrixAutoUpdate = false;

  if (refLandmarker !== undefined) {
    return;
  }

  const canvas = document.createElement('canvas');

  const vision = await FilesetResolver.forVisionTasks(
    'https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm'
  );

  const faceLandmarker = await FaceLandmarker.createFromOptions(vision, {
    canvas: canvas,
    outputFaceBlendshapes: true,
    outputFacialTransformationMatrixes: true,
    runningMode: 'VIDEO',
    baseOptions: {
      delegate: 'GPU',
      modelAssetPath: 'https://storage.googleapis.com/mediapipe-assets/face_landmarker_with_blendshapes.task',
    },
  });

  refLandmarker = faceLandmarker;
};
