//  ************************************************************************************************
//
//  BornAgain: simulate and fit reflection and scattering
//
//! @file      Resample/Particle/IReParticle.cpp
//! @brief     Implements interface class IReParticle.
//!
//! @homepage  http://www.bornagainproject.org
//! @license   GNU General Public License v3 or higher (see COPYING)
//! @copyright Forschungszentrum Jülich GmbH 2018
//! @authors   Scientific Computing Group at MLZ (see CITATION, AUTHORS)
//
//  ************************************************************************************************

#include "Resample/Particle/IReParticle.h"
#include "Base/Spin/SpinMatrix.h"
#include "Base/Util/Assert.h"
#include "Base/Vector/WavevectorInfo.h"
#include "Resample/Element/DiffuseElement.h"
#include "Resample/Flux/MatrixFlux.h"
#include "Resample/Flux/ScalarFlux.h"
#include "Sample/Material/Admixtures.h"
#include <memory>

IReParticle::IReParticle() = default;

IReParticle::IReParticle(const std::optional<size_t>& i_layer)
    : m_i_layer(i_layer)
{
}

IReParticle::~IReParticle() = default;

OneAdmixture IReParticle::admixed() const
{
    return {m_admixed_fraction, *m_admixed_material};
}

void IReParticle::setAdmixedFraction(double fraction)
{
    m_admixed_fraction = fraction;
}

void IReParticle::setAdmixedMaterial(const Material& material)
{
    m_admixed_material = std::make_unique<Material>(material);
}

double IReParticle::volume() const
{
    auto zero_wavevectors = WavevectorInfo::makeZeroQ();
    return std::abs(theFF(zero_wavevectors));
}

complex_t IReParticle::coherentFF(const DiffuseElement& ele) const
{
    const WavevectorInfo& wavevectors = ele.wavevectorInfo();

    if (!i_layer().has_value())
        // no slicing, pure Born approximation
        return theFF(wavevectors);

    ASSERT(i_layer().has_value());
    const size_t iLayer = i_layer().value();

    const auto* inFlux = dynamic_cast<const ScalarFlux*>(ele.fluxIn(iLayer));
    const auto* outFlux = dynamic_cast<const ScalarFlux*>(ele.fluxOut(iLayer));

    // Retrieve the two different incoming wavevectors in the layer
    const C3& ki = wavevectors.getKi();
    const complex_t kiz = inFlux->getScalarKz();
    const C3 k_i_T{ki.x(), ki.y(), -kiz};
    const C3 k_i_R{ki.x(), ki.y(), +kiz};

    // Retrieve the two different outgoing wavevector bins in the layer
    const C3& kf = wavevectors.getKf();
    const complex_t kfz = outFlux->getScalarKz();
    const C3 k_f_T{kf.x(), kf.y(), +kfz};
    const C3 k_f_R{kf.x(), kf.y(), -kfz};

    // Construct the four different scattering contributions wavevector infos
    const double wavelength = wavevectors.vacuumLambda();
    const WavevectorInfo q_TT(k_i_T, k_f_T, wavelength);
    const WavevectorInfo q_RT(k_i_R, k_f_T, wavelength);
    const WavevectorInfo q_TR(k_i_T, k_f_R, wavelength);
    const WavevectorInfo q_RR(k_i_R, k_f_R, wavelength);

    // Get the four R,T coefficients
    const complex_t T_in = inFlux->getScalarT();
    const complex_t R_in = inFlux->getScalarR();
    const complex_t T_out = outFlux->getScalarT();
    const complex_t R_out = outFlux->getScalarR();

    // The four different scattering contributions;
    // S stands for scattering off the particle,
    // R for reflection off the layer interface.

    // Note that the order of multiplication matters:
    // If a prefactor is 0, then theFF() won't be called.

    const complex_t term_S = T_in * T_out * theFF(q_TT);
    const complex_t term_RS = R_in * T_out * theFF(q_RT);
    const complex_t term_SR = T_in * R_out * theFF(q_TR);
    const complex_t term_RSR = R_in * R_out * theFF(q_RR);

    return term_S + term_RS + term_SR + term_RSR;
}

