#!/usr/bin/env python3
import ctypes
import itertools
import math
import optparse
import os
import random
import re
import sys

from PIL import Image
import pygame
from pygame.locals import *

import bassmusic
import ezgl
from ezgl import gl
import rocket

VirtualSize = (1280, 720)
VirtualAspect = VirtualSize[0] / VirtualSize[1]
Verbosity = 0

AssetDir = "stripes_assets"
DefaultMusic = None
DefaultMusic = os.path.join(AssetDir, "rahmschwein.mp3")
DefaultBPM = 125
DefaultRBP = 8
QuitAtRow = 2200

def fract(x): return x - math.floor(x)
def mix(a, b, x): return a + (b - a) * x
def vmix(va, vb, x): return tuple(a + (b - a) * x for a, b in zip(va, vb))

################################################################################
# MARK: TexturedRect

class TexturedRect(ezgl.Shader):
    vs = """
        attribute highp vec2 pos;
        uniform highp vec4 screenArea;
        uniform highp vec4 texArea;
        varying mediump vec2 texCoord;
        void main() {
            gl_Position = vec4(screenArea.xy + screenArea.zw * pos, 0.0, 1.0);
            texCoord = texArea.xy + texArea.zw * pos;
        }
    """
    fs = """
        uniform lowp sampler2D tex;
        varying mediump vec2 texCoord;
        void main() {
            gl_FragColor = texture2D(tex, texCoord);
        }
    """
    attributes = { 'pos': 0 }
    uniforms = [ 'screenArea', 'texArea' ]

    def __init__(self):
        ezgl.Shader.__init__(self)
        self.vbo = None

    def use_vbo(self):
        if self.vbo is None:
            self.vbo = gl.create_static_buffer(gl.ARRAY_BUFFER, type=gl.SHORT, data=[0,0,1,0,0,1,1,1])
        else:
            gl.BindBuffer(gl.ARRAY_BUFFER, self.vbo)
        gl.VertexAttribPointer(0, 2, gl.SHORT, gl.FALSE, 0, 0)
        gl.set_enabled_attribs(0)

    def prepare_draw(self, tex=None):
        if tex is not None:
            gl.set_texture(tex=tex, tmu=0)
        self.use_vbo()
        self.use()
        return self

    def draw(self, tex=None, x0=-1, y0=1, x1=1, y1=-1, s0=0, t0=0, s1=1, t1=1, prepare=True):
        if prepare:
            self.prepare_draw(tex)
        gl.Uniform4f(self.screenArea, x0, y0, x1 - x0, y1 - y0)
        gl.Uniform4f(self.texArea, s0, t0, s1 - s0, t1 - t0)
        gl.DrawArrays(gl.TRIANGLE_STRIP, 0, 4)
        return self

################################################################################
# MARK: utils

def GenerateCircleTexture(size):
    r = [x*x for x in range(1 - size, size, 2)]
    r0 = size * size
    prescale = 1.0 / (r0 - 2)
    def i2d():
        for y in r:
            for x in r:
                yield (x, y)
    data = [prescale * max(0, r0-x-y) for x, y in i2d()]
    data = bytes(int(0.5 + 255.0 * x * x) for x in data)
    tex = gl.make_texture(filter=gl.LINEAR)
    gl.TexImage2D(gl.TEXTURE_2D, 0, gl.LUMINANCE, size, size, 0, gl.LUMINANCE, gl.UNSIGNED_BYTE, data)
    return tex

class Sinusoid(object):
    def __init__(self, min_freq=0.0, max_freq=1.0, min_val=0.0, max_val=1.0):
        self.freq = random.uniform(min_freq, max_freq)
        self.phase = random.uniform(0, 2*math.pi)
        self.amp = (max_val - min_val) * 0.5
        self.offset = (max_val + min_val) * 0.5
    def __call__(self, t):
        return self.offset + self.amp * math.sin(self.phase + self.freq * t)

