import { GPUComputationRenderer } from 'three/examples/jsm/misc/GPUComputationRenderer';
import { color_ramp_params } from '../core/resources/assets/hair';

import {
  Color,
  BufferAttribute,
  DataTexture,
  Group,
  MeshStandardMaterial,
  Material,
  RGBAFormat,
  SkinnedMesh,
  sRGBEncoding,
  Texture,
  UnsignedByteType,
  LinearFilter,
  LinearMipMapLinearFilter,
} from 'three';
import { AvatarView } from '../AvatarView';
import { brightness_contrast_pars } from '../core/Assets/shaders/brightness_contrast';
import { color_ramp_pars } from '../core/Assets/shaders/color_ramp';
import { curve_pars } from '../core/Assets/shaders/curve';
import { hair_cap_pars } from '../core/Assets/shaders/hair_cap';
import { Uniform } from '../core/Assets/shaders/_common';
import { Avatar } from '../core/resources/avatar';
import { EyeShaderParams, eye_pars } from '../core/Assets/shaders/eye_shader';
import { AvatarAnimatable } from '../core/resources/avatar_animatable';

import { Transformer } from './model_export';
import { hairCapShaderParams } from '../core/resources/haircap';

export let transformer: Transformer;

function blerp(alpha_1: number, alpha_2: number, alpha_3: number, alpha_4: number, u: number, v: number) {
  const delta_u = u - Math.floor(u);
  const delta_v = v - Math.floor(v);

  const q11 = alpha_1 * (1 - delta_v) * delta_u;
  const q21 = alpha_2 * (1 - delta_v) * (1 - delta_u);
  const q12 = alpha_3 * (1 - delta_u) * delta_v;
  const q22 = alpha_4 * delta_v * delta_u;

  return q11 + q21 + q12 + q22;
}

export function filter_faces_by_alpha(avatar_mesh: SkinnedMesh, offset_map: DataTexture, channel = 1) {
  const alphaMap = offset_map.image;

  const model_geom = avatar_mesh.geometry;
  const model_indices = model_geom.index!; // needed for reverse op
  const alphas = [];

  // Calculate alpha for each vertex
  for (let vertex_idx = 0; vertex_idx < model_geom.attributes['position'].count; vertex_idx++) {
    const u = model_geom.attributes['uv'].getY(vertex_idx) * alphaMap.height;
    const v = model_geom.attributes['uv'].getX(vertex_idx) * alphaMap.width;
    // const alpha_1 = alphaMap.data[4 * (Math.floor(u + 1) * alphaMap.height + Math.floor(v)) + 1] / 255.0;
    // const alpha_2 = alphaMap.data[4 * (Math.floor(u) * alphaMap.height + Math.floor(v)) + 1] / 255.0;
    // const alpha_3 = alphaMap.data[4 * (Math.floor(u) * alphaMap.height + Math.floor(v + 1)) + 1] / 255.0;
    // const alpha_4 = alphaMap.data[4 * (Math.floor(u + 1) * alphaMap.height + Math.floor(u + 1)) + 1] / 255.0;
    const alpha_value = 1 - alphaMap.data[4 * (Math.floor(u) * alphaMap.height + Math.floor(v)) + channel] / 255.0;
    // 1 - blerp(alpha_1, alpha_2, alpha_3, alpha_4, v, u);
    alphas.push(alpha_value);
  }

  // Aggregation rule
  const index = model_geom.index!;
  const new_index_array = [];
  for (let idx = 0; idx < index.count / 3; idx++) {
    const a_1 = alphas[index.array[3 * idx]];
    const a_2 = alphas[index.array[3 * idx + 1]];
    const a_3 = alphas[index.array[3 * idx + 2]];
    const aggregated_face_alpha = a_1 + a_2 + a_3;
    if (aggregated_face_alpha < 2.01) {
      new_index_array.push(index.array[3 * idx]);
      new_index_array.push(index.array[3 * idx + 1]);
      new_index_array.push(index.array[3 * idx + 2]);
    }
  }

  model_geom.setIndex(new_index_array);
  index.needsUpdate = true;

  return model_indices;
}

