/*---------------------------------------------------------
    This file is part of the program suite TEN
    (Tools for Elastic Networks)

    Copyright (C)

	Lars Ackermann
    G. Matthias Ullmann
    Bayreuth 2014

    www.bisb.uni-bayreuth.de

    This program is free software: you can redistribute
    it and/or modify it under the terms of the
    GNU Affero General Public License as published by the
    Free Software Foundation, either version 3 of the
    License, or (at your option) any later version.

    This program is distributed in the hope that it will
    be useful, but WITHOUT ANY WARRANTY; without even the
    implied warranty of MERCHANTABILITY or FITNESS FOR A
    PARTICULAR PURPOSE. See the GNU General Public License
    for more details.

    You should have received a copy of the
    GNU Affero General Public License along with this
    program.  If not, see <http://www.gnu.org/licenses/>.
-----------------------------------------------------------*/


#include "TrajectoryGeneratorLapack.h"
#include <iostream>
#include "gmd.h"
#include "laslv.h"
#include "blas3pp.h"
#include "ConfComparator.h"

TrajectoryGeneratorLapack::TrajectoryGeneratorLapack(SetupConstants *settings, FileHandler* fileHandler) :
	AbstractTrajectoryGenerator(settings, fileHandler) {
}

TrajectoryGeneratorLapack::~TrajectoryGeneratorLapack() {
}

void TrajectoryGeneratorLapack::calcAndPrintSingleTrajectories(const char* protName, vector<vector<double> > startCoords, vector<double> eigenvals, double* eigenvecs, int internalProtId) {

	vector<int> trajs = settings -> getTrajs();

	SettingsMap setMap = settings -> getSettings();
	vector<string> paths = settings -> getPaths();
	string path;
	if (internalProtId < paths.size()) {
		path = paths.at(internalProtId);
	} else {
		path = "";
	}
	BBMap bbMap = settings -> getBBNumbers();
	ConfComparator* confComp = NULL;
	const char* refName;
	if (setMap["compareConfs"] == 1) {
		confComp = new ConfComparator(fileHandler, settings);
		vector<string> refNames = settings -> getRefNames();
		vector<string> refPaths = settings -> getRefPaths();
		string path = "";
		if (internalProtId < refPaths.size()) {
			path = refPaths.at(internalProtId);
		}
		refName = refNames.at(internalProtId).c_str();
		vector<vector<double> > refCoords = fileHandler -> readPQRM(path.c_str(), refNames.at(internalProtId).c_str(), bbMap);

		settings -> setAllBBNumbers(bbMap);
		//The following makes sure that you compare a structure with the same number of back bone atoms -->
		if (bbMap[refName].size() != bbMap[protName].size()) {
			std::cout
					<< "WARNING: The protein's and the reference protein's structure hold a different number of back bone atoms. To compare conformations to a reference structure it is necessary that the structures hold the SAME number of bb atoms. So please revise the pqrms and re-compute the eigenvalues and -vectors. Program will not compare conformations and reference structure this time."
					<< std::endl;
			setMap["compareConfs"] = 0;
		} else {
			confComp = new ConfComparator(fileHandler, settings);
			confComp -> setRefCoords(refCoords, refName);
		}
	}
	vector<double> linExpansions = this -> calculateLinearExpansions(eigenvals);
	int trajsteps = ((2 * M_PI) / setMap["traj_tstep"]) + 1;
	int eigenVecSize = eigenvals.size() + 6;
	for (uint i = 0; i < trajs.size(); i++) {
		int trajId = trajs.at(i);
		double* oneEigVec = new double[eigenVecSize];
		for (int subId = 0; subId < eigenVecSize; subId++) {
			oneEigVec[subId] = eigenvecs[(trajId - 1) * (eigenVecSize) + subId];
		}
		FILE* file = NULL;
		if (setMap["print_traj"] == 1) {
			file = fileHandler -> createAndOpenDCDFile(trajId, protName, path.c_str());
			fileHandler -> writeDCDHeader(file, eigenVecSize / 3, trajsteps);
		} else if (setMap["print_traj"] == 2) {
			file = fileHandler -> createAndOpenPDBFile(trajId, protName, path.c_str());
		}
		FILE* distFile = NULL;
		if (setMap["compareConfs"] == 1) {
			distFile = fileHandler -> createAndOpenDistFile(trajId, protName, path.c_str());
		}

		int counter = 1;
		for (double t = 0.0; t < 2 * M_PI; t += setMap["traj_tstep"]) {
			double* displacement = this -> calculateSingleDisplacementStep(linExpansions, oneEigVec, trajId - 1, t);
			float** changedCoords = this -> calculateNewCoords(startCoords, displacement);
			if (setMap["print_traj"] == 1) {
				fileHandler -> writeDCD(file, changedCoords, eigenVecSize / 3);
			} else if (setMap["print_traj"] == 2) {
				fileHandler -> writePseudoPDB(file, changedCoords, eigenVecSize / 3);
			}
			if (setMap["compareConfs"] == 1 and t < M_PI) {
				confComp -> setTimeStepCoords(changedCoords);
				confComp -> calculateDiff(t, protName, refName);
			}
			//free memory
			for (int p = 0; p < 3; p++) {
				delete[] changedCoords[p];
			}
			delete[] changedCoords;
			delete[] displacement;
			counter++;
		}
		if (setMap["compareConfs"] == 1) {
			fileHandler -> writeDists(distFile, confComp -> getMinimalDiff());
			fileHandler -> closeFile(distFile);
		}
		if (setMap["print_traj"] == 1 or setMap["print_traj"] == 2) {
			fileHandler -> closeFile(file);
		}
		delete[] oneEigVec;
	}
	if (setMap["compareConfs"] == 1) {
		delete confComp;
	}
}

