import { FloatType, LinearFilter, Matrix4, Mesh, PlaneGeometry, Scene, ShaderMaterial, Vector4, WebGLRenderTarget } from "three";
const ssrMaterial = new ShaderMaterial({
  vertexShader: `
    out vec2 vUv;
	out mat4 vProjectionMatrix;
	out mat4 vViewMatrix;
	void main() {
		vUv = uv;
        vViewMatrix = viewMatrix;
		vProjectionMatrix = projectionMatrix;
		gl_Position = vec4( position, 1.0 );
	}
  `,
  fragmentShader: `
	in vec2 vUv;
	in mat4 vProjectionMatrix;
	in mat4 vViewMatrix;
	uniform sampler2D uPosition;
	uniform sampler2D uNormal;
	uniform sampler2D uRadiance;
	uniform sampler2D uEnvmap;
	uniform mat4 uViewMatrixInverse;
	#define PI 3.14159265359
	#define TwoPI 6.28318530718
	#define OneOverPI 0.31830988618
	float rand(float co) { return fract(sin(co*(91.3458)) * 47453.5453); }
    float rand(vec2 co)  { return fract(sin(dot(co.xy ,vec2(12.9898,78.233))) * 43758.5453); }
    float rand(vec3 co)  { return rand(co.xy+rand(co.z)); }
    vec3 getEnvmapRadiance(vec3 idir) {
        vec3 dir = vec3(idir.zyx);
        // skybox coordinates
        vec2 skyboxUV = vec2(
            (atan(dir.x, dir.z) + PI) / (PI * 2.0),
            (asin(dir.y) + PI * 0.5) / (PI)
        );
        vec3 col = vec3(0.0);
        // col = RGBEToLinear(texture2D(uEnvmap, skyboxUV)).xyz;
        col = texture2D(uEnvmap, skyboxUV).xyz;
        return col * 0.1;
    }
	float depthBufferAtP(vec3 p) {
		vec4 projP = vProjectionMatrix * vec4(p, 1.0);
		vec2 pNdc = (projP / projP.w).xy;
		vec2 pUv  = pNdc * 0.5 + 0.5;
		float depthAtPointP = abs(texture2D(uPosition, pUv).z);
		if(depthAtPointP == 0.0) depthAtPointP = 9999999.0; 
		return depthAtPointP;
	}
	bool intersect(
		vec3 ro, vec3 rd, float sampl,
		out vec3 intersectionP,
		out vec3 lastP) 
	{
		bool jitter = true;
		float startingStep = 1.25;
		float stepMult = 1.35;
		const int steps = 30;
		const int binarySteps = 7;
		// float startingStep = 0.095;
		// float stepMult = 1.35;
		// const int steps = 25;
		// const int binarySteps = 5;
		float maxIntersectionDepthDistance = 1.5;
		float step = startingStep;
		vec3 p = ro;
		bool intersected = false;
		bool possibleIntersection = false;
		float lastRecordedDepthBuffThatIntersected;
		vec3 p1, p2;
		vec3 initialP = p;
		for(int i = 0; i < steps; i++) {
			float randIndexer1 = float(gl_FragCoord.x * 98.18 + gl_FragCoord.y * 197.88) * 0.171897 + sampl * 11.789;
			// at the end of the loop, we'll advance p by jittB to keep the jittered sampling in the proper "cell" 
			float jittA = rand(randIndexer1);
			if(!jitter) jittA = 1.0;
			float jittB = 1.0 - jittA;
			p += rd * step * jittA;
			vec4 projP = vProjectionMatrix * vec4(p, 1.0);
			vec2 pNdc = (projP / projP.w).xy;
			vec2 pUv  = pNdc * 0.5 + 0.5;
			float depthAtPosBuff = abs(texture2D(uPosition, pUv).z);
			if(depthAtPosBuff == 0.0) {
				depthAtPosBuff = 9999999.0;
			} 
			// out of screen bounds condition
			if(pUv.x < 0.0 || pUv.x > 1.0 || pUv.y < 0.0 || pUv.y > 1.0 || p.z > 0.0) {
				break;
			}
			float depthAtPointP = abs(p.z);
			if(depthAtPointP > depthAtPosBuff) {
				// intersection found!
				p1 = initialP;
				p2 = p;
				lastRecordedDepthBuffThatIntersected = depthAtPosBuff;
				possibleIntersection = true;
				break;
			}
			// initialP needs to be the last jittered sample, and can't just be the "p" value at the start
			// of the loop iteration, otherwise you run the risk of having both p1 and p2 at the same side of the depth buffer
			// and (apparently) for the binary search to work properly you need to have p1 and p2 on different sides of the depth buffer
			// p1 at the side of the depth buffer plane that it's closer to the camera, and p2 at the other side
			initialP = p;
			p += rd * step * jittB;
			step *= stepMult; // this multiplication obviously need to appear AFTER we add jittB
		}
		// stranamente mi trovo a dover spostare la binary search fuori dal primo loop, altrimenti
		// per qualche motivo esoterico la gpu inizia a prendere fuoco
		// ******** binary search start *********
		for(int j = 0; j < binarySteps; j++) {
			vec3 mid = (p1 + p2) * 0.5;
			float depthAtMid = abs(mid.z);
			float depthAtPosBuff = depthBufferAtP(mid);
			if(depthAtMid > depthAtPosBuff) {
				p2 = (p1 + p2) * 0.5;
				// we need to save this value inside this if-statement otherwise if it was outside and above this
				// if statement, it would be possible that it's value would be very large (e.g. if p1 intersected the "background"
				// since in that case positionBufferAtP() returns viewDir * 99999999.0)
				// and if that value is very large, it can create artifacts when evaluating this condition:
				// ---> if(abs(distanceFromCameraAtP2 - lastRecordedDepthBuffThatIntersected) < maxIntersectionDepthDistance) 
				// to be honest though, these artifacts only appear for largerish values of maxIntersectionDepthDistance
				lastRecordedDepthBuffThatIntersected = depthAtPosBuff;
			} else {
				p1 = (p1 + p2) * 0.5;
			}
		}
		// ******** binary search end   *********
		intersectionP = p2;
		lastP = p;
		// use p2 as the intersection point
		float depthAtP2 = abs(p2.z);
		if( possibleIntersection &&   // without using possibleIntersection apparently it's possible that lastRecordedDepthBuffThatIntersected
																  // ends up being valid thanks to the binary search, and that causes all sorts of troubles
			abs(depthAtP2 - lastRecordedDepthBuffThatIntersected) < maxIntersectionDepthDistance) {
			// intersection validated
			intersected = true;
		}
		return intersected;
	}
	void main() {
		vec4 positionTexel = texture2D(uPosition, vUv);
		float meshId = positionTexel.w;
		vec3 radianceVS = texture2D(uRadiance, vUv).xyz;
		vec3 positionVS = positionTexel.xyz;
		vec3 normalVS   = texture2D(uNormal, vUv).xyz;
		vec3 viewDir    = -normalize(positionVS);
		float depth = abs(positionVS.z);
		vec3 radianceSum = vec3(0.0);
		float brdf;
		float pdf;
		vec3 rd = normalize(reflect(-viewDir, normalVS)); 
		vec3 ro = positionVS + rd * max(0.01, 0.015 * depth);
	 
		// vec3 mult = vec3(1.0);
		// // irradiance factor over normal at point
		// // https://i.redd.it/802mndge03t01.png
		// mult *= clamp(dot(rd, normalVS), 0.0, 1.0);
		// // I'm not sure if this is true or not but basically here I'm
		// // multiplying by the BRDF between the bounced surface which is hopefully directly lit
		// // and the surface that I'm trying to shade. Lambertian BRDF is albedo / PI
		// mult *= albedo / PI;
		float maxIntersectionDepthDistance = 1.5;
		vec3 p2;
		vec3 lastP;
		bool intersected = intersect(ro, rd, 0.0, p2, lastP);
		float meshIdAtBounce;
		vec2 p2Uv;
		if(intersected) {
			// intersection validated
			vec4 projP2 = vProjectionMatrix * vec4(p2, 1.0);
			p2Uv = (projP2 / projP2.w).xy * 0.5 + 0.5;
			vec3 radianceAtBounce = texture2D(uRadiance, p2Uv).xyz;
			// non moltiplichiamo nè per la brdf, nè per la pdf, nè per cos(theta)
			// perchè per quello che ho capito, l'equazione è:
			// Li * brdf * (1 / pdf) * cos(theta)
			// Li * (1 / PI) * (1 / (cos(theta) / PI)) * cos(theta)
			// Li * (1 / PI) * (PI / cos(theta)) * cos(theta)
			// = Li
			// e quindi si semplifica tutto via :/ bho
			// radianceSum += radianceAtBounce * mult;
			meshIdAtBounce = texture2D(uPosition, p2Uv).w;
			radianceSum += radianceAtBounce;
		}
		// uncomment for the real reflection distance
    // float reflDistance = length(p2 - positionVS);
		// in reality, reflDistance is recording the y-value of the reflected
		// skyscraper instead of actually calculating the length of the reflection
    float reflDistance = (uViewMatrixInverse * vec4(p2, 1.0)).y;
    // because rd is in viewSpace
    vec3 envRD = transpose(mat3(vViewMatrix)) * rd; 
		if (intersected && meshIdAtBounce > 9.5) {
			intersected = !intersected;
		}
		// stronger reflections to the buildings in the far back 
		radianceSum *= clamp((abs(p2.z) - 1200.0) * 1.0, 1.0, 2.5);
		// changing reflections colors to every mesh of the main building, 
		// to which we assigned meshId = 1 
		if (intersected && meshIdAtBounce > 0.5 && meshIdAtBounce < 1.5) {
			radianceSum *= vec3(0.5, 0.3, 0.10) * 1.3;
			// radianceSum = mix(radianceSum, radianceSum + (radianceSum - vec3(0.5)) * 0.05, 1.0);
		}
    if (!intersected) {
      reflDistance = 0.0;
      radianceSum = getEnvmapRadiance(envRD);
    }
		gl_FragColor = vec4(radianceSum, reflDistance);
		// gl_FragColor = vec4(vec3(reflDistance), 1.0);
		// gl_FragColor = vec4(radianceSum, 1.0);
	}
`,
  uniforms: {
    uRadiance: { value: null, type: "t" },
    uPosition: { value: null, type: "t" },
    uNormal:   { value: null, type: "t" },
    uEnvmap:   { value: null, type: "t" },
	uViewMatrixInverse: { value: new Matrix4() },
  },
	name: "ssrMaterial",
});
const ssrQuadMesh = new Mesh(new PlaneGeometry(2, 2), ssrMaterial);
const ssrScene = new Scene();
ssrScene.add(ssrQuadMesh);
ssrQuadMesh.frustumCulled = false;
export function computeSSR(gl, camera, ssrRT, radianceRT, gBuffer, envMap) {
	gl.setRenderTarget(ssrRT);
	ssrMaterial.uniforms.uEnvmap.value = envMap;
	ssrMaterial.uniforms.uRadiance.value = radianceRT.texture;
	ssrMaterial.uniforms.uNormal.value = gBuffer.texture[0];
	ssrMaterial.uniforms.uPosition.value = gBuffer.texture[1];
	ssrMaterial.uniforms.uViewMatrixInverse.value = camera.matrixWorld;
	gl.render(ssrScene, camera);
	gl.setRenderTarget(null);
}