export function set_avatar_indices(avatar_mesh: SkinnedMesh, indices: BufferAttribute) {
  avatar_mesh.geometry.setIndex(indices);
  avatar_mesh.geometry.index!.needsUpdate = true;
}

const fromLinear = `
vec3 fromLinear(vec3 linearRGB)
{
  bvec3 cutoff = lessThan(linearRGB, vec3(0.0031308));
  vec3 higher = vec3(1.055)*pow(linearRGB, vec3(1.0/2.4)) - vec3(0.055);
  vec3 lower = linearRGB * vec3(12.92);

  return mix(higher, lower, cutoff);
}`;

function _computeTexture(texture_size: number, shader_code: string, uniforms: Record<string, Uniform<any>>) {
  const gpuCompute = new GPUComputationRenderer(texture_size, texture_size, AvatarView.renderer);
  gpuCompute.setDataType(UnsignedByteType);
  const res_texture = gpuCompute.createTexture();
  const resultVar = gpuCompute.addVariable('resTexture', shader_code, res_texture);

  Object.entries(uniforms).forEach(([k, v]) => {
    resultVar.material.uniforms[k] = v;
  });

  const error = gpuCompute.init();
  if (error !== null) {
    console.error(error);
  }

  gpuCompute.compute();

  const pixelBuffer = new Uint8Array(texture_size * texture_size * 4);

  AvatarView.renderer.readRenderTargetPixels(
    gpuCompute.getCurrentRenderTarget(resultVar),
    0,
    0,
    texture_size,
    texture_size,
    pixelBuffer
  );
  const res = new DataTexture(pixelBuffer, texture_size, texture_size, RGBAFormat, UnsignedByteType);
  res.encoding = sRGBEncoding;
  res.generateMipmaps = false;
  res.minFilter = LinearFilter;
  res.magFilter = LinearFilter;
  res.userData.mimeType = 'image/jpeg';
  return res;
}

export function update_avatar_skin_texture(
  material_: MeshStandardMaterial,
  texture_size: number,
  with_haircap: boolean
) {
  const material = material_ as MeshStandardMaterial;

  const new_texture = _compute_avatar_texture(material, Avatar, texture_size, with_haircap);
  material.map = new_texture;
  material.map.needsUpdate = true;
  material.needsUpdate = true;
}

export function compute_transformed_alpha_map(offset_map: Texture, transform_map: Texture, texture_size: number) {
  const uniforms = {
    alphaMap: new Uniform<Texture>(offset_map),
    transformMap: new Uniform<Texture>(transform_map),
  };

  const shader_text = `
  uniform sampler2D alphaMap;
  uniform sampler2D transformMap;

  void main()	{
    vec2 vUv = gl_FragCoord.xy / resolution.xy;
    vec4 p = texture2D( transformMap, vUv );
    gl_FragColor = vec4(texture2D( alphaMap, vec2(p.g, 1.0-p.b)).g);
  }
  `;
  const res = _computeTexture(texture_size, shader_text, uniforms as Record<string, Uniform<any>>);
  return res;
}

export function textureToArray(texture: Texture, texture_size: number) {
  const uniforms = {
    texture1: new Uniform<Texture>(texture),
  };

  const shader_text = `
  uniform sampler2D texture1;

  void main()	{
    vec2 vUv = gl_FragCoord.xy / resolution.xy;
    gl_FragColor = texture2D( texture1, vUv );
  }
  `;
  const res = _computeTexture(texture_size, shader_text, uniforms as Record<string, Uniform<any>>);
  return res;
}

export function update_eye_texture(eye_mesh: SkinnedMesh, texture_size = 256) {
  const material = eye_mesh.material as MeshStandardMaterial;
  const new_texture = compute_eye_texture(material, texture_size);

  material.map = new_texture;
  material.map.needsUpdate = true;
  material.needsUpdate = true;
}

export function update_eyelash_texture(eyelash_mesh: SkinnedMesh, texture_size = 256) {
  const material = eyelash_mesh.material as MeshStandardMaterial;
  const new_texture = compute_eyelash_texture(material, texture_size);

  material.map = new_texture;
  material.map.needsUpdate = true;
  material.needsUpdate = true;
}

