#include "data\\shaders\\common.h"
#include "data\\shaders\\input_formats.h"

Texture2D g_color_half : register(t0);
Texture2D g_coc : register(t1);
Texture2D g_color : register(t2);

SamplerState g_sam_linear : register(s3);
SamplerState g_sam_point : register(s1);

static const float GOLDEN_ANGLE = 2.39996323;
static const float MAX_BLUR_SIZE = 5.0f;
static const float RAD_SCALE = 0.2f;

static const float max_radius = 1.0f;

float GetBlurSize(float2 coc)
{
  if (coc.r > 0.0f)
  {
    return coc.r * max_radius * MAX_BLUR_SIZE * 5;
  }
  
  return coc.g * max_radius * MAX_BLUR_SIZE * 5;
}

float4 main(VertexTOut pin) : SV_Target
{
  float2 dim;
  g_color_half.GetDimensions(dim.x,dim.y);
  float2 pixel_size = 1.0f.xx / dim;
  
  float3 color = g_color.Sample(g_sam_linear, pin.uv).rgb;
  float2 coc = g_coc.Sample(g_sam_point, pin.uv).rg;

  float center_size = GetBlurSize(coc);
  
  if (length(coc) < 0.0001f)
    return float4(color,1.0f);
  
  float w = 1.0f;
  float total = w;
  float radius = RAD_SCALE;
  for (float ang = 0.0f; radius < MAX_BLUR_SIZE; ang += GOLDEN_ANGLE)
  {
    float2 tc = pin.uv + float2(cos(ang),sin(ang)) * pixel_size * radius;
    float3 sample = g_color_half.Sample(g_sam_linear, tc).rgb;
    
    float2 sample_coc = g_coc.Sample(g_sam_point, tc).rg;
    float sample_size = GetBlurSize(sample_coc);

    float m = smoothstep(radius - 0.5f, radius + 0.5f, sample_size);
    color += lerp(color/total, sample, m);
    
    total += w;
    radius += RAD_SCALE / radius;
  }

  return float4(color / total, 1.0f);
}