//  ************************************************************************************************
//
//  BornAgain: simulate and fit reflection and scattering
//
//! @file      Sim/Export/SimulationToPython.cpp
//! @brief     Implements class SimulationToPython.
//!
//! @homepage  http://www.bornagainproject.org
//! @license   GNU General Public License v3 or higher (see COPYING)
//! @copyright Forschungszentrum Jülich GmbH 2018
//! @authors   Scientific Computing Group at MLZ (see CITATION, AUTHORS)
//
//  ************************************************************************************************

#include "Sim/Export/SimulationToPython.h"
#include "Base/Axis/Scale.h"
#include "Base/Py/PyFmt.h"
#include "Base/Util/Assert.h"
#include "Device/Beam/Beam.h"
#include "Device/Beam/FootprintGauss.h"
#include "Device/Beam/FootprintSquare.h"
#include "Device/Detector/OffspecDetector.h"
#include "Device/Detector/RectangularDetector.h"
#include "Device/Detector/SphericalDetector.h"
#include "Device/Mask/DetectorMask.h"
#include "Device/Resolution/ConvolutionDetectorResolution.h"
#include "Device/Resolution/ResolutionFunction2DGaussian.h"
#include "Param/Distrib/Distributions.h"
#include "Param/Node/NodeUtil.h"
#include "Resample/Options/SimulationOptions.h"
#include "Sim/Background/ConstantBackground.h"
#include "Sim/Background/PoissonBackground.h"
#include "Sim/Export/PyFmt2.h"
#include "Sim/Export/SampleToPython.h"
#include "Sim/Scan/AlphaScan.h"
#include "Sim/Scan/QzScan.h"
#include "Sim/Simulation/includeSimulations.h"
#include <iomanip>

using Py::Fmt::indent;