function compute_eyelash_texture(eye_material: MeshStandardMaterial, texture_size: number) {
  const uniforms: { textureEyelash?: any; colorEyelash?: any } = {};
  uniforms['textureEyelash'] = new Uniform<Texture>(eye_material.map!);
  uniforms['colorEyelash'] = new Uniform<Color>(eye_material.color);

  const shader_text = `
  uniform sampler2D textureEyelash;
  vec3 colorEyelash;

  void main()	{

    vec2 vUv = gl_FragCoord.xy / resolution.xy;
    vec4 color = texture2D( textureEye, vUv );

    gl_FragColor = vec4( fromLinear(color.xyz * colorEyelash), color.a );
  }
  `;
  const res = _computeTexture(texture_size, shader_text, uniforms as unknown as Record<string, Uniform<any>>);
  return res;
}

function compute_eye_texture(eye_material: MeshStandardMaterial, texture_size: number) {
  const uniforms: EyeShaderParams & { textureEye?: any } = Object.assign({}, AvatarAnimatable.eye_params);
  uniforms['textureEye'] = new Uniform<Texture>(eye_material.map!);

  const shader_text = `
  uniform sampler2D textureEye;
  ${color_ramp_pars}
  ${eye_pars}
  ${fromLinear}

  void main()	{

    vec2 vUv = gl_FragCoord.xy / resolution.xy;
    vec3 color = texture2D( textureEye, vUv ).rgb;

    color = eye_color(color, vUv);
    color = fromLinear(color);

    gl_FragColor = vec4(color, 1.0);
  }
  `;
  const res = _computeTexture(texture_size, shader_text, uniforms as unknown as Record<string, Uniform<any>>);
  return res;
}

export function update_ORM_map(mesh: SkinnedMesh, texture_size = 256) {
  // Packs ORM, if channels are separate
  const material = mesh.material as MeshStandardMaterial;
  const new_texture = compute_ORM_texture(material, texture_size);
  material.roughnessMap = new_texture;
  material.metalnessMap = new_texture;

  if (material.aoMap) {
    material.aoMap = new_texture;
  }

  material.roughness = 1.0;
  material.metalness = 1.0;
  material.aoMapIntensity = 1.0;

  material.needsUpdate = true;
}

function _getDummyTexture() {
  const data = new Uint8Array(4 * 4);
  data.fill(255);
  const texture = new DataTexture(data, 2, 2, RGBAFormat);
  texture.needsUpdate = true;
  return texture;
}

function compute_ORM_texture(material: MeshStandardMaterial, texture_size: number) {
  const uniforms = {
    rmMap: new Uniform<Texture>(material.roughnessMap || _getDummyTexture()), // it will also have metalness map
    aoMap: new Uniform<Texture>(material.aoMap || _getDummyTexture()),
    roughness: new Uniform<number>(material.roughness),
    metalness: new Uniform<number>(material.metalness),
    aoMapIntensity: new Uniform<number>(material.aoMapIntensity),
  };

  const shader_text = `
  uniform sampler2D rmMap;
  uniform sampler2D aoMap;
  uniform float roughness;
  uniform float metalness;
  uniform float aoMapIntensity;

  void main()	{

    vec2 vUv = gl_FragCoord.xy / resolution.xy;

    // rm
    vec2 rm = texture2D( rmMap, vUv ).gb;
    float r = clamp(rm.x * roughness, 0.0, 1.0);
    float m = clamp(rm.y * metalness, 0.0, 1.0);

    // o
    float o = texture2D( aoMap, vUv ).r;
    o = clamp( o * aoMapIntensity, 0.0, 1.0);

    gl_FragColor = vec4(o, r, m, 1.0);
  }
  `;
  const res = _computeTexture(texture_size, shader_text, uniforms as unknown as Record<string, Uniform<any>>);
  return res;
}

