import { Color, Shader } from 'three';
import { Uniform } from './_common';

export interface ColorRampParams {
  ramp_color1: Uniform<Color>;
  ramp_color2: Uniform<Color>;
  ramp_color3: Uniform<Color>;

  ramp_val1: Uniform<number>;
  ramp_val2: Uniform<number>;
  ramp_val3: Uniform<number>;
}

const default_color_ramp_params: ColorRampParams = {
  // ramp_color1: new THREE.Color(0.023, 0.011, 0.008).convertLinearToSRGB(),
  // ramp_color2: new THREE.Color(0.152, 0.056, 0.037).convertLinearToSRGB(),
  // ramp_color3: new THREE.Color(0.339, 0.152, 0.116).convertLinearToSRGB(),

  // Brown
  // ramp_color1: new THREE.Color(0.023, 0.011, 0.008),
  // ramp_color2: new THREE.Color(0.152, 0.056, 0.037),
  // ramp_color3: new THREE.Color(0.339, 0.152, 0.116),

  ramp_color1: new Uniform(new Color(0.103, 0.075, 0.057)),
  ramp_color2: new Uniform(new Color(0.788, 0.527, 0.266)),
  ramp_color3: new Uniform(new Color(0.889, 0.432, 0.174)),

  // Blond
  // ramp_color1: new THREE.Color(0.103, 0.075, 0.057),
  // ramp_color2: new THREE.Color(0.788, 0.527, 0.266),
  // ramp_color3: new THREE.Color(0.889, 0.432, 0.174),

  // ramp_val1: 0.136,
  // ramp_val2: 0.259,
  // ramp_val3: 0.341,

  // new
  ramp_val1: new Uniform(0.027),
  ramp_val2: new Uniform(0.414),
  ramp_val3: new Uniform(0.791),
};

export const color_ramp_pars = `
uniform vec3 ramp_color1;
uniform vec3 ramp_color2;
uniform vec3 ramp_color3;

uniform float ramp_val1;
uniform float ramp_val2;
uniform float ramp_val3;

vec3 color_ramp(const in vec3 inputColor) {
  float avg = (inputColor.x + inputColor.y + inputColor.z) / 3.0;
  vec3 ramp_res;
  if (avg < ramp_val1) {
      ramp_res = ramp_color1;
  } else if (avg < ramp_val2) {
      ramp_res = mix(ramp_color1, ramp_color2, (avg - ramp_val1) / (ramp_val2 - ramp_val1) );
  } else if (avg < ramp_val3) {
      ramp_res = mix(ramp_color2, ramp_color3, (avg - ramp_val2) / (ramp_val3 - ramp_val2) );
  } else {
      ramp_res = ramp_color3;
  }

  return ramp_res;
}
`;

export function color_ramp(shader: Shader, color_ramp_params: ColorRampParams = default_color_ramp_params) {
  Object.entries(color_ramp_params).forEach(([k, v]) => {
    shader.uniforms[k] = v;
  });

  shader.fragmentShader = color_ramp_pars + shader.fragmentShader;

  shader.fragmentShader = shader.fragmentShader.replace(
    '#include <emissivemap_fragment>',
    `
    #include <emissivemap_fragment>
    diffuseColor.xyz = color_ramp(diffuseColor.xyz);
    `
  );
}