class Bob(object):
    def __init__(self, size):
        self.size = size
        self.scale = 2.0 - size
        self.x = Sinusoid(0.03125, 0.125 - 0.0625 * size)
        self.y = Sinusoid(0.03125, 0.125 - 0.0625 * size)
    def draw(self, t):
        x = self.x(t) * self.scale - 1.0
        y = self.y(t) * self.scale - 1.0
        TexturedRect.get_instance().draw(None, x, y, x+self.size, y+self.size*VirtualAspect)

################################################################################
# MARK: map modes

class MapMode(object):
    def draw(self):
        pass

class MapMode_Blank(MapMode):
    "empty map"
    mode_id = 0

class MapMode_Blobs(MapMode):
    "randomly moving blobs in the background"
    mode_id = 6106
    def __init__(self, count=8, seed=0x13375EED):
        random.seed(seed)
        self.circleTex = GenerateCircleTexture(64)
        self.bobs = [Bob(random.uniform(0.125, 1.0)) for i in range(count)]
    def draw(self):
        global sync
        gl.Enable(gl.BLEND)
        gl.BlendFunc(gl.ONE_MINUS_DST_COLOR, gl.ONE)
        gl.set_texture(tex=self.circleTex, tmu=0)
        for b in self.bobs:
            b.draw(sync.row)
        gl.Disable(gl.BLEND)

class MapMode_Logo(MapMode):
    "blurred TRBL logo"
    mode_id = 1337
    def __init__(self):
        self.logo = gl.make_texture(filter=gl.LINEAR)
        gl.load_texture(gl.TEXTURE_2D, Image.open(os.path.join(AssetDir, "trbl_logo_blur.jpg")))
    def draw(self):
        global sync, VirtualAspect
        r = TexturedRect.get_instance().prepare_draw(self.logo)
        x = 1.0 / VirtualAspect
        r.draw(x0=0.0, y0=0.0, x1=-x, y1=-1.0)
        r.draw(x0=0.0, y0=0.0, x1=+x, y1=-1.0)
        r.draw(x0=0.0, y0=0.0, x1=-x, y1=+1.0)
        r.draw(x0=0.0, y0=0.0, x1=+x, y1=+1.0)

class MapMode_ImageFile(MapMode):
    "generic image"
    mode_id = None
    def __init__(self, filename):
        self._name = os.path.basename(filename)
        self.tex = gl.make_texture(filter=gl.LINEAR)
        img = Image.open(filename)
        gl.load_texture(gl.TEXTURE_2D, img)
        self.aspect = img.size[0] / img.size[1]
    def draw(self):
        global VirtualAspect
        ex = min(self.aspect / VirtualAspect, 1.0)
        ey = min(VirtualAspect / self.aspect, 1.0)
        TexturedRect.get_instance().prepare_draw(self.tex).draw(x0=-ex, y0=-ey, x1=ex, y1=ey)

################################################################################
# MARK: color modes

class ColorMode(object):
    def get(self):
        return self.data
    @staticmethod
    def tobytes(data):
        if isinstance(data[0], (list, tuple)):
            data = list(itertools.chain(*data))
        return bytes(min(255, max(0, int(x + 0.5))) for x in data)

