﻿#include "pch.h"
#include "Model.h"

#include "engine/Engine.h"
#include "gl/Buffer.h"

#include <print>

#include "engine/model/Mesh.h"

#include "FileWatch.hpp"


#define max(a,b) (a > b ? (a) : (b))


smallvec<ModelNode*, 100> Model::parse_node(aiNode* parent_node, ModelNode* curr_model_node) {
	// std::vector<ModelNode*> childern_model_nodes;
	unsigned child_node_cnt = parent_node->mNumChildren;


	bool mesh_shaped = false;
	bool physics_enabled = true;

	for (int i = 0; i < child_node_cnt; i++) {
		aiNode* child_node = parent_node->mChildren[i];
		if(!std::string(child_node->mName.C_Str()).starts_with("_Col")) {
			// ------------------------ CREATE CHILDREN ------------------------ //
			this->nodes.push_back(ModelNode(child_node, this));
			ModelNode* child_model_node = &this->nodes[this->nodes.size() - 1];

			child_model_node->ai_transform_matrix = parent_node->mTransformation * child_node->mTransformation;

			aiVector3D position;
			aiVector3D scaling;
			aiQuaternion rotation;
			child_model_node->ai_transform_matrix.Decompose(scaling, rotation, position);

			child_model_node->position = glm::vec3(position.x,position.y, position.z);
			child_model_node->scale = glm::vec3(scaling.x,scaling.y, scaling.z);
			child_model_node->rotation = glm::quat(rotation.w,rotation.x, rotation.y, rotation.z);

			memcpy(&child_model_node->transform_matrix[0][0], &child_model_node->ai_transform_matrix[0][0], sizeof(float) * 16);

			curr_model_node->children.push_back(child_model_node);

			this->parse_node(child_node, child_model_node);

			curr_model_node->children_map[child_model_node->name] = child_model_node;
		}
	}

	for (int i = 0; i < child_node_cnt; i++) {
		// ------------------------ PARSE CHILDREN FOR CURR NODE ------------------------ //
		aiNode* child_node = parent_node->mChildren[i];
		auto name = std::string(child_node->mName.C_Str());

		if (name.starts_with("_UnlockRot")) {
			curr_model_node->rot_locked = glm::bvec3(false, false, false);

		}

		if (name.starts_with("_NoPhysics")) {

		}

		auto create_col_triangles = [](std::vector<int>& indicesList, std::vector<float>& posList) -> JPH::TriangleList {
			JPH::TriangleList triangles;
			for (int k = 0; k < indicesList.size() / 3; k++) {
				int idx_a = indicesList[k * 3];
				int idx_b = indicesList[k * 3 + 1];
				int idx_c = indicesList[k * 3 + 2];

				float vert_x = posList[idx_a * 4];
				float vert_y = posList[idx_a * 4 + 1];
				float vert_z = posList[idx_a * 4 + 2];
				float vert_w = posList[idx_a * 4 + 3];

				float vert_b_x = posList[idx_b * 4];
				float vert_b_y = posList[idx_b * 4 + 1];
				float vert_b_z = posList[idx_b * 4 + 2];
				float vert_b_w = posList[idx_b * 4 + 3];

				float vert_c_x = posList[idx_c * 4];
				float vert_c_y = posList[idx_c * 4 + 1];
				float vert_c_z = posList[idx_c * 4 + 2];
				float vert_c_w = posList[idx_c * 4 + 3];

				assert(vert_w == 0);
				assert(vert_b_w == 0);
				assert(vert_c_w == 0);

				JPH::Float3 v1 = JPH::Float3(vert_x, vert_y, vert_z);
				JPH::Float3 v2 = JPH::Float3(vert_b_x, vert_b_y, vert_b_z);
				JPH::Float3 v3 = JPH::Float3(vert_c_x, vert_c_y, vert_c_z);

				triangles.push_back(JPH::Triangle(v1, v2, v3));
			}

			return triangles;
		};

		if(name.starts_with("_Dynamic")) {
			curr_model_node->is_dynamic_physics_object = true;
		}

		int is_col = name.starts_with("_Col");
		int is_mesh_col = name.starts_with("_MeshCol");
		int is_nonconv_mesh_col = name.starts_with("_NonconvMeshCol");

		if (is_col || is_mesh_col || is_nonconv_mesh_col) {
			Mesh* mesh;
			glm::vec3 scale;
			if(is_col) {
				mesh = this->meshes[child_node->mMeshes[0]];
				// TODO: get from parent_model_node
				aiVector3D position;
				aiVector3D scaling;
				aiQuaternion rotation;
				child_node->mTransformation.Decompose(scaling, rotation, position);
				scale = glm::vec3(scaling.x, scaling.y, scaling.z);
			} else {
				mesh = this->meshes[parent_node->mMeshes[0]];
				aiVector3D position;
				aiVector3D scaling;
				aiQuaternion rotation;
				parent_node->mTransformation.Decompose(scaling, rotation, position);
				scale = glm::vec3(scaling.x, scaling.y, scaling.z);
			}

			JPH::Result<JPH::Ref<JPH::Shape>> mesh_shape_creation_res;

			if(is_nonconv_mesh_col) {
				JPH::TriangleList triangles = create_col_triangles(mesh->indicesList, mesh->posList);
				JPH::Ref<JPH::MeshShapeSettings> shape_settings = new JPH::MeshShapeSettings(triangles);
				mesh_shape_creation_res = shape_settings->Create();
			} else if (is_col || is_mesh_col) {
				JPH::Array<JPH::Vec3> vert_arr;
				for(int i_mesh = 0; i_mesh < mesh->posList.size(); i_mesh+=4) {
					JPH::Vec3 vert = JPH::Vec3(
						mesh->posList[i_mesh],
						mesh->posList[i_mesh + 1],
						mesh->posList[i_mesh + 2]
					);
					vert_arr.push_back(vert);
				}
				const float margin = 0.04f;
				JPH::Ref<JPH::ConvexHullShapeSettings> shape_settings = new JPH::ConvexHullShapeSettings(vert_arr, margin);
				mesh_shape_creation_res = shape_settings->Create();
				// mesh.
				// mesh_shape_creation_res = mesh_shape_creation_res.Get()->ScaleShape(glm::to_jph(curr_model_node->scale));
				mesh_shape_creation_res = mesh_shape_creation_res.Get()->ScaleShape(glm::to_jph(scale));
			}

			if(mesh_shape_creation_res.HasError()){
				std::cout << "Mesh Shape Creation Error \n";
			} else {
				JPH::Ref<JPH::Shape> s = mesh_shape_creation_res.Get();
				// JPH::Ref<JPH::Shape> scaled_shape = (new JPH::ScaledShapeSettings(s, glm::to_jph(curr_model_node->scale)))->
				//                                     Create().Get();
				// s.GetPtr()->ScaleShape(glm::to_jph(curr_model_node->scale));
				curr_model_node->collision_shape = std::move(s);
				mesh_shaped = true;
			}
		}
	}

	// auto sub_children = this->parse_node(child, TODO);
	return curr_model_node->children;
}