function _compute_avatar_texture(
  material: MeshStandardMaterial,
  avatarClass: typeof Avatar | typeof AvatarAnimatable,
  texture_size = 1024,
  with_haircap = true
) {
  const shader_text = `
      uniform sampler2D textureAvatar;
      ${brightness_contrast_pars}
      ${hair_cap_pars}
      ${curve_pars}
      ${fromLinear}

			void main()	{

				vec2 vUv = gl_FragCoord.xy / resolution.xy;
				vec3 color = texture2D( textureAvatar, vUv ).xyz;

        color = curve(color);
        color = brightness_contrast(color);
        ${with_haircap ? 'color = applyHairCap(color, vUv);' : ''}
        color = fromLinear(color);

        gl_FragColor = vec4( color, 1.0);
			}
    `;

  const uniforms: Record<any, Uniform<any>> = {};
  [avatarClass.curve_params, hairCapShaderParams, avatarClass.brightness_contrast_params].forEach((p) => {
    Object.entries(p).forEach(([k, v]) => {
      uniforms[k] = v;
    });
  });
  uniforms['textureAvatar'] = new Uniform<Texture>(material.map!);

  const res = _computeTexture(texture_size, shader_text, uniforms as unknown as Record<string, Uniform<any>>);

  return res;
}

function _compute_hair_texture(hair_texture: Texture, texture_size = 1024) {
  const shader_text = `
      uniform sampler2D textureHair;
      ${color_ramp_pars}
      ${fromLinear}

			void main()	{

				vec2 vUv = gl_FragCoord.xy / resolution.xy;
				vec4 rgba = texture2D( textureHair, vUv );
        vec3 color = rgba.rgb;

        color = color_ramp(color);
        color = clamp(fromLinear(color), 0.0, 1.0);

        if (rgba.a < 2.0/255.0) {
          color = vec3(0.0);
        }
        gl_FragColor = vec4( color , rgba.a);
			}
    `;

  const uniforms: Record<any, Uniform<any>> = {};
  Object.entries(color_ramp_params).forEach(([k, v]) => {
    uniforms[k] = v;
  });

  uniforms['textureHair'] = new Uniform<Texture>(hair_texture);

  const res = _computeTexture(texture_size, shader_text, uniforms as unknown as Record<string, Uniform<any>>);
  res.userData.mimeType = 'image/png';
  res.minFilter = LinearMipMapLinearFilter;
  res.magFilter = LinearMipMapLinearFilter;
  return res;
}

export function update_hair_diffuse(hair_group: Group, texture_size: number) {
  const prev_maps: string[] = [];
  const new_maps: Texture[] = [];

  hair_group.children.forEach((node) => {
    const material = (node as SkinnedMesh).material as MeshStandardMaterial;

    if (material.userData['name'].includes('no_ramp')) {
      return;
    }

    const oldMapUuid = material.map!.uuid;

    // cache-like thing -- check if we've already seen this map
    const idx = prev_maps.findIndex((v: string) => v == material.map!.uuid);
    if (idx != -1) {
      material.map = new_maps[idx];
      return;
    }

    const new_texture = _compute_hair_texture(material.map!, texture_size);
    // let new_texture: CanvasTexture | Texture = computeHairTextureJS(texture_size, material);
    // new_texture.generateMipmaps = true;

    material.map = new_texture;
    material.map.needsUpdate = true;
    material.needsUpdate = true;

    prev_maps.push(oldMapUuid);
    new_maps.push(material.map!);
  });
}

export function reset_hair_maps(hair_group: Group) {
  hair_group.children.forEach((node: any) => {
    restoreOldMap((node as SkinnedMesh).material as MeshStandardMaterial);
  });
}

export function restoreOldMap(material: MeshStandardMaterial & { oldMap?: Texture }) {
  if (!material.oldMap) return;

  const m = material.map;
  queueMicrotask(() => {
    m?.dispose();
  });

  material.map = material.oldMap!;
  material.map!.needsUpdate = true;
  material.needsUpdate = true;

  delete material.oldMap;
}

export function convert_to_opaque(mat: MeshStandardMaterial) {
  const m = mat.clone();
  m.alphaMap = null;
  m.alphaTest = 0;
  return m;
}
