#import bevy_sprite::mesh2d_vertex_output::VertexOutput

@group(2) @binding(0) var<uniform> time: f32;
@group(2) @binding(1) var<uniform> beat: f32;
@group(2) @binding(2) var<uniform> percentage: f32;

const pi = radians(180.0);

fn scaledSin(x: f32) -> f32 {
    // return scale(x, vec2(-1.0, 1.0));
    return (sin(x * (pi * 2)) + 1) / 2;
}

fn pride_color_from_index(index : i32) -> vec3<f32> {
    const red = vec3(1.0, 0.0, 0.0);
    const orange = vec3(1.0, 0.5, 0.0);
    const yellow = vec3(1.0, 1.0, 0.0);
    const green = vec3(0.0, 1.0, 0.0);
    const blue = vec3(0.0, 0.0, 1.0);
    const indigo = vec3(0.34, 0.0, 0.66);
    const violet = vec3(0.54, 0.0, 0.75);

    switch index {
	default: /* case 0 */ {
	    return red;
	}
        case 1: {
	    return orange;
	}
        case 2: {
	    return yellow;
	}
        case 3: {
	    return green;
	}
        case 4: {
	    return blue;
	}
        case 5: {
	    return indigo;
	}
        case 6: {
	    return violet;
	}
    };
}

fn pride_color_count() -> f32 {
    const first = 18.0;
    let b = beat - first;

    if beat > 192.0 {
        return 7.0 * (1.0 - ((beat - 192.0) / (212.0 - 192.0)));
    } else if b > 10 {
        return 7.0;
    } else if b > 9 {
        return 6.0;
    } else if b > 8 {
        return 5.0;
    } else if b > 4 {
        return 4.0;
    } else if b > 3 {
        return 3.0;
    } else if b > 2 {
        return 2.0;
    } else {
        return 1.0;
    }
};

fn pride_color_from_scalar(input : f32) -> vec3<f32> {
    let stripe_width = 1.0 / pride_color_count();
    let should_scroll = beat > 64;
    let timemul = select(time / 2, time * 2, beat > 80);
    let i = select(input, input + timemul, should_scroll);
    let stripe_index = i32(floor(i % 1.0 / stripe_width));

    return pride_color_from_index(stripe_index);
}

fn pride(input: vec2<f32>) -> vec3<f32> {
    return pride_color_from_scalar(input.y);
}

fn bg(input: vec2<f32>) -> vec3<f32> {
    let direction = vec2(sin(time * pi * 0.5), cos(time * pi * 0.5));
    let p = input + direction;
    let r = scaledSin((p.x) * 5);
    let b = scaledSin((p.y) * 5);

    // return vec3(r, 0, b);
    return pride(input);
}

fn cannaselector() -> f32 {
    if beat > 96 {
        return 1.0;
    } else {
        return 0.0;
    }
}

fn cannabization(input: vec2<f32>) -> f32 {
    let p = vec2(input.x * 3 - 1.5, input.y * 3 - 2.5);
    let rotdir = select(1.0, -1.0, beat > 128 && beat < 144);
    let rot = select(0.0, rotdir * time * pi * 1.39, (beat > 48 && beat < 64) || beat > 112);
    let radius = abs(length(p) * 2);
    let x = -atan2(p.y, p.x);
    let f = (1.0 + 0.9 * cos(8.0 * x + rot)) *
	    (1.0 + 0.1 * cos(24.0 * x + rot)) *
	    (0.9 + 0.1 * cos(200.0 * x + rot)) *
	    (1.0 + sin(x + rot));
    let g1 = 1.0 - smoothstep(f, f + (1.0 - (beat % 1.0)) + 0.02, radius);
    let g2 = f - radius;

    return mix(g1, g2, cannaselector());
}

fn cannacolor(cannaMask: f32) -> vec3<f32> {
    if beat > 64 {
        return pride_color_from_scalar(cannaMask);
    } else {
        let whiteness = cannaMask * (1.0 - (beat % 1.0));
        return vec3(whiteness, cannaMask, whiteness);
    }
}

fn percentage_brightness() -> f32 {
    if percentage < 0.1 {
        return percentage / 0.1;
    } else if percentage > 0.9 {
        return 1.0 - ((percentage - 0.9) / 0.1);
    } else {
        return 1.0;
    }
}

// Beat 0  -> Pride Init
// Beat 32 -> Green pulsating Cannabis
// Beat 64 -> Rotating Pride
// Beat 96 -> Full Cannabis
fn selector() -> f32 {
    if beat > 192 {
        return 1.0;
    } else if beat > 160 && beat < 192 {
        return beat % 1.0;
    } else if beat > 96 {
        return 1.0;
    } else if beat > 64 {
        return 0.0;
    } else if beat > 32 {
        return 1.0;
    } else {
        return 0.0;
    }
}


@fragment
fn fragment(mesh: VertexOutput) -> @location(0) vec4<f32> {
    let bg = bg(mesh.uv);
    let cannaMask = cannabization(mesh.uv);
    let canna = cannacolor(cannaMask);
    // let finalColor = bg;
    let finalColor = mix(bg, canna, selector());
    return percentage_brightness() * vec4(finalColor, 1.0);
    // return material_color * ((sin((mesh.uv.x * 2 * pi) + time) + 1) / 2);
}