Model::Model(std::string _path):path(_path) {
	std::filesystem::path currentPath = std::filesystem::current_path();
	std::string currentPathStr = currentPath.string();
	path = std::string("src\\") + path;
	path = currentPathStr + std::string("\\") + path;
	this->re_create();

	// Convert the input string to a filesystem path
	// Get the absolute path
	std::wstring absolute_folder_path = std::filesystem::absolute(std::filesystem::path(path)).wstring();
	std::wstring file_name = std::filesystem::path(path).filename().wstring();

	std::wstring regex_query = L".*";
	regex_query += file_name;
	std::wregex regRules(L".*\\.(glb)$");


	
	static filewatch::FileWatch<std::wstring>* file_watch = new filewatch::FileWatch<std::wstring>(
		//filewatch::FileWatch<std::wstring> watch(
		// L"./src/shaders",
		absolute_folder_path,
		[this, regRules](const std::wstring& path, const filewatch::Event change_type) {
			// std::wcout << path << L" : ";

			long curr_time = clock();
			long prev_shader_change_time = this->last_change_timestamp.load();
			double delta_time_secs = double(curr_time - prev_shader_change_time) / double(CLOCKS_PER_SEC);
			this->last_change_timestamp.store(curr_time);

			if(delta_time_secs > 0.15) {
				bool foundReg = std::regex_search(path, regRules);
				if (!foundReg)
					return;
				const bool changedFile = change_type == filewatch::Event::modified ||
					change_type == filewatch::Event::renamed_old;
					
				if (changedFile) {
					WEngine->scene_reloaded.store(true);

				}
			}
		});
}

#include <psapi.h>