SpinMatrix IReParticle::coherentPolFF(const DiffuseElement& ele) const
{
    const WavevectorInfo& wavevectors = ele.wavevectorInfo();

    if (!i_layer().has_value()) {
        // no slicing, pure Born approximation
        SpinMatrix o = thePolFF(wavevectors);
        return {-o.c, +o.a, -o.d, +o.b};
    }
    const size_t iLayer = i_layer().value();

    const auto* inFlux = dynamic_cast<const MatrixFlux*>(ele.fluxIn(iLayer));
    const auto* outFlux = dynamic_cast<const MatrixFlux*>(ele.fluxOut(iLayer));

    // the required wavevectors inside the layer for
    // different eigenmodes and in- and outgoing wavevector;
    const complex_t kix = wavevectors.getKi().x();
    const complex_t kiy = wavevectors.getKi().y();
    const Spinor& kiz = inFlux->getKz();
    const C3 ki_1R{kix, kiy, +kiz.u};
    const C3 ki_1T{kix, kiy, -kiz.u};
    const C3 ki_2R{kix, kiy, +kiz.v};
    const C3 ki_2T{kix, kiy, -kiz.v};

    const complex_t kfx = wavevectors.getKf().x();
    const complex_t kfy = wavevectors.getKf().y();
    const Spinor& kfz = outFlux->getKz();
    const C3 kf_1R{kfx, kfy, -kfz.u};
    const C3 kf_1T{kfx, kfy, +kfz.u};
    const C3 kf_2R{kfx, kfy, -kfz.v};
    const C3 kf_2T{kfx, kfy, +kfz.v};

    // now each of the 16 matrix terms of the polarized DWBA is calculated:
    // NOTE: when the underlying reflection/transmission coefficients are
    // scalar, the eigenmodes have identical eigenvalues and spin polarization
    // is used as a basis; in this case however the matrices get mixed:
    //     real m_M11 = calculated m_M12
    //     real m_M12 = calculated m_M11
    //     real m_M21 = calculated m_M22
    //     real m_M22 = calculated m_M21
    // since both eigenvalues are identical, this does not influence the result.
    SpinMatrix ff_BA; // will be overwritten

    double wavelength = wavevectors.vacuumLambda();

    // The following matrices each contain the four polarization conditions
    // (p->p, p->m, m->p, m->m)
    // The first two indices indicate a scattering from the 1/2 eigenstate into
    // the 1/2 eigenstate, while the capital indices indicate a reflection
    // before and/or after the scattering event (first index is in-state,
    // second is out-state; this also applies to the internal matrix indices)

    // eigenmode 1 -> eigenmode 1: direct scattering
    ff_BA = thePolFF({ki_1T, kf_1T, wavelength});
    SpinMatrix M11_S(-DotProduct(outFlux->T1min(), ff_BA * inFlux->T1plus()),
                     +DotProduct(outFlux->T1plus(), ff_BA * inFlux->T1plus()),
                     -DotProduct(outFlux->T1min(), ff_BA * inFlux->T1min()),
                     +DotProduct(outFlux->T1plus(), ff_BA * inFlux->T1min()));
    // eigenmode 1 -> eigenmode 1: reflection and then scattering
    ff_BA = thePolFF({ki_1R, kf_1T, wavelength});
    SpinMatrix M11_RS(-DotProduct(outFlux->T1min(), ff_BA * inFlux->R1plus()),
                      +DotProduct(outFlux->T1plus(), ff_BA * inFlux->R1plus()),
                      -DotProduct(outFlux->T1min(), ff_BA * inFlux->R1min()),
                      +DotProduct(outFlux->T1plus(), ff_BA * inFlux->R1min()));
    // eigenmode 1 -> eigenmode 1: scattering and then reflection
    ff_BA = thePolFF({ki_1T, kf_1R, wavelength});
    SpinMatrix M11_SR(-DotProduct(outFlux->R1min(), ff_BA * inFlux->T1plus()),
                      +DotProduct(outFlux->R1plus(), ff_BA * inFlux->T1plus()),
                      -DotProduct(outFlux->R1min(), ff_BA * inFlux->T1min()),
                      +DotProduct(outFlux->R1plus(), ff_BA * inFlux->T1min()));
    // eigenmode 1 -> eigenmode 1: reflection, scattering and again reflection
    ff_BA = thePolFF({ki_1R, kf_1R, wavelength});
    SpinMatrix M11_RSR(-DotProduct(outFlux->R1min(), ff_BA * inFlux->R1plus()),
                       +DotProduct(outFlux->R1plus(), ff_BA * inFlux->R1plus()),
                       -DotProduct(outFlux->R1min(), ff_BA * inFlux->R1min()),
                       +DotProduct(outFlux->R1plus(), ff_BA * inFlux->R1min()));

    // eigenmode 1 -> eigenmode 2: direct scattering
    ff_BA = thePolFF({ki_1T, kf_2T, wavelength});
    SpinMatrix M12_S(-DotProduct(outFlux->T2min(), ff_BA * inFlux->T1plus()),
                     +DotProduct(outFlux->T2plus(), ff_BA * inFlux->T1plus()),
                     -DotProduct(outFlux->T2min(), ff_BA * inFlux->T1min()),
                     +DotProduct(outFlux->T2plus(), ff_BA * inFlux->T1min()));
    // eigenmode 1 -> eigenmode 2: reflection and then scattering
    ff_BA = thePolFF({ki_1R, kf_2T, wavelength});
    SpinMatrix M12_RS(-DotProduct(outFlux->T2min(), ff_BA * inFlux->R1plus()),
                      +DotProduct(outFlux->T2plus(), ff_BA * inFlux->R1plus()),
                      -DotProduct(outFlux->T2min(), ff_BA * inFlux->R1min()),
                      +DotProduct(outFlux->T2plus(), ff_BA * inFlux->R1min()));
    // eigenmode 1 -> eigenmode 2: scattering and then reflection
    ff_BA = thePolFF({ki_1T, kf_2R, wavelength});
    SpinMatrix M12_SR(-DotProduct(outFlux->R2min(), ff_BA * inFlux->T1plus()),
                      +DotProduct(outFlux->R2plus(), ff_BA * inFlux->T1plus()),
                      -DotProduct(outFlux->R2min(), ff_BA * inFlux->T1min()),
                      +DotProduct(outFlux->R2plus(), ff_BA * inFlux->T1min()));
    // eigenmode 1 -> eigenmode 2: reflection, scattering and again reflection
    ff_BA = thePolFF({ki_1R, kf_2R, wavelength});
    SpinMatrix M12_RSR(-DotProduct(outFlux->R2min(), ff_BA * inFlux->R1plus()),
                       +DotProduct(outFlux->R2plus(), ff_BA * inFlux->R1plus()),
                       -DotProduct(outFlux->R2min(), ff_BA * inFlux->R1min()),
                       +DotProduct(outFlux->R2plus(), ff_BA * inFlux->R1min()));

    // eigenmode 2 -> eigenmode 1: direct scattering
    ff_BA = thePolFF({ki_2T, kf_1T, wavelength});
    SpinMatrix M21_S(-DotProduct(outFlux->T1min(), ff_BA * inFlux->T2plus()),
                     +DotProduct(outFlux->T1plus(), ff_BA * inFlux->T2plus()),
                     -DotProduct(outFlux->T1min(), ff_BA * inFlux->T2min()),
                     +DotProduct(outFlux->T1plus(), ff_BA * inFlux->T2min()));
    // eigenmode 2 -> eigenmode 1: reflection and then scattering
    ff_BA = thePolFF({ki_2R, kf_1T, wavelength});
    SpinMatrix M21_RS(-DotProduct(outFlux->T1min(), ff_BA * inFlux->R2plus()),
                      +DotProduct(outFlux->T1plus(), ff_BA * inFlux->R2plus()),
                      -DotProduct(outFlux->T1min(), ff_BA * inFlux->R2min()),
                      +DotProduct(outFlux->T1plus(), ff_BA * inFlux->R2min()));
    // eigenmode 2 -> eigenmode 1: scattering and then reflection
    ff_BA = thePolFF({ki_2T, kf_1R, wavelength});
    SpinMatrix M21_SR(-DotProduct(outFlux->R1min(), ff_BA * inFlux->T2plus()),
                      +DotProduct(outFlux->R1plus(), ff_BA * inFlux->T2plus()),
                      -DotProduct(outFlux->R1min(), ff_BA * inFlux->T2min()),
                      +DotProduct(outFlux->R1plus(), ff_BA * inFlux->T2min()));
    // eigenmode 2 -> eigenmode 1: reflection, scattering and again reflection
    ff_BA = thePolFF({ki_2R, kf_1R, wavelength});
    SpinMatrix M21_RSR(-DotProduct(outFlux->R1min(), ff_BA * inFlux->R2plus()),
                       +DotProduct(outFlux->R1plus(), ff_BA * inFlux->R2plus()),
                       -DotProduct(outFlux->R1min(), ff_BA * inFlux->R2min()),
                       +DotProduct(outFlux->R1plus(), ff_BA * inFlux->R2min()));

    // eigenmode 2 -> eigenmode 2: direct scattering
    ff_BA = thePolFF({ki_2T, kf_2T, wavelength});
    SpinMatrix M22_S(-DotProduct(outFlux->T2min(), ff_BA * inFlux->T2plus()),
                     +DotProduct(outFlux->T2plus(), ff_BA * inFlux->T2plus()),
                     -DotProduct(outFlux->T2min(), ff_BA * inFlux->T2min()),
                     +DotProduct(outFlux->T2plus(), ff_BA * inFlux->T2min()));
    // eigenmode 2 -> eigenmode 2: reflection and then scattering
    ff_BA = thePolFF({ki_2R, kf_2T, wavelength});
    SpinMatrix M22_RS(-DotProduct(outFlux->T2min(), ff_BA * inFlux->R2plus()),
                      +DotProduct(outFlux->T2plus(), ff_BA * inFlux->R2plus()),
                      -DotProduct(outFlux->T2min(), ff_BA * inFlux->R2min()),
                      +DotProduct(outFlux->T2plus(), ff_BA * inFlux->R2min()));
    // eigenmode 2 -> eigenmode 2: scattering and then reflection
    ff_BA = thePolFF({ki_2T, kf_2R, wavelength});
    SpinMatrix M22_SR(-DotProduct(outFlux->R2min(), ff_BA * inFlux->T2plus()),
                      +DotProduct(outFlux->R2plus(), ff_BA * inFlux->T2plus()),
                      -DotProduct(outFlux->R2min(), ff_BA * inFlux->T2min()),
                      +DotProduct(outFlux->R2plus(), ff_BA * inFlux->T2min()));
    // eigenmode 2 -> eigenmode 2: reflection, scattering and again reflection
    ff_BA = thePolFF({ki_2R, kf_2R, wavelength});
    SpinMatrix M22_RSR(-DotProduct(outFlux->R2min(), ff_BA * inFlux->R2plus()),
                       +DotProduct(outFlux->R2plus(), ff_BA * inFlux->R2plus()),
                       -DotProduct(outFlux->R2min(), ff_BA * inFlux->R2min()),
                       +DotProduct(outFlux->R2plus(), ff_BA * inFlux->R2min()));

    return M11_S + M11_RS + M11_SR + M11_RSR + M12_S + M12_RS + M12_SR + M12_RSR + M21_S + M21_RS
           + M21_SR + M21_RSR + M22_S + M22_RS + M22_SR + M22_RSR;
}
