import package::shaders::pbr::pbr_material;
import package::shaders::pbr::pbr_material_lightmap;
import package::shaders::math::calculate_camera_pos_worldspace;
// const PI: f32 = 3.14159265359;

struct VertexInput {
    @location(0) a_position: vec3<f32>,
    @location(1) a_normal: vec3<f32>,
    @location(2) a_tangent: vec3<f32>,
    @location(3) a_uv: vec2<f32>,
};

struct VertexOutput {
    @builtin(position) position: vec4<f32>,
    @location(0) v_uv: vec2<f32>,
    @location(1) v_normal_worldspace: vec3<f32>,
    @location(2) v_tangent_worldspace: vec3<f32>,
    @location(3) v_pos_worldspace: vec3<f32>,
    @location(4) v_camera_pos_worldspace: vec3<f32>,
    @location(5) v_material_adjustment: vec3<f32>,
};

// Vertex shader inputs
struct VsUniforms {
    g_projection_from_world: mat4x4<f32>,
    g_projection_from_model: mat4x4<f32>,
    g_camera_from_model: mat4x4<f32>,
    g_camera_from_world: mat4x4<f32>,
    g_world_from_model: mat4x4<f32>,
    g_light_dir_worldspace_norm: vec3<f32>,
    g_app_time: f32,
    g_simulation_frame_ratio: f32,
    scale: f32,
    instance_move: vec3<f32>,
};

struct Particle {
    position: vec3<f32>,
    velocity: vec3<f32>,
    upvector: vec3<f32>,
} 

@group(0) @binding(0) var<uniform> context: VsUniforms;


struct FsUniforms {
    g_light_projection_from_world: mat4x4<f32>,
    g_camera_from_world: mat4x4<f32>,
    g_projection_from_camera: mat4x4<f32>,
    g_chart_time: f32,
    g_app_time: f32,
    g_light_dir_camspace_norm: vec3<f32>,
    g_light_dir_worldspace_norm: vec3<f32>,
    g_lightspace_from_world: mat3x3<f32>,
    light_color: vec4<f32>,
    roughness: f32,
    metallic: f32,
    ambient: f32,
    normal_strength: f32,
    shadow_bias: f32,
    color: vec3<f32>,
    emissive: f32,
    ao_adjust: f32,
    _pad: vec4f,
};

// Fragment shader inputs
@group(1) @binding(0) var<uniform> u: FsUniforms;
@group(1) @binding(1) var envmap: texture_2d<f32>;
@if(!ENTRY_POINT_FS_MAIN_NOOP) 
@group(1) @binding(2) var shadow: texture_depth_2d;
@group(1) @binding(3) var base_color_map: texture_2d<f32>;
@group(1) @binding(4) var roughness_map: texture_2d<f32>;
@group(1) @binding(5) var metallic_map: texture_2d<f32>;
@group(1) @binding(6) var normal_map: texture_2d<f32>;
@group(1) @binding(7) var brdf_lut: texture_2d<f32>;
@group(1) @binding(8) var emissive_map: texture_2d<f32>;
@group(1) @binding(9) var ambient_occlusion_map: texture_2d<f32>;
@group(1) @binding(10) var light_map: texture_2d<f32>;

@group(1) @binding(11) var sampler_envmap: sampler;
@group(1) @binding(12) var sampler_shadow: sampler_comparison;
@group(1) @binding(13) var sampler_repeat: sampler;

fn sample_shadow_map(world_pos: vec3<f32>, shadow: texture_depth_2d) -> f32 {
    var lightspace_pos = (u.g_light_projection_from_world * vec4<f32>(world_pos, 1.0)).xyz;
    lightspace_pos = lightspace_pos * vec3f(0.5, -0.5, 1) + vec3f(0.5, 0.5, u.shadow_bias * -0.001);
    return textureSampleCompare(shadow, sampler_shadow, lightspace_pos.xy, lightspace_pos.z);
}

fn adjust(value: f32, factor: f32) -> f32 {
    if factor < 0.0 {
        return value * (1.0 + factor);
    }
    return factor + value * (1.0 - factor);
}