void TrajectoryGeneratorLapack::calcAndPrintSuperTrajectories(const char* protName, vector<vector<double> > startCoords, vector<double> eigenvals, double* eigenvecs, int internalProtId) {

	vector<int> supertrajs = settings -> getSuperTrajs();
	if (supertrajs.size() == 0) {
		for (uint i = 1; i <= eigenvals.size(); i++) {
			supertrajs.push_back(i);
		}
	}

	SettingsMap setMap = settings -> getSettings();
	vector<string> paths = settings -> getPaths();
	string path;
	if (internalProtId < paths.size()) {
		path = paths.at(internalProtId);
	} else {
		path = "";
	}
	PhaseMap phaseMap = settings -> getPhases();
	phaseMap = this -> preparePhasesMap(phaseMap, supertrajs);
	BBMap bbMap = settings -> getBBNumbers();

	ConfComparator* confComp = NULL;
	const char* refName;
	if (setMap["compareConfs"] == 1) {
		confComp = new ConfComparator(fileHandler, settings);
		vector<string> refNames = settings -> getRefNames();
		vector<string> refPaths = settings -> getRefPaths();
		string path = "";
		if (internalProtId < refPaths.size()) {
			path = refPaths.at(internalProtId);
		}
		refName = refNames.at(internalProtId).c_str();
		vector<vector<double> > refCoords = fileHandler -> readPQRM(path.c_str(), refNames.at(internalProtId).c_str(), bbMap);
		settings -> setAllBBNumbers(bbMap);
		//The following makes sure that you compare a structure with the same number of back bone atoms -->
		if (bbMap[refName].size() != bbMap[protName].size()) {
			std::cout
					<< "WARNING: The protein's and the reference protein's structure hold a different number of back bone atoms. To compare conformations to a reference structure it is necessary that the structures hold the SAME number of bb atoms. So please revise the pqrms and re-compute the eigenvalues and -vectors. Program will not compare conformations and reference structure this time."
					<< std::endl;
			setMap["compareConfs"] = 0;
		} else {
			confComp = new ConfComparator(fileHandler, settings);
			confComp -> setRefCoords(refCoords, refName);
		}
	}

	vector<double> linExpansions = this -> calculateLinearExpansions(eigenvals);
	int eigenVecSize = eigenvals.size() + 6;
	FILE* file = NULL;
	if (setMap["print_super"] == 1) {
		file = fileHandler -> createAndOpenDCDFile(0, protName, path.c_str());
		fileHandler -> writeDCDHeader(file, eigenVecSize / 3, setMap["supertraj_steps"] + 1);
	} else if (setMap["print_super"] == 2) {
		file = fileHandler -> createAndOpenPDBFile(0, protName, path.c_str());
	}
	FILE* distFile = NULL;
	if (setMap["compareConfs"] == 1) {
		distFile = fileHandler -> createAndOpenDistFile(0, protName, path.c_str());
	}

	for (double t = 0.0; t < setMap["supertraj_steps"] * setMap["supertraj_tstep"]; t += setMap["supertraj_tstep"]) {
		double* fullDisplacement = new double[eigenVecSize];
		for (int p = 0; p < eigenVecSize; p++) {
			fullDisplacement[p] = 0.0;
		}

		for (uint i = 0; i < supertrajs.size(); i++) {
			int superTrajId = supertrajs.at(i);
			double* oneEigVec = new double[eigenVecSize];
			for (int subId = 0; subId < eigenVecSize; subId++) {
				oneEigVec[subId] = eigenvecs[(superTrajId - 1) * (eigenVecSize) + subId];
			}
			double* displacement = this -> calculateSuperDisplacementStep(linExpansions, oneEigVec, eigenvals, superTrajId - 1, t, phaseMap);
			for (int j = 0; j < eigenVecSize; j++) {
				fullDisplacement[j] += displacement[j];
			}
			delete[] displacement;
			delete[] oneEigVec;
		}
		float** changedCoords = this -> calculateNewCoords(startCoords, fullDisplacement);
		if (setMap["print_super"] == 1) {
			fileHandler -> writeDCD(file, changedCoords, eigenVecSize / 3);
		} else if (setMap["print_super"] == 2) {
			fileHandler -> writePseudoPDB(file, changedCoords, eigenVecSize / 3);
		}
		if (setMap["compareConfs"] == 1) {
			confComp -> setTimeStepCoords(changedCoords);
			confComp -> calculateDiff(t, protName, refName);
		}
		//free memory
		for (int p = 0; p < 3; p++) {
			delete[] changedCoords[p];
		}
		delete[] changedCoords;
		delete[] fullDisplacement;
	}
	if (setMap["print_super"] == 1 or setMap["print_super"] == 2) {
		fileHandler -> closeFile(file);
	}
	if (setMap["compareConfs"] == 1) {
		fileHandler -> writeDists(distFile, confComp -> getMinimalDiff());
		fileHandler -> closeFile(distFile);
		delete confComp;
	}
}