void Model::re_create() {
	ZoneScopedN("reload model");

	// this->nodes = std::vector<ModelNode>(1000);

	Assimp::Importer importer;

	// lunacy
	this->textures.reserve(1000);
	this->nodes.reserve(1000);
	this->meshes.reserve(5000);
	this->mesh_outlines.reserve(5000);

	// this->meshes = std::vector<Mesh*>(5000);
	// this->mesh_outlines = std::vector<Mesh*>(5000);

	auto print_mem = [&]() {
		static HANDLE proc_handle = GetCurrentProcess();
		// proc_handle = GetCurrentProcess();
		PROCESS_MEMORY_COUNTERS_EX pmc;
		GetProcessMemoryInfo(proc_handle, (PROCESS_MEMORY_COUNTERS*)&pmc, sizeof(pmc));
		SIZE_T virtualMemUsed = pmc.PrivateUsage;
		SIZE_T physicalMemUsed = pmc.WorkingSetSize;
		float virtualMemUsedMB = static_cast<float>(virtualMemUsed) / (1024 * 1024);
		float physicalMemUsedMB = static_cast<float>(physicalMemUsed) / (1024 * 1024);
		std::print("-------- \n");
		std::print("MEM virt: {:f} mb", virtualMemUsedMB);
		std::print("MEM phys: {:f} mb", physicalMemUsedMB);

	};

	// for(int i = 0; i < 10000; i++) {
	// 	print_mem();
	// 	Assimp::Importer _importer;
	// 	const aiScene* scene = _importer.ReadFile(path, aiProcess_Triangulate | aiProcess_OptimizeMeshes);
	// 	_importer.FreeScene();
	// 	print_mem();
	// }




	const aiScene* scene;

	bool bad_scene_ptr = false;
	bool g = false;
	bool j = false;
	bool s = false;
	bool u = false;
	do {
		//if (scene != nullptr) {
		//	delete scene;
		//}
		//if (this->scene_non_triangulated != nullptr) {
		//	delete this->scene_non_triangulated ;
		//}

		scene = importer.ReadFile(path, aiProcess_Triangulate | aiProcess_OptimizeMeshes);
		// this->scene_non_triangulated = importer.ReadFile(path, 0);

		wlog_info("-- LOADED SCENE -- ");

		bad_scene_ptr = false;
		g = false;
		j = false;
		s = false;
		u = false;

		bad_scene_ptr = IsBadReadPtr(scene, sizeof(aiScene));
		if (!bad_scene_ptr) {
			g = scene->mNumCameras != 0;
			j = scene->mRootNode == nullptr;
			if (!j) {
				s = IsBadReadPtr(scene->mRootNode, sizeof(aiNode));
				if (!s) {
					u = IsBadReadPtr(scene->mMeshes, sizeof(aiMesh));
				}
			}
		}

	} while (bad_scene_ptr || g || j || s || u);

	this->scene_non_triangulated = scene;
	this->ai_scene = scene;

	if (scene == 0) {
		std::printf("FATAL ERROR LOADING SCENE \n");
		std::printf("FATAL ERROR LOADING SCENE \n");
		std::printf("FATAL ERROR LOADING SCENE \n");
		std::printf("FATAL ERROR LOADING SCENE \n");
		std::printf("FATAL ERROR LOADING SCENE \n");
		std::printf("FATAL ERROR LOADING SCENE \n");
		std::printf("FATAL ERROR LOADING SCENE \n");
		std::printf("FATAL ERROR LOADING SCENE \n");
		std::printf("FATAL ERROR LOADING SCENE \n");
	}


	for(int i = 0; i < scene->mNumLights; i++) {
		aiLight* light = scene->mLights[i];
		for(int idx_child = 0; idx_child < scene->mRootNode->mNumChildren; idx_child++) {
			auto node = scene->mRootNode->mChildren[idx_child];
			// if(strcmp(light->mName, node->mName)) {
			if(strcmp(light->mName.C_Str(), node->mName.C_Str()) == 0) {
				aiVector3D position;
				aiVector3D scaling;
				aiQuaternion rotation;
				node->mTransformation.Decompose(scaling, rotation, position);

				light->mPosition = position;
			}
		}

		this->lights.push_back(*light);
	}

	aiTexture** ai_textures = scene->mTextures;
	// load textures
	for(size_t i = 0; i < scene->mNumTextures; i++) {
		auto aiTex = scene->mTextures[i];

		unsigned char *image_data = nullptr;
		int height;
		int width;
		int components_per_pixel;

		stbi_set_flip_vertically_on_load(false);
		if (aiTex->mHeight == 0) {
			image_data = stbi_load_from_memory(reinterpret_cast<unsigned char*>(aiTex->pcData), aiTex->mWidth, &width, &height, &components_per_pixel, 0);
		} else {
			image_data = stbi_load_from_memory(reinterpret_cast<unsigned char*>(aiTex->pcData), aiTex->mWidth * aiTex->mHeight, &width, &height, &components_per_pixel, 0);
		}

		auto internal_format = InternalFormat::RGBA8;

		if (components_per_pixel == 3) {
			internal_format = InternalFormat::SRGB8;
		} else if (components_per_pixel == 4) {
			internal_format = InternalFormat::SRGB8_ALPHA8;
		}


		Texture* tex = WEngine->alloc_textures.push(TextureDesc{
			(uint32_t)width, (uint32_t)height,1,
			internal_format
		}).ptr;
		tex->is_model_texture = true;
		this->textures.push_back(tex);

		tex->upload_data(image_data);
		stbi_image_free(image_data);
	}

	aiMesh** ai_meshes = scene->mMeshes;
	unsigned scene_mesh_cnt = scene->mNumMeshes;

	int buff_tex_coords_counter = 0;
	int buff_normals_counter = 0;
	int buff_vert_indices_counter = 0;
	int buff_verts_counter = 0;

	std::vector<float> globalPosList;
	std::vector<float> globalTexCoordList;
	std::vector<float> globalNormalsList;
	std::vector<float> globalColoursList;
	std::vector<int> globalIndicesList;

	for (int i = 0; i < scene_mesh_cnt; i++) {
		aiMesh* ai_mesh = ai_meshes[i];

		int face_count = ai_mesh->mNumFaces;
		int vertex_count = ai_mesh->mNumVertices;
		int element_count = face_count * 3;

		std::vector<float> posList;
		std::vector<float> texCoordList;
		std::vector<float> normalsList;
		std::vector<float> coloursList;
		std::vector<int> indicesList;

		auto positionsAssimp = ai_mesh->mVertices;
		int num_verts = (int)ai_mesh->mNumVertices;
		for (int i = 0; i < num_verts; i++) {
			auto pos = positionsAssimp[i];
			posList.push_back(pos.x);
			posList.push_back(pos.y);
			posList.push_back(pos.z);
			posList.push_back(0);
		}

		if (ai_mesh->HasNormals()) {
			auto normals_assimp = ai_mesh->mNormals;
			for (int i = 0; i < num_verts; i++) {
				auto normal = normals_assimp[i];
				normalsList.push_back(normal.x);
				normalsList.push_back(normal.y);
				normalsList.push_back(normal.z);
				normalsList.push_back(0);
			}
		}
		if (ai_mesh->mTextureCoords[0]) {
			for (int i = 0; i < num_verts; i++) {
				// glm::vec2 vec;
				// a vertex can contain up to 8 different texture coordinates. We thus make the assumption that we won't
				// use models where a vertex can have multiple texture coordinates so we always take the first set (0).
				float x = ai_mesh->mTextureCoords[0][i].x;
				float y = ai_mesh->mTextureCoords[0][i].y;

				// vertex.TexCoords = vec;

				// // tangent
				// float tan_x = ai_mesh->mTangents[i].x;
				// float tan_y = ai_mesh->mTangents[i].y;
				// float tan_z = ai_mesh->mTangents[i].z;
				//
				// // bitangent
				// float bitan_x = ai_mesh->mBitangents[i].x;
				// float bitan_y = ai_mesh->mBitangents[i].y;
				// float bitan_z = ai_mesh->mBitangents[i].z;

				texCoordList.push_back(x);
				texCoordList.push_back(y);
			}
		}
		aiFace* facesBuffer = ai_mesh->mFaces;
		{
			// if (is_outline) {
			// 	element_count = 0;
			// 	// for (i in 0 until faceCount) {
			// 	for (int i = 0; i < face_count; i++) {
			// 		aiFace face = facesBuffer[i];
			// 		//                check(face.mNumIndices() == 3) { "AIFace.mNumIndices() != 3" }
			// 		//                elementArrayBufferData.put(face.mIndices())
			// 		int numIndices = face.mNumIndices;
			//
			// 		if (numIndices != 3) {
			// 			std::println("wrong face cnt");
			// 		}
			// 		// for (j in 0 until numIndices){
			// 		for (int j = 1; j < numIndices; j++) {
			// 			unsigned idx = face.mIndices[j];
			//
			// 			this->indicesList.push_back(posList[glm::max((int(idx)-1)*3, 0)]);
			// 			this->indicesList.push_back(posList[glm::max((int(idx)-1)*3 + 1, 0)]);
			// 			this->indicesList.push_back(posList[glm::max((int(idx)-1)*3 + 2, 0)]);
			// 			this->indicesList.push_back(0);
			//
			// 			this->indicesList.push_back(posList[idx * 3]);
			// 			this->indicesList.push_back(posList[idx * 3 + 1]);
			// 			this->indicesList.push_back(posList[idx * 3 + 2]);
			// 			this->indicesList.push_back(0);
			// 		}
			// 		// indicesList.push_back(2147483647);
			//
			// 		element_count += numIndices;
			// 	}
			// } else
		}
		{
			// for (i in 0 until faceCount) {
			for (int i = 0; i < face_count; i++) {
				// val face: AIFace = facesBuffer[i]
				aiFace face = facesBuffer[i];

				//                check(face.mNumIndices() == 3) { "AIFace.mNumIndices() != 3" }
				//                elementArrayBufferData.put(face.mIndices())
				int numIndices = (int)face.mNumIndices;
				if (numIndices != 3) {
					// println("Wrong index cnt")
					std::println("Wrong index cnt");
				}
				for (int j = 0; j < numIndices; j++) {
					int idx = (int)face.mIndices[j];
					indicesList.push_back(idx);
				}
			}
		}


		for(float pos : posList) {
			globalPosList.push_back(pos);
		}
		for(float tex_coord : texCoordList) {
			globalTexCoordList.push_back(tex_coord);
		}
		for(float normal : normalsList) {
			globalNormalsList.push_back(normal);
		}
		for(float colour : coloursList) {
			globalColoursList.push_back(colour);
		}
		for(int idx : indicesList) {
			globalIndicesList.push_back(idx);
		}

		Mesh* mesh = new Mesh(
			ai_mesh,
			false,
			this,
			buff_tex_coords_counter,
			buff_normals_counter,
			buff_vert_indices_counter,
			buff_verts_counter
		);
		mesh->posList = posList;
		mesh->texCoordList = texCoordList;
		mesh->normalsList = normalsList;
		mesh->coloursList = coloursList;
		mesh->indicesList = indicesList;

		buff_tex_coords_counter += texCoordList.size();
		buff_normals_counter += normalsList.size();
		buff_vert_indices_counter += indicesList.size();
		buff_verts_counter += posList.size();

		meshes.push_back(mesh);
	}

	// if(!texCoordList.empty()) {
		// this->buff_tex_coords = new Buffer(&texCoordList[0], texCoordList.size() * 4);
	// }
	// if(!normalsList.empty()) {
	// }

	this->buff_tex_coords = new Buffer({.byte_len = (int)globalTexCoordList.size() * 4, .upload_data = &globalTexCoordList[0], .name = "Buff Texcoords"});
	this->buff_normals = new Buffer({.byte_len = (int)globalNormalsList.size() * 4, .upload_data = &globalNormalsList[0], .name = "Buff Normals"});
	this->buff_vert_indices = new Buffer({.byte_len = (int)globalIndicesList.size() * 4, .upload_data = &globalIndicesList[0], .name = "Buff Indices" });
	this->buff_verts = new Buffer({.byte_len = (int)globalPosList.size() * 4, .upload_data = &globalPosList[0], .name = "Buff Verts"});

	// for (int i = 0; i < scene_mesh_cnt; i++) {
	// 	aiMesh* ai_mesh = ai_meshes[i];
	// 	Mesh* mesh = new Mesh(ai_mesh, false, this);
	// 	meshes.push_back(mesh);
	// }

	// bool outlines = true;
	// if (outlines) {
	// 	aiMesh** ai_meshes = scene_non_triangulated->mMeshes;
	// 	unsigned scene_mesh_cnt = scene_non_triangulated->mNumMeshes;
	// 	for (int i = 0; i < scene_mesh_cnt; i++) {
	// 		aiMesh* ai_mesh = ai_meshes[i];
	// 		Mesh* mesh = new Mesh(ai_mesh, true, this);
	// 		mesh_outlines.push_back(mesh);
	// 	}
	// }


	this->nodes.push_back(ModelNode(scene->mRootNode, this));
	parent_node = &this->nodes[0];
	ModelNode* root_node = &this->nodes[this->nodes.size() - 1];
	root_node->ai_transform_matrix = scene->mRootNode->mTransformation;
	this->parse_node(scene->mRootNode, root_node);

	importer.FreeScene();
}