fn sample_srgb_as_linear(map: texture_2d<f32>, uv: vec2<f32>) -> vec3<f32> {
    let v = textureSample(map, sampler_repeat, uv).rgb;
    return pow(v, vec3<f32>(1.0 / 2.2));
}

@if(!ENTRY_POINT_FS_MAIN_NOOP) 
@fragment
fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> { 
    let lightness = sample_shadow_map(in.v_pos_worldspace, shadow);

    @if(TEXTURE_BOUND_TO_ROUGHNESS_MAP)
    let roughness = adjust(sample_srgb_as_linear(roughness_map, in.v_uv).r, u.roughness);

    @if(!TEXTURE_BOUND_TO_ROUGHNESS_MAP)
    let roughness = u.roughness;

    @if(TEXTURE_BOUND_TO_METALLIC_MAP)
    let metallic = adjust(sample_srgb_as_linear(metallic_map, in.v_uv).r, u.metallic);

    @if(!TEXTURE_BOUND_TO_METALLIC_MAP)
    let metallic = u.metallic;

    @if(TEXTURE_BOUND_TO_BASE_COLOR_MAP)
    var base_color = textureSample(base_color_map, sampler_repeat, in.v_uv).rgb * u.color;

    @if(!TEXTURE_BOUND_TO_BASE_COLOR_MAP)
    var base_color = u.color;

    @if(!TEXTURE_BOUND_TO_LIGHT_MAP)
    var color = pbr_material(in.v_uv, in.v_pos_worldspace, in.v_normal_worldspace, 
        in.v_tangent_worldspace, 
        in.v_camera_pos_worldspace, u.g_light_dir_worldspace_norm,
        u.normal_strength, u.light_color.rgb * lightness, vec3f(u.ambient), 
        base_color, roughness, metallic, normal_map, 
        envmap, brdf_lut, 
        sampler_repeat, sampler_envmap);  

    @if(TEXTURE_BOUND_TO_LIGHT_MAP)
    var color = pbr_material_lightmap(in.v_uv, in.v_pos_worldspace, in.v_normal_worldspace, 
        in.v_tangent_worldspace, 
        in.v_camera_pos_worldspace, u.g_light_dir_worldspace_norm,
        u.normal_strength, u.light_color.rgb * lightness, vec3f(u.ambient), 
        base_color, roughness, metallic, normal_map, 
        envmap, brdf_lut, light_map,
        sampler_repeat, sampler_envmap, u.g_lightspace_from_world); 

    @if(TEXTURE_BOUND_TO_EMISSIVE_MAP) 
    {
        color = color + textureSample(emissive_map, sampler_repeat, in.v_uv).rgb * u.emissive;
    }

    @if(TEXTURE_BOUND_TO_AMBIENT_OCCLUSION_MAP)
    {
        let ao = textureSample(ambient_occlusion_map, sampler_repeat, in.v_uv).r;
        color = color * mix(1.0, ao, u.ao_adjust);
    }

    return vec4<f32>(color, 1.0);  
}

@fragment
fn fs_main_noop(in: VertexOutput) {}

@vertex
fn vs_main(input: VertexInput, @builtin(instance_index) instance_index: u32) -> VertexOutput {
    var output: VertexOutput;

    let scale = pow(2.0, context.scale);

    output.v_pos_worldspace = (context.g_world_from_model * vec4<f32>(input.a_position * scale, 1.0)).xyz;

    output.position = context.g_projection_from_world * vec4<f32>(output.v_pos_worldspace, 1.0);

    output.v_uv = input.a_uv;
    output.v_normal_worldspace = (context.g_world_from_model * vec4<f32>(input.a_normal, 0.0)).xyz;
    output.v_tangent_worldspace = (context.g_world_from_model * vec4<f32>(input.a_tangent, 0.0)).xyz;
    output.v_camera_pos_worldspace = calculate_camera_pos_worldspace(context.g_camera_from_world);

    return output;
}