float** TrajectoryGeneratorLapack::calculateNewCoords(vector<vector<double> > startCoords, double* displacementStep) {
	vector<double> x = startCoords.at(0);
	vector<double> y = startCoords.at(1);
	vector<double> z = startCoords.at(2);

	float** xyz = new float*[3];
	for (int i = 0; i < 3; i++) {
		*(xyz + i) = new float[x.size()]; /*[3][N]*/
	}

	for (uint i = 0; i < x.size(); i++) {
		xyz[0][i] = (float) ((x.at(i) + displacementStep[i * 3]));
		xyz[1][i] = (float) ((y.at(i) + displacementStep[i * 3 + 1]));
		xyz[2][i] = (float) ((z.at(i) + displacementStep[i * 3 + 2]));
	}
	return xyz;
}

double* TrajectoryGeneratorLapack::calculateSingleDisplacementStep(vector<double> linExpansions, double* eigenvec, int trajIndex, double time) {

	double* eigenVecCopy = new double[linExpansions.size() + 6];
	memcpy(eigenVecCopy, eigenvec, sizeof(double) * (linExpansions.size() + 6));
	LaGenMatDouble eigVec = LaGenMatDouble(eigenVecCopy, linExpansions.size() + 6, 1, false);
	double scalingFactor = linExpansions.at(trajIndex) * cos(time);
	Blas_Scale(scalingFactor, eigVec);
	return eigenVecCopy;
}

double* TrajectoryGeneratorLapack::calculateSuperDisplacementStep(vector<double> linExpansions, double* eigenvec, vector<double> eigenvals, int trajIndex, double time, PhaseMap phases) {

	double* eigenVecCopy = new double[linExpansions.size() + 6];
	memcpy(eigenVecCopy, eigenvec, sizeof(double) * (linExpansions.size() + 6));
	LaGenMatDouble eigVec = LaGenMatDouble(eigenVecCopy, linExpansions.size() + 6, 1, false);

	double w_k = sqrt(eigenvals.at(trajIndex));
	double scalingFactor = linExpansions.at(trajIndex) * cos(w_k * time + phases[trajIndex + 1]);
	Blas_Scale(scalingFactor, eigVec);
	return eigenVecCopy;
}

vector<double> TrajectoryGeneratorLapack::calculateLinearExpansions(vector<double> eigenvalues) {
	SettingsMap setMap = settings -> getSettings();
	vector<double> alphas;
	double numerator = sqrt(2 * setMap["uniGasConst"] * setMap["temp"]);
	for (uint i = 0; i < eigenvalues.size(); i++) {
		alphas.push_back(numerator / sqrt(eigenvalues.at(i)));
	}
	return alphas;
}

double* TrajectoryGeneratorLapack::prepareHessianMatrix(double* modifiedHessianMatrix, int numberOfEVals, const char* protName) {

	MassMap allMasses = settings -> getMasses();
	vector<double> masses = allMasses[protName];

	LaGenMatDouble modHessMatr = LaGenMatDouble(modifiedHessianMatrix, numberOfEVals + 6, numberOfEVals, false);
	//this step must be done because the hessian matrix has
	for (int i = 0; i < (numberOfEVals + 6) / 3; i++) {
		for (int j = 0; j < numberOfEVals; j++) {
			for (int k = 0; k < 3; k++) {
				modHessMatr(i * 3 + k, j) = modHessMatr(i * 3 + k, j) / sqrt(masses.at(i));
			}
		}
	}
	return modifiedHessianMatrix;
}

PhaseMap TrajectoryGeneratorLapack::preparePhasesMap(PhaseMap phaseMap, vector<int> superTrajs) {

	boost::unordered_map<int, double>::iterator it;
	for (uint i = 0; i < superTrajs.size(); i++) {
		int currentMode = superTrajs.at(i);
		it = phaseMap.find(currentMode);
		if (it == phaseMap.end()) {
			double min = 0.0;
			double max = 2.0*M_PI;
			double randomPhase = (double) rand() / (double) RAND_MAX;
			randomPhase = min + randomPhase * (max - min);
			phaseMap[currentMode] = randomPhase;
		} else {
			continue;
		}
	}

	return phaseMap;
}