class ColorMode_Default(ColorMode):
    "nice fixed palette"
    mode_id = 0
    def __init__(self, size=64):
        data = [[0.0, 0.0, 0.0, 255.0] for x in range(size)]
        for x in range(-size//2, size//2):
            data[x][0] = 64 + 192 * fract(0.3 * abs(x))
            data[x][1] = 64 + 192 * fract(0.5 * abs(x))
            data[x][2] = 64 + 192 * fract(0.5 * abs(x))
        self.data = self.tobytes(data)

class ColorMode_fg1(ColorMode):
    "use fg1 directly"
    mode_id = 1
    def get(self):
        return self.tobytes(sync.fg1.v)

class ColorMode_Alternating(ColorMode):
    "alternate between fg0 and fg1"
    mode_id = 2
    def get(self):
        return self.tobytes(sync.fg0.v) + self.tobytes(sync.fg1.v)

class ColorModeBase_BinaryMask(ColorMode):
    def get(self):
        global sync
        c0 = self.tobytes(sync.fg0.v)
        c1 = self.tobytes(vmix(sync.fg0.v, sync.fg1.v, sync.colorMix.value))
        m = int(sync.colorMask.value + 0.5)
        return b''.join(c1 if ((m >> bit) & 1) else c0 for bit in range(self.mode_id))
class ColorMode_BinaryMask8(ColorModeBase_BinaryMask):
    "use 8-bit binary mask to choose between fg0 and fg1"
    mode_id = 8
class ColorMode_BinaryMask16(ColorModeBase_BinaryMask):
    "use 16-bit binary mask to choose between fg0 and fg1"
    mode_id = 16

class ColorModeBase_DeciMask(ColorMode):
    def __init__(self):
        self.div = [10**x for x in range(self.mode_id)]
    def get(self):
        global sync
        gmix = sync.colorMix.value * 0.1111111111
        m = int(sync.colorMask.value + 0.5)
        return b''.join(self.tobytes(vmix(sync.fg0.v, sync.fg1.v, ((m // div) % 10) * gmix)) for div in self.div)
class ColorMode_DeciMask3(ColorModeBase_DeciMask):
    "use 3-digit decimal mask to mix between fg0 and fg1"
    mode_id = 3
class ColorMode_DeciMask4(ColorModeBase_DeciMask):
    "use 4-digit decimal mask to mix between fg0 and fg1"
    mode_id = 4
class ColorMode_DeciMask5(ColorModeBase_DeciMask):
    "use 5-digit decimal mask to mix between fg0 and fg1"
    mode_id = 5
class ColorMode_DeciMask6(ColorModeBase_DeciMask):
    "use 6-digit decimal mask to mix between fg0 and fg1"
    mode_id = 6
class ColorMode_DeciMask7(ColorModeBase_DeciMask):
    "use 7-digit decimal mask to mix between fg0 and fg1"
    mode_id = 7

################################################################################
# MARK: stripe shader

class StripeShader(ezgl.Shader):
    vs = """
        attribute highp vec2 posAttr;     // input attribute: (0,0)..(1,1)
        uniform mediump vec3 texTrans;    // texture coordinate transform
        uniform highp mat3 stripeMatrix;  // transform matrix for the stripes
        varying mediump vec2 tc;          // texture coordinate
        varying highp vec2 stripeCoord;   // coordinate in stripe coord. system
        void main() {
            gl_Position = vec4(posAttr * vec2(2.0, -2.0) + vec2(-1.0, 1.0), 0.0, 1.0);
            tc = (posAttr - vec2(0.5, 0.5) - texTrans.xy) / texTrans.z + vec2(0.5, 0.5);
            stripeCoord = (stripeMatrix * vec3(posAttr - vec2(0.5), 1.0)).xy;
        }
    """
    fs = """
        varying mediump vec2 tc;
        varying highp vec2 stripeCoord;
        uniform lowp sampler2D mapTex;      // map control texture
        uniform lowp sampler2D stripeTex;   // stripe texture
        uniform lowp sampler2D colorTex;    // color texture
        uniform mediump float ctScale;      // color texture coordinate scaling
        uniform mediump vec2 stripeScale;   // stripe intensity scaling (x = offset, y = scale)
        uniform mediump vec2 map2Params;    // influence of the map to {x = offset, y = width}
        uniform mediump vec3 wobble;        // wobble parameters (x = frequency, y = amplitude, z = phase)
        uniform mediump vec4 bgColor;       // background color; (mixed with map texture)
        void main() {
            lowp vec3 mapColor = texture2D(mapTex, tc).rgb;
            mediump float mapVal = dot(vec3(0.299, 0.587, 0.114), mapColor);
            mediump float stripePos = stripeCoord.y + map2Params.x * mapVal + wobble.y * sin(wobble.x * stripeCoord.x + wobble.z);
            mediump float stripeVal = texture2D(stripeTex, vec2(stripePos, 0.0)).r;
            stripeVal = clamp((stripeVal - clamp(stripeScale.x - map2Params.y * mapVal, 0.0, 1.0)) * stripeScale.y, 0.0, 1.0);
            mediump float stripeID = floor(stripePos + 0.5);
            lowp vec4 fgColor = texture2D(colorTex, vec2(0.5, (stripeID + 0.5) * ctScale));
            gl_FragColor = vec4(mix(mix(mapColor, bgColor.rgb, bgColor.a), fgColor.rgb, fgColor.a * stripeVal), mapVal);
            //gl_FragColor = vec4(mix(mix(mapColor.xyz, bgColor.rgb, bgColor.a), mix(mapColor.xyz, fgColor.rgb, fgColor.a), stripeVal), mapVal);
        }
    """
    attributes = { 'posAttr': 0 }
    uniforms = [ ('stripeTex', 1), ('colorTex', 2), 'texTrans', 'stripeMatrix', 'ctScale', 'stripeScale', 'map2Params', 'wobble', 'bgColor']

    def __init__(self):
        global VirtualSize
        ezgl.Shader.__init__(self)
        self.screen_res = min(VirtualSize)
        data = bytes(list(range(256-8, 0, -8)) + list(range(0, 256, 8)) + [255])
        self.mat = (ctypes.c_float * 9)()
        self.mat[8] = 1.0
        self.tex = gl.make_texture(wrap=gl.REPEAT, filter=gl.LINEAR)
        gl.TexImage2D(gl.TEXTURE_2D, 0, gl.LUMINANCE, 64, 1, 0, gl.LUMINANCE, gl.UNSIGNED_BYTE, data)

    def draw(self,
             map_tex,
             color_tex,
             color_tex_size = 1,
             scale=10.0,
             angle=0.0,
             offset=0.0,
             offset2=0.0,
             width=0.5,
             smooth=0.0,
             map2offset=0.0,
             map2width=0.0,
             wobble_freq=0.0,
             wobble_amp=0.0,
             wobble_phase=0.0,
             background=(0.0, 0.0, 0.0, 1.0),
             tex_dx=0.0,
             tex_dy=0.0,
             tex_zoom=0.0
        ):
        TexturedRect.get_instance().use_vbo()
        self.use()
        gl.set_texture(tex=color_tex, tmu=2)
        gl.set_texture(tex=self.tex, tmu=1)
        gl.set_texture(tex=map_tex, tmu=0)
        scale = self.screen_res / max(scale, 1.0 / 128)
        angle = math.radians(angle)
        c = scale * math.cos(angle)
        s = scale * math.sin(angle)
        self.mat[1] = s * VirtualAspect
        self.mat[4] = c
        self.mat[7] = offset
        self.mat[0] = c * VirtualAspect
        self.mat[3] = -s
        self.mat[6] = offset2
        gl.Uniform1f(self.ctScale, 1.0 / color_tex_size)
        gl.UniformMatrix3fv(self.stripeMatrix, 1, False, self.mat)
        gl.Uniform2f(self.stripeScale, 1.0 - width, min(1.0 / max(smooth, 0.0001), self.screen_res * ScreenScalingFactor * 0.5 / max(scale, 0.0001)))
        gl.Uniform2f(self.map2Params, map2offset, map2width)
        gl.Uniform3f(self.wobble, wobble_freq, wobble_amp, wobble_phase)
        gl.Uniform3f(self.texTrans, tex_dx, tex_dy, tex_zoom)
        gl.Uniform4f(self.bgColor, *background)
        gl.DrawArrays(gl.TRIANGLE_STRIP, 0, 4)

################################################################################
# MARK: main

if __name__ == "__main__":
    frozen = bool(getattr(sys, 'frozen', None))

    parser = optparse.OptionParser(usage="%prog [OPTIONS...] <music.mp3|musicdir>")
    def parse_geometry(option, opt_str, value, parser):
        try:
            x, y = map(int, value.lower().split('x'))
        except ValueError:
            raise optparse.OptionValueError("invalid --geometry argument '%s'" % value)
        parser.values.geometry = (x, y)
    parser.add_option('-v', '--verbose', action='count',
                      help="more verbose operation")
    parser.add_option('-f', '--fullscreen', action='store_true', default=frozen,
                      help="run in fullscreen mode")
    parser.add_option('-g', '--geometry', metavar='WxH', type='str', action='callback', callback=parse_geometry,
                      help="initial window size or fullscreen resolution")
    parser.add_option('-b', '--bpm', type='int', default=DefaultBPM,
                      help="beats per minute [default: %default]")
    parser.add_option('-r', '--rpb', type='int', default=DefaultRBP,
                      help="rows per beat [default: %default]")
    parser.add_option('-s', '--sync', metavar='NAME', default="sync",
                      help="base name of the synchronization track file [default: '%default']")
    opts, args = parser.parse_args()
    Verbosity = int(opts.verbose or 0)
    geometry = opts.geometry or (None if opts.fullscreen else VirtualSize)

    ezgl.Init(geometry, fullscreen=opts.fullscreen, resizable=True,
              title="no stars, just stripes")
    if opts.fullscreen:
        pygame.mouse.set_visible(False)
    ezgl.Shader.LOG_DEFAULT = ezgl.Shader.LOG_IF_NOT_EMPTY

    def update_viewport(w, h):
        global MainViewport, ScreenScalingFactor, VirtualSize
        ScreenScalingFactor = min(w / VirtualSize[0], h / VirtualSize[1])
        mw = int(VirtualSize[0] * ScreenScalingFactor + 0.5)
        mh = int(VirtualSize[1] * ScreenScalingFactor + 0.5)
        MainViewport = ((w - mw) // 2, (h - mh) // 2, mw, mh)
    vp = (ctypes.c_int32 * 4)()
    gl.GetIntegerv(gl.VIEWPORT, ctypes.cast(vp, ctypes.POINTER(ctypes.c_int32)))
    update_viewport(vp[2], vp[3])
    del vp

    if not(args) and DefaultMusic:
        args = [DefaultMusic]
    try:
        filename = args[0]
        bassmusic.Init()
        if os.path.isdir(filename):
            import glob
            filename = random.choice(glob.glob(os.path.join(filename, "*.mp3")))
        music = bassmusic.Track(filename, bpm=opts.bpm, rpb=opts.rpb)
        print("music:", filename)
    except (IndexError, IOError, bassmusic.BASSError):
        music = rocket.FakeController(opts.bpm / 60.0 * opts.rpb)

    sync = rocket.Device(opts.sync, [
        "mapMode",
        "mapTrans:dx", "mapTrans:dy", "mapTrans:zoom",
        "stripes:scale", "stripes:angle",
        "stripes:width", "stripes:m2Width",
        "stripes:offset", "stripes:m2Offset",
        "stripes:smooth",
        "stripes:wFreq", "stripes:wAmp", "stripes:wPhase",
        "colorMode", "colorMask", "colorMix",
        "fg0:r", "fg0:g", "fg0:b", "fg0:a",
        "fg1:r", "fg1:g", "fg1:b", "fg1:a",
         "bg:r",  "bg:g",  "bg:b",  "bg:a",
    ], controller=music, client=not(frozen))

    def gen_mode_map(class_prefix, extra_callback=None):
        class_prefix += '_'
        res = {}
        for name in globals():
            if name.startswith(class_prefix):
                c = globals()[name]
                if c.mode_id is None: continue
                res[c.mode_id] = c()
        if extra_callback:
            for mode_id, obj in extra_callback():
                res[mode_id] = obj
        if Verbosity:
            print(class_prefix[:-1] + "s:")
            for i in sorted(res):
                c = res[i].__class__
                n = getattr(res[i], '_name', c.__doc__)
                print("%6d = %s%s" % (i, c.__name__.split('_', 1)[-1], (" (%s)" % (n.strip().split('\n', 1)[0].strip()) if n else "")))
            print()
        return res

    def load_image_assets():
        for fn in os.listdir(AssetDir):
            m = re.match(r'\d+', fn)
            ext = os.path.splitext(fn)[-1].strip(".").lower()
            if m and (ext in ("jpg", "png")):
                yield int(m.group(0), 10), MapMode_ImageFile(os.path.join(AssetDir, fn))

    map_modes = gen_mode_map("MapMode", load_image_assets)
    color_modes = gen_mode_map("ColorMode")

    fbo = gl.GenFramebuffers()
    gl.BindFramebuffer(gl.FRAMEBUFFER, fbo)
    zbuf = gl.GenRenderbuffers()
    gl.BindRenderbuffer(gl.RENDERBUFFER, zbuf)
    gl.RenderbufferStorage(gl.RENDERBUFFER, gl.DEPTH_COMPONENT16, VirtualSize[0], VirtualSize[1])
    gl.FramebufferRenderbuffer(gl.FRAMEBUFFER, gl.DEPTH_ATTACHMENT, gl.RENDERBUFFER, zbuf)
    rtt = gl.make_texture(filter=gl.LINEAR)
    gl.TexImage2D(gl.TEXTURE_2D, 0, gl.RGBA, VirtualSize[0], VirtualSize[1], 0, gl.RGBA, gl.UNSIGNED_BYTE, None)
    gl.FramebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, rtt, 0)
    assert gl.CheckFramebufferStatus(gl.FRAMEBUFFER) == gl.FRAMEBUFFER_COMPLETE
    gl.BindFramebuffer(gl.FRAMEBUFFER, 0)

    ssh = StripeShader()
    colortex = gl.make_texture(filter=gl.NEAREST, wrap=gl.REPEAT)

    gl.ClearColor(0.0, 0.0, 0.0, 1.0)

    sync.start()

    while True:
        while True:
            ev = pygame.event.poll()
            if not ev:
                break
            if ev.type == QUIT:
                sys.exit(0)
            elif ev.type == KEYDOWN:
                if (ev.key == K_q) or (frozen and (ev.key == K_ESCAPE)):
                    pygame.event.post(pygame.event.Event(QUIT))
            elif ev.type == VIDEORESIZE:
                ezgl.NotifyResize(ev.w, ev.h)
                update_viewport(ev.w, ev.h)

        sync.update()
        if sync.is_player() and (sync.row >= QuitAtRow):
            break
        sync.fg0.v = (sync.fg0.r.value, sync.fg0.g.value, sync.fg0.b.value, 255.0 * sync.fg0.a.value)
        sync.fg1.v = (sync.fg1.r.value, sync.fg1.g.value, sync.fg1.b.value, 255.0 * sync.fg1.a.value)
        sync.bg.v  = (sync.bg.r.value / 255.0, sync.bg.g.value / 255.0, sync.bg.b.value / 255.0, sync.bg.a.value)

        gl.BindFramebuffer(gl.FRAMEBUFFER, fbo)
        gl.Viewport(0, 0, VirtualSize[0], VirtualSize[1])
        gl.Clear(gl.COLOR_BUFFER_BIT)
        map_modes.get(int(0.5 + sync.mapMode.value), map_modes[0]).draw()
        gl.BindFramebuffer(gl.FRAMEBUFFER, 0)
        gl.Viewport(*MainViewport)

        gl.Clear(gl.COLOR_BUFFER_BIT)

        gl.BindTexture(gl.TEXTURE_2D, colortex)
        data = color_modes.get(int(0.5 + sync.colorMode.value), color_modes[0]).get()
        size = len(data) // 4
        gl.TexImage2D(gl.TEXTURE_2D, 0, gl.RGBA, 1, size, 0, gl.RGBA, gl.UNSIGNED_BYTE, data)
        # print("colorTex:", " ".join(f"{x:02X}" for x in data))

        ssh.draw(
            map_tex = rtt,
            color_tex = colortex,
            color_tex_size = size,
            scale        = sync.stripes.scale.value,
            angle        = sync.stripes.angle.value,
            width        = sync.stripes.width.value,
            smooth       = sync.stripes.smooth.value,
            offset       = sync.stripes.offset.value,
            map2offset   = sync.stripes.m2Offset.value,
            map2width    = sync.stripes.m2Width.value,
            wobble_freq  = sync.stripes.wFreq.value,
            wobble_amp   = sync.stripes.wAmp.value,
            wobble_phase = sync.stripes.wPhase.value,
            background   = sync.bg.v,
            tex_dx       = sync.mapTrans.dx.value,
            tex_dy       = sync.mapTrans.dy.value,
            tex_zoom     = sync.mapTrans.zoom.value,
        )

        ezgl.SwapBuffers()
