import { Color, Shader } from 'three';
import { Uniform } from './_common';

export interface FresnelParams {
  fresnel_dir: Uniform<Color>;
  fresnel_ior: Uniform<number>;
  fresnel_strength: Uniform<number>;
  fresnel_use_custom_direction: Uniform<boolean>;
}

const default_fresnel_params: FresnelParams = {
  // hair
  fresnel_dir: new Uniform(new Color(0.5, 0, 0.5)),
  // fresnel_ior: 1.1,
  // fresnel_strength: 0.10,
  // fresnel_ior: 1.43,
  // fresnel_strength: 0.06,
  fresnel_ior: new Uniform(1.17),
  fresnel_strength: new Uniform(0.04),
  fresnel_use_custom_direction: new Uniform(true),

  // fresnel_strength: 0.10 * 10,
  // fresnel_use_custom_direction: false

  // custom_direction_strenght: true
  // other
  // fresnel_dir: new Color(0.0, 0.0, 1.0),
  // fresnel_ior: 1.05,
  // fresnel_strength: 0.15 * 1,
  // fresnel_use_custom_direction: false
};

// let counter = 0;

export function fresnel(
  shader: Shader,
  fresnel_params: FresnelParams = default_fresnel_params,
  define_fns = true,
  counter = 0
) {
  Object.entries(fresnel_params).forEach(([k, v]) => {
    shader.uniforms[`${k}_${counter}`] = v;
  });

  const s = `
#ifdef OBJECTSPACE_NORMALMAP

	normal = texture2D( normalMap, vUv ).xyz * 2.0 - 1.0; // overrides both flatShading and attribute normals

	#ifdef FLIP_SIDED

		normal = - normal;

	#endif

	#ifdef DOUBLE_SIDED

		normal = normal * faceDirection;

	#endif

	normal = normalize( normalMatrix * normal );

#elif defined( TANGENTSPACE_NORMALMAP )

	vec3 mapN = texture2D( normalMap, vUv ).xyz * 2.0 - 1.0;
	mapN.xy *= normalScale;

	#ifdef USE_TANGENT

		normal = normalize( vTBN * mapN );

	#else

		normal = perturbNormal2ArbHair( - vViewPosition, normal, mapN, faceDirection, der );

	#endif

#elif defined( USE_BUMPMAP )

	normal = perturbNormalArb( - vViewPosition, normal, dHdxy_fwd(), faceDirection );

#endif

`;
  shader.fragmentShader = shader.fragmentShader.replace('#include <normal_fragment_maps>', s);

  shader.fragmentShader =
    `
    uniform vec3 fresnel_dir_${counter};
    uniform float fresnel_ior_${counter};
    uniform float fresnel_strength_${counter};
    uniform bool fresnel_use_custom_direction_${counter};
  ` + shader.fragmentShader;

  if (define_fns) {
    let fns = `
        struct Derivatives {

          vec3 q0;
          vec3 q1;
          vec2 st0;
          vec2 st1;

        };
        Derivatives der;
        float fresnel_dielectric_cos(float cosi, float eta) {
            /* compute fresnel reflectance without explicitly computing
            * the refracted direction */
            float c = abs(cosi);
            float g = eta * eta - 1.0 + c * c;
            float result;

            if (g > 0.0) {
                g = sqrt(g);
                float A = (g - c) / (g + c);
                float B = (c * (g + c) - 1.0) / (c * (g - c) + 1.0);
                result = 0.5 * A * A * (1.0 + B * B);
            }
            else {
                result = 1.0; /* TIR (no refracted component) */
            }

            return result;
        }
        float fresnel(float IOR, float dotNVi)  {
            float f = max(IOR, 1e-5);
            float eta = false ? 1.0 / f : f;

            return fresnel_dielectric_cos(dotNVi, eta);
        }

        vec3 perturbNormal2ArbHair( vec3 eye_pos, vec3 surf_norm, vec3 mapN, float faceDirection, Derivatives der) {

            vec3 q0 = der.q0;
            vec3 q1 = der.q1;
            vec2 st0 = der.st0;
            vec2 st1 = der.st1;

            vec3 N = surf_norm; // normalized

            vec3 q1perp = cross( q1, N );
            vec3 q0perp = cross( N, q0 );

            vec3 T = q1perp * st0.x + q0perp * st1.x;
            vec3 B = q1perp * st0.y + q0perp * st1.y;

            float det = max( dot( T, T ), dot( B, B ) );
            float scale = ( det == 0.0 ) ? 0.0 : faceDirection * inversesqrt( det );

            return normalize( T * ( mapN.x * scale ) + B * ( mapN.y * scale ) + N * mapN.z );

        }
    `;
    shader.fragmentShader = shader.fragmentShader.replace('void main() {', fns + '\nvoid main() {');
  }

  shader.fragmentShader = shader.fragmentShader.replace(
    '#include <alphatest_fragment>',
    `
    der.q0 = dFdx( (- vViewPosition).xyz );
		der.q1 = dFdy( (- vViewPosition).xyz );
		der.st0 = dFdx( vUv.st );
		der.st1 = dFdy( vUv.st );

    #include <alphatest_fragment>
    `
  );

  // the trick here is `cross( normal, vTBN[0])` instead of using just vTBN[1]. This makes the fresnel follow the normal map
  shader.fragmentShader = shader.fragmentShader.replace(
    '#include <emissivemap_fragment>',
    // 'vec3 outgoingLight = totalDiffuse + totalSpecular + totalEmissiveRadiance;',
    `
    // vec3 outgoingLight = totalDiffuse + totalSpecular + totalEmissiveRadiance;
        #include <emissivemap_fragment>
        {
          // #if defined( TANGENTSPACE_NORMALMAP )
          //     #ifdef USE_TANGENT
          //         vec3 normal_pivoted = normalize( vTBN * normalize(fresnel_dir) );
          //     #else
          //         vec3 normal_pivoted = perturbNormal2Arb( - vViewPosition, normal,  normalize(fresnel_dir), faceDirection );
          //     #endif
          // #endif

          vec3 normal_pivoted;
          if (fresnel_use_custom_direction_${counter}) {
            vec3 fresnel_dir_normed = normalize((fresnel_dir_${counter}.xyz - 0.5) * 2.0);

            #ifdef USE_TANGENT
              normal_pivoted = cross( normal, vTBN[0]);
            #else
            // normalize( vNormal )
            // normal_pivoted = normalize(perturbNormal2ArbHair( - vViewPosition, normal,  fresnel_dir_normed, faceDirection, der));
            #endif

          } else {
            normal_pivoted = normal;
          }

          float dotNVi = dot( normal_pivoted, normalize( vViewPosition ) );

          float fac = fresnel(fresnel_ior_${counter}, dotNVi) * fresnel_strength_${counter};

          // fac *= clamp(pow(dot( normal, normalize( vViewPosition ) ),0.1), 0.0, 1.0);
          const vec3 highlight_color = vec3(1.0, 1.0, 1.0);
          diffuseColor.xyz = diffuseColor.xyz + fac * highlight_color;
          // outgoingLight += fac * highlight_color*4.0;
          // diffuseColor.xyz = vTBN[1];
          // diffuseColor.xyz=  vec3(dotNVi, dotNVi, dotNVi);
        }
        `
  );
  // counter += 1;
  return counter;
}
