#include "data\\shaders\\common.h"

cbuffer DispatchParams : register( b4 )
{
    uint3   num_thread_groups;
    uint pad1;

    uint3   num_threads;
    uint pad2;
};

struct ComputeShaderInput
{
    uint3 group_id            : SV_GroupID;
    uint3 group_thread_id     : SV_GroupThreadID;
    uint3 dispatch_thread_id  : SV_DispatchThreadID;
    uint group_index          : SV_GroupIndex;
};

RWStructuredBuffer<Frustum> frustums : register( u0 );

[numthreads(TILE_SIZE,TILE_SIZE,1)]
void main(ComputeShaderInput In)
{
  uint2 dispatch_thread_id = In.dispatch_thread_id.xy;
  if (dispatch_thread_id.x >= num_threads.x || dispatch_thread_id.y >= num_threads.y)
    return;
  
  const float3 eye_pos = float3(0.0f,0.0f,0.0f);
  
  float4 screen_space[4];
  //top left
  screen_space[0] = float4(dispatch_thread_id.xy * TILE_SIZE / g_screen_size, 1.0f, 1.0f);
  //top right
  screen_space[1] = float4(float2(dispatch_thread_id.x+1,dispatch_thread_id.y)  * TILE_SIZE / g_screen_size, 1.0f, 1.0f);
  //bottom left
  screen_space[2] = float4(float2(dispatch_thread_id.x,dispatch_thread_id.y+1)  * TILE_SIZE / g_screen_size, 1.0f, 1.0f);
  //bottom right
  screen_space[3] = float4(float2(dispatch_thread_id.x+1,dispatch_thread_id.y+1)  * TILE_SIZE / g_screen_size, 1.0f, 1.0f);

  float3 view_space[4];
  for (int i = 0; i < 4; ++i)
  {
    view_space[i] = ScreenToView(screen_space[i]).xyz;
  }
  
  Frustum frustum;
  
  //left
  frustum.planes[0] = ComputePlane( eye_pos, view_space[2], view_space[0] );
  //right
  frustum.planes[1] = ComputePlane( eye_pos, view_space[1], view_space[3] );
  //top
  frustum.planes[2] = ComputePlane( eye_pos, view_space[0], view_space[1] );
  //bottom
  frustum.planes[3] = ComputePlane( eye_pos, view_space[3], view_space[2] );


  int index = dispatch_thread_id.x + (dispatch_thread_id.y * num_threads.x);
  frustums[index] = frustum;
}