#include "data\\shaders\\common.h"
#include "data\\shaders\\input_formats.h"
#include "data\\shaders\\lights.h"
#include "data\\shaders\\pcss.h"

StructuredBuffer<CullableLight> g_cullable_lights : register(t1, space1);

Texture2D g_depth : register(t0);
Texture2D g_shadow_maps[10] : register(t0, space2);

SamplerState g_sam_linear : register(s3);

// Mie scaterring approximated with Henyey-Greenstein phase function.
//https://www.alexandre-pestana.com/volumetric-lights/
float ComputeScattering(float lightDotView, float scattering)
{
float result = 1.0f - scattering * scattering;
result /= (4.0f * PI * pow(1.0f + scattering * scattering - (2.0f * scattering) * lightDotView, 1.5f));
return result;
}

float4 main(VertexOutPIOut pin) : SV_Target
{
  CullableLight light = g_cullable_lights[pin.instance_id];
  float2 half_size = g_screen_size * 0.5f;
  
  float2 uv = pin.clip_pos.xy / pin.clip_pos.w * float2(0.5f, -0.5f) + 0.5f;
  float depth = g_depth.Sample(g_sam_linear, uv).x;
  float3 pos = ScreenToView( float4(uv, depth, 1.0f)).xyz;
  float3 pos_v = pin.pos_v;
  
  float d = pos.z;
  if (pos.z > pos_v.z)
    pos = pos_v;
  
  float3 eye_pos = float3(0,0,0);
  float dist = length(pos);
  float3 view_dir = normalize(pos);
  
  float3 start = eye_pos;
  float3 end = pos;
   
  uint sample_count = 16;
  float step_size = distance(start,end) / sample_count;
  float3 step = view_dir * step_size;
  
  float density = light.scatt_density;
  float multiplier = light.scatt_multiplier;
  float falloff = light.scatt_falloff;
  
  pos = start + step * Dither(uv*half_size);
  float3 accumulation = 0.0f.xxx;
  for (uint i = 0; i < sample_count; ++i)
  {
    float dist_to_light = distance(pos, light.position);
    float dist_to_eye = distance(pos, eye_pos);
  
    float d = falloff * saturate(dist_to_light / light.radius);
    float l = falloff * saturate(dist_to_eye / light.radius);
    
    float d2pi4 = abs(d * d * 4.0f * PI) + 0.00001f;
    
    float3 lin = exp(-d * density) * multiplier / ( d2pi4 );
    float3 li = lin * density * 0.01f;
    
    float3 attenuation = point_light_falloff(dist_to_light, light.radius);
    attenuation *= li * exp(-l * density) * step_size;
    //attenuation *= ComputeScattering(saturate(dot(light.position - pos, pos)), 0.86);
  
    accumulation += attenuation;
    pos += step;
  }
  
  accumulation /= sample_count;
  
  return float4(max(0.0f,accumulation) * light.color * light.intensity, d);
}