namespace {

//! Returns a function that converts a coordinate to a Python code snippet with appropriate unit
std::function<std::string(double)> printFunc(const IDetector& detector)
{
    if (detector.defaultCoords() == Coords::MM)
        return Py::Fmt::printDouble;
    if (detector.defaultCoords() == Coords::RADIANS)
        return Py::Fmt::printDegrees;
    ASSERT(false); // unknown detector units
}

bool isQuadraticDetector(const IDetector& det)
{
    if (det.axis(0).size() != det.axis(1).size())
        return false;
    if (std::abs(det.axis(0).span() - det.axis(1).span())
        > 1e-12 * (det.axis(0).span() + det.axis(1).span()))
        return false;
    return true;
}

//! Returns true if it is (0, -1, 0) vector
bool isDefaultDirection(const R3 direction)
{
    return fabs(direction.x()) < 5e-16 && fabs(direction.y() + 1) < 1e-15
           && fabs(direction.z()) < 5e-16;
}

std::string defineFootprint(const IFootprint& foot)
{
    std::ostringstream result;
    result << indent() << "footprint = ba." << foot.className();
    result << "(" << Py::Fmt::printDouble(foot.widthRatio()) << ")\n";
    return result.str();
}

std::string defineAlphaScan(const AlphaScan& scan)
{
    std::ostringstream result;
    result << indent() << "axis = " << Py::Fmt2::printAxis(scan.coordinateAxis(), "rad") << "\n"
           << indent() << "scan = "
           << "ba.AlphaScan(axis)\n";

    if (const IDistribution1D* d = scan.angleDistribution()) {
        result << indent() << "distribution = " << Py::Fmt2::printDistribution(*d);
        result << indent() << "scan.setAngleDistribution(distribution)\n";
    }
    if (const IDistribution1D* d = scan.wavelengthDistribution()) {
        result << indent() << "distribution = " << Py::Fmt2::printDistribution(*d);
        result << indent() << "scan.setWavelengthDistribution(distribution)\n";
    } else
        result << indent() << "scan.setWavelength(" << Py::Fmt::printDouble(scan.wavelength())
               << ")\n";
    return result.str();
}

std::string defineQzScan(const QzScan& scan)
{
    std::ostringstream result;
    const std::string axis_def = indent() + "axis = ";
    result << axis_def << Py::Fmt2::printAxis(scan.coordinateAxis(), "") << "\n";
    // TODO correct unit would be 1/nm

    result << indent() << "scan = ba.QzScan(axis)\n";
    if (const IDistribution1D* d = scan.qzDistribution()) {
        result << indent() << "distribution = " << Py::Fmt2::printDistribution(*d);
        if (scan.resolution_is_relative())
            result << indent() << "scan.setRelativeQResolution(distribution, "
                   << scan.resolution_widths().at(0) << ")\n";
        else if (scan.resolution_widths().size() == 1)
            result << indent() << "scan.setAbsoluteQResolution(distribution, "
                   << scan.resolution_widths().at(0) << ")\n";
        else
            ASSERT(false); // vector resolution export not yet implemented
    }
    return result.str();
}

std::string definePolarizationAnalyzer(const PolFilter& analyzer, const std::string parent)
{
    std::ostringstream result;
    const R3& v = analyzer.BlochVector();
    double transmission = analyzer.transmission();

    if (v.mag2() > 0.0) {
        std::string direction_name = "analyzer_Bloch_vector";
        result << indent() << direction_name << " = R3(" << Py::Fmt::printDouble(v.x()) << ", "
               << Py::Fmt::printDouble(v.y()) << ", " << Py::Fmt::printDouble(v.z()) << ")\n";
        result << indent() << parent << ".setAnalyzer(" << direction_name << ", "
               << "1"
               << ", " << Py::Fmt::printDouble(transmission) << ")\n";
    }
    return result.str();
}

std::string defineDetector(const IDetector& detector)
{
    std::ostringstream result;
    result << std::setprecision(12);

    if (const auto* const det = dynamic_cast<const SphericalDetector*>(&detector)) {
        result << indent() << "detector = ba.SphericalDetector(";
        if (isQuadraticDetector(*det)) {
            result << det->axis(0).size() << ", " << Py::Fmt::printDegrees(det->axis(0).span())
                   << ", " << Py::Fmt::printDegrees(det->axis(0).center()) << ", "
                   << Py::Fmt::printDegrees(det->axis(1).center());
        } else {
            result << det->axis(0).size() << ", " << Py::Fmt::printDegrees(det->axis(0).min())
                   << ", " << Py::Fmt::printDegrees(det->axis(0).max()) << ", "
                   << det->axis(1).size() << ", " << Py::Fmt::printDegrees(det->axis(1).min())
                   << ", " << Py::Fmt::printDegrees(det->axis(1).max());
        }
        result << ")\n";
    } else if (const auto* const det = dynamic_cast<const RectangularDetector*>(&detector)) {
        result << "\n";
        result << indent() << "detector = ba.RectangularDetector(" << det->xSize() << ", "
               << Py::Fmt::printDouble(det->width()) << ", " << det->ySize() << ", "
               << Py::Fmt::printDouble(det->height()) << ")\n";
        if (det->getDetectorArrangment() == RectangularDetector::GENERIC) {
            result << indent() << "detector.setDetectorPosition("
                   << Py::Fmt::printKvector(det->getNormalVector()) << ", "
                   << Py::Fmt::printDouble(det->getU0()) << ", "
                   << Py::Fmt::printDouble(det->getV0());
            if (!isDefaultDirection(det->getDirectionVector()))
                result << ", " << Py::Fmt::printKvector(det->getDirectionVector());
            result << ")\n";
        } else if (det->getDetectorArrangment() == RectangularDetector::PERPENDICULAR_TO_SAMPLE) {
            result << indent() << "detector.setPerpendicularToSampleX("
                   << Py::Fmt::printDouble(det->getDistance()) << ", "
                   << Py::Fmt::printDouble(det->getU0()) << ", "
                   << Py::Fmt::printDouble(det->getV0()) << ")\n";
        } else if (det->getDetectorArrangment()
                   == RectangularDetector::PERPENDICULAR_TO_DIRECT_BEAM) {
            result << indent() << "detector.setPerpendicularToDirectBeam("
                   << Py::Fmt::printDouble(det->getDistance()) << ", "
                   << Py::Fmt::printDouble(det->getU0()) << ", "
                   << Py::Fmt::printDouble(det->getV0()) << ")\n";
        } else if (det->getDetectorArrangment()
                   == RectangularDetector::PERPENDICULAR_TO_REFLECTED_BEAM) {
            result << indent() << "detector.setPerpendicularToReflectedBeam("
                   << Py::Fmt::printDouble(det->getDistance()) << ", "
                   << Py::Fmt::printDouble(det->getU0()) << ", "
                   << Py::Fmt::printDouble(det->getV0()) << ")\n";
        } else
            ASSERT(false); // unknown alignment
    } else
        ASSERT(false); // unknown detector
    if (detector.hasExplicitRegionOfInterest()) {
        const auto xBounds = detector.regionOfInterestBounds(0);
        const auto yBounds = detector.regionOfInterestBounds(1);
        result << indent() << "detector.setRegionOfInterest(" << printFunc(detector)(xBounds.first)
               << ", " << printFunc(detector)(yBounds.first) << ", "
               << printFunc(detector)(xBounds.second) << ", " << printFunc(detector)(yBounds.second)
               << ")\n";
    }
    result << definePolarizationAnalyzer(detector.analyzer(), "detector");

    if (const IDetectorResolution* resfunc = detector.detectorResolution()) {
        if (const auto* convfunc = dynamic_cast<const ConvolutionDetectorResolution*>(resfunc)) {
            if (const auto* resfunc = dynamic_cast<const ResolutionFunction2DGaussian*>(
                    convfunc->getResolutionFunction2D())) {
                result << indent() << "detector.setResolutionFunction(";
                result << "ba.ResolutionFunction2DGaussian(";
                result << printFunc(detector)(resfunc->sigmaX()) << ", ";
                result << printFunc(detector)(resfunc->sigmaY()) << "))\n";
            } else
                ASSERT(false); // unknown detector resolution function
        } else
            ASSERT(false); // not a ConvolutionDetectorResolution function
    }

    return result.str();
}

std::string defineBeamPolarization(const Beam& beam)
{
    std::ostringstream result;
    auto bloch_vector = beam.polVector();
    if (bloch_vector.mag() > 0.0) {
        std::string beam_polMatrices = "beam_polMatrices";
        result << indent() << beam_polMatrices << " = R3(" << Py::Fmt::printDouble(bloch_vector.x())
               << ", " << Py::Fmt::printDouble(bloch_vector.y()) << ", "
               << Py::Fmt::printDouble(bloch_vector.z()) << ")\n";
        result << indent() << "beam.setPolarization(" << beam_polMatrices << ")\n";
    }
    return result.str();
}

std::string defineGISASBeam(const ScatteringSimulation& simulation)
{
    std::ostringstream result;
    const Beam& beam = simulation.beam();

    if (beam.intensity() == 1) {
        result << indent() << "beam = ba.Beam(1, ";
    } else {
        result << indent() << "beam = ba.Beam(" << Py::Fmt::printDouble(beam.intensity()) << ", ";
    }
    result << Py::Fmt::printNm(beam.wavelength()) << ", " << Py::Fmt::printDegrees(beam.alpha_i());
    if (beam.phi_i() != 0)
        result << ", " << Py::Fmt::printDegrees(beam.phi_i());
    result << ")\n";

    if (const IFootprint* fp = beam.footprint()) {
        result << defineFootprint(*fp);
        result << indent() << "beam.setFootprint(footprint)\n";
    }
    result << defineBeamPolarization(beam);

    return result.str();
}

std::string defineBeamScan(const IBeamScan& scan)
{
    std::ostringstream result;
    if (const auto* s = dynamic_cast<const AlphaScan*>(&scan))
        result << defineAlphaScan(*s);
    else if (const auto* s = dynamic_cast<const QzScan*>(&scan))
        result << defineQzScan(*s);
    else
        ASSERT(false);
    if (scan.intensity() != 1)
        result << indent() << "scan.setIntensity(" << scan.intensity() << ")\n";
    if (const IFootprint* fp = scan.footprint()) {
        result << defineFootprint(*fp);
        result << indent() << "scan.setFootprint(footprint)\n";
    }
    const PolFilter* analyzer = scan.analyzer();
    if (analyzer)
        result << definePolarizationAnalyzer(*analyzer, "scan");
    return result.str();
}

std::string defineParameterDistributions(const std::vector<ParameterDistribution>& distributions)
{
    std::ostringstream result;
    if (distributions.empty())
        return "";
    for (size_t i = 0; i < distributions.size(); ++i) {
        const std::string mainParUnits = distributions[i].unitOfParameter();

        const std::string distr = "distr_" + std::to_string(i + 1);
        result << indent() << distr << " = "
               << Py::Fmt2::printDistribution(*distributions[i].getDistribution());

        result << indent() << "simulation.addParameterDistribution(ba."
               << distributions[i].whichParameterAsPyEnum() << ", " << distr << ")\n";
    }
    return result.str();
}

std::string defineMasks(const IDetector& detector)
{
    std::ostringstream result;
    result << std::setprecision(12);

    const DetectorMask* detectorMask = detector.detectorMask();
    if (detectorMask && detectorMask->hasMasks()) {
        result << "\n";
        for (size_t i_mask = 0; i_mask < detectorMask->numberOfMasks(); ++i_mask) {
            const MaskPattern* pat = detectorMask->patternAt(i_mask);
            IShape2D* shape = pat->shape;
            bool mask_value = pat->doMask;
            result << Py::Fmt2::representShape2D(indent(), shape, mask_value, printFunc(detector));
        }
        result << "\n";
    }
    return result.str();
}

std::string defineSimulationOptions(const SimulationOptions& options)
{
    std::ostringstream result;
    result << std::setprecision(12);

    if (options.getHardwareConcurrency() != options.getNumberOfThreads())
        result << indent() << "simulation.options().setNumberOfThreads("
               << options.getNumberOfThreads() << ")\n";
    if (options.isIntegrate())
        result << indent() << "simulation.options().setMonteCarloIntegration(True, "
               << options.getMcPoints() << ")\n";
    if (options.useAvgMaterials())
        result << indent() << "simulation.options().setUseAvgMaterials(True)\n";
    if (options.includeSpecular())
        result << indent() << "simulation.options().setIncludeSpecular(True)\n";
    return result.str();
}

std::string defineBackground(const ISimulation& simulation)
{
    std::ostringstream result;

    const auto* bg = simulation.background();
    if (const auto* constant_bg = dynamic_cast<const ConstantBackground*>(bg)) {
        if (constant_bg->backgroundValue() > 0.0) {
            result << indent() << "background = ba.ConstantBackground("
                   << Py::Fmt::printScientificDouble(constant_bg->backgroundValue()) << ")\n";
            result << indent() << "simulation.setBackground(background)\n";
        }
    } else if (dynamic_cast<const PoissonBackground*>(bg)) {
        result << indent() << "background = ba.PoissonBackground()\n";
        result << indent() << "simulation.setBackground(background)\n";
    }
    return result.str();
}

std::string defineScatteringSimulation(const ScatteringSimulation& simulation)
{
    std::ostringstream result;
    result << "\n" << indent() << "# Define GISAS simulation:\n";
    result << defineGISASBeam(simulation);
    result << defineDetector(simulation.detector());
    result << indent() << "simulation = ba.ScatteringSimulation(beam, sample, detector)\n";
    result << defineParameterDistributions(simulation.paramDistributions());
    result << defineMasks(simulation.detector());
    result << defineSimulationOptions(simulation.options());
    result << defineBackground(simulation);
    return result.str();
}

std::string defineOffspecSimulation(const OffspecSimulation& simulation)
{
    std::ostringstream result;
    result << "\n" << indent() << "# Define off-specular simulation:\n";
    result << defineBeamScan(*simulation.scan());

    const OffspecDetector& detector = simulation.detector();
    result << indent() << "detector = ba.OffspecDetector(";
    result << std::setprecision(12);
    result << detector.axis(0).size() << ", " << Py::Fmt::printDegrees(detector.axis(0).min())
           << ", " << Py::Fmt::printDegrees(detector.axis(0).max()) << ", "
           << detector.axis(1).size() << ", " << Py::Fmt::printDegrees(detector.axis(1).min())
           << ", " << Py::Fmt::printDegrees(detector.axis(1).max());
    result << ")\n";
    result << definePolarizationAnalyzer(detector.analyzer(), "detector");

    result << indent() << "simulation = ba.OffspecSimulation(scan, sample, detector)\n";
    result << defineParameterDistributions(simulation.paramDistributions());
    result << defineSimulationOptions(simulation.options());
    result << defineBackground(simulation);
    return result.str();
}

std::string defineSpecularSimulation(const SpecularSimulation& simulation)
{
    std::ostringstream result;
    result << "\n" << indent() << "# Define specular scan:\n";
    result << defineBeamScan(*simulation.scan());
    result << indent() << "simulation = ba.SpecularSimulation(scan, sample)\n";
    result << defineParameterDistributions(simulation.paramDistributions());
    result << defineSimulationOptions(simulation.options());
    result << defineBackground(simulation);
    // result << defineBeamIntensity(simulation.beam());
    result << "\n";
    return result.str();
}

std::string defineSimulate(const ISimulation& simulation)
{
    std::ostringstream result;
    result << "def get_simulation(sample):\n";
    if (const auto* s = dynamic_cast<const ScatteringSimulation*>(&simulation))
        result << defineScatteringSimulation(*s);
    else if (const auto* s = dynamic_cast<const OffspecSimulation*>(&simulation))
        result << defineOffspecSimulation(*s);
    else if (const auto* s = dynamic_cast<const SpecularSimulation*>(&simulation))
        result << defineSpecularSimulation(*s);
    else
        ASSERT(false);
    result << "    return simulation\n\n\n";

    return result.str();
}

std::string simulationCode(const ISimulation& simulation)
{
    ASSERT(simulation.sample());
    std::string code =
        SampleToPython().sampleCode(*simulation.sample()) + defineSimulate(simulation);
    return "import bornagain as ba\n" + Py::Fmt::printImportedSymbols(code) + "\n\n" + code;
}

} // namespace

//  ************************************************************************************************
//  class SimulationToPython
//  ************************************************************************************************

std::string SimulationToPython::simulationPlotCode(const ISimulation& simulation)
{
    return simulationCode(simulation)
           + "if __name__ == '__main__':\n"
             "    from bornagain import ba_plot as bp\n"
             "    bp.parse_args()\n"
             "    sample = get_sample()\n"
             "    simulation = get_simulation(sample)\n"
             "    result = simulation.simulate()\n"
             "    bp.plot_simulation_result(result)\n";
}

std::string SimulationToPython::simulationSaveCode(const ISimulation& simulation,
                                                   const std::string& fname)
{
    return simulationCode(simulation)
           + "if __name__ == '__main__':\n"
             "    sample = get_sample()\n"
             "    simulation = get_simulation(sample)\n"
             "    result = simulation.simulate()\n"
             "    ba.writeDatafield(result, \""
           + fname + "\")\n";
}
