diff --git a/EvtGenBase/EvtGenKine.hh b/EvtGenBase/EvtGenKine.hh
--- a/EvtGenBase/EvtGenKine.hh
+++ b/EvtGenBase/EvtGenKine.hh
@@ -22,6 +22,7 @@
 #define EVTGENKINE_HH
 
 class EvtVector4R;
+class EvtParticle;
 
 class EvtGenKine {
   public:
@@ -30,6 +31,14 @@
 
     static double PhaseSpacePole( double M, double m1, double m2, double m3,
                                   double a, EvtVector4R p4[10] );
+
+    /*
+     * Function which takes two invariant masses squared in 3-body decay and
+     * parent after makeDaughters() and generateMassTree() and
+     * calculates/generates momenta of daughters and sets those.
+     */
+    static void ThreeBodyKine( const double m12Sq, const double m23Sq,
+                               EvtParticle* p );
 };
 
 #endif
diff --git a/EvtGenModels/EvtFlatSqDalitz.hh b/EvtGenModels/EvtFlatSqDalitz.hh
--- a/EvtGenModels/EvtFlatSqDalitz.hh
+++ b/EvtGenModels/EvtFlatSqDalitz.hh
@@ -40,6 +40,12 @@
     void initProbMax() override;
 
     void decay( EvtParticle* p ) override;
+
+  private:
+    double m_mPrimeMin{ 0. };
+    double m_mPrimeMax{ 1. };
+    double m_thetaPrimeMin{ 0. };
+    double m_thetaPrimeMax{ 1. };
 };
 
 #endif
diff --git a/src/EvtGenBase/EvtGenKine.cpp b/src/EvtGenBase/EvtGenKine.cpp
--- a/src/EvtGenBase/EvtGenKine.cpp
+++ b/src/EvtGenBase/EvtGenKine.cpp
@@ -25,6 +25,7 @@
 #include "EvtGenBase/EvtRandom.hh"
 #include "EvtGenBase/EvtReport.hh"
 #include "EvtGenBase/EvtVector4R.hh"
+#include "EvtGenBase/EvtParticle.hh"
 
 #include <iostream>
 #include <math.h>
@@ -340,3 +341,60 @@
 
     return 1.0 + a / ( m12sq * m12sq );
 }
+
+/*
+ * Function which takes two invariant masses squared in 3-body decay and
+ * parent after makeDaughters() and generateMassTree() and
+ * calculates/generates momenta of daughters and sets those.
+ */
+void EvtGenKine::ThreeBodyKine( const double m12Sq, const double m23Sq,
+                                EvtParticle* p )
+{
+    const double mParent = p->mass();
+    EvtParticle* daug1 = p->getDaug( 0 );
+    EvtParticle* daug2 = p->getDaug( 1 );
+    EvtParticle* daug3 = p->getDaug( 2 );
+    const double mDaug1 = daug1->mass();
+    const double mDaug2 = daug2->mass();
+    const double mDaug3 = daug3->mass();
+    const double mParentSq{ mParent * mParent };
+    const double mDaug1Sq{ mDaug1 * mDaug1 };
+    const double mDaug2Sq{ mDaug2 * mDaug2 };
+    const double mDaug3Sq{ mDaug3 * mDaug3 };
+    const double invMParent{ 1. / mParent };
+
+    const double En1 = 0.5 * ( mParentSq + mDaug1Sq - m23Sq ) * invMParent;
+    const double En3 = 0.5 * ( mParentSq + mDaug3Sq - m12Sq ) * invMParent;
+    const double En2 = mParent - En1 - En3;
+    const double p1mag = std::sqrt( En1 * En1 - mDaug1Sq );
+    const double p2mag = std::sqrt( En2 * En2 - mDaug2Sq );
+    double cosPhi = 0.5 * ( mDaug1Sq + mDaug2Sq + 2 * En1 * En2 - m12Sq ) /
+                    ( p1mag * p2mag );
+
+    double sinPhi = std::sqrt( 1 - cosPhi * cosPhi );
+    if ( EvtRandom::Flat( 0., 1. ) > 0.5 ) {
+        sinPhi *= -1;
+    }
+    const double p2x = p2mag * cosPhi;
+    const double p2y = p2mag * sinPhi;
+    const double p3x = -p1mag - p2x;
+    const double p3y = -p2y;
+
+    // Construct 4-momenta and rotate them randomly in space
+    EvtVector4R p1( En1, p1mag, 0., 0. );
+    EvtVector4R p2( En2, p2x, p2y, 0. );
+    EvtVector4R p3( En3, p3x, p3y, 0. );
+    const double euler1 = EvtRandom::Flat( 0., EvtConst::twoPi );
+    const double euler2 = std::acos( EvtRandom::Flat( -1.0, 1.0 ) );
+    const double euler3 = EvtRandom::Flat( 0., EvtConst::twoPi );
+    p1.applyRotateEuler( euler1, euler2, euler3 );
+    p2.applyRotateEuler( euler1, euler2, euler3 );
+    p3.applyRotateEuler( euler1, euler2, euler3 );
+
+    // set momenta for daughters
+    daug1->init( daug1->getId(), p1 );
+    daug2->init( daug2->getId(), p2 );
+    daug3->init( daug3->getId(), p3 );
+
+    return;
+}
diff --git a/src/EvtGenModels/EvtFlatSqDalitz.cpp b/src/EvtGenModels/EvtFlatSqDalitz.cpp
--- a/src/EvtGenModels/EvtFlatSqDalitz.cpp
+++ b/src/EvtGenModels/EvtFlatSqDalitz.cpp
@@ -20,14 +20,13 @@
 
 #include "EvtGenModels/EvtFlatSqDalitz.hh"
 
-#include "EvtGenBase/EvtDiracSpinor.hh"
+#include "EvtGenBase/EvtConst.hh"
 #include "EvtGenBase/EvtGenKine.hh"
 #include "EvtGenBase/EvtPDL.hh"
 #include "EvtGenBase/EvtParticle.hh"
 #include "EvtGenBase/EvtPatches.hh"
+#include "EvtGenBase/EvtRandom.hh"
 #include "EvtGenBase/EvtReport.hh"
-#include "EvtGenBase/EvtTensor4C.hh"
-#include "EvtGenBase/EvtVector4C.hh"
 
 #include <fstream>
 #include <stdio.h>
@@ -51,72 +50,66 @@
 
 void EvtFlatSqDalitz::initProbMax()
 {
-    setProbMax( 1. );
+    noProbMax();
 }
 
 void EvtFlatSqDalitz::init()
 {
-    // check that there are 0 arguments
-    checkNArg( 0 );
-
     //check there are 3 daughters
     checkNDaug( 3 );
+
+    // check that there are 0 arguments
+    checkNArg( 0, 2, 4 );
+
+    if ( getNArg() > 0 ) {
+        m_mPrimeMin = getArg( 0 );
+        m_mPrimeMax = getArg( 1 );
+    }
+    if ( getNArg() > 2 ) {
+        m_thetaPrimeMin = getArg( 2 );
+        m_thetaPrimeMax = getArg( 3 );
+    }
 }
 
 void EvtFlatSqDalitz::decay( EvtParticle* p )
 {
-    p->initializePhaseSpace( getNDaug(), getDaugs() );
-
-    double mB = p->mass();
-    double m1 = p->getDaug( 0 )->mass();
-    double m2 = p->getDaug( 1 )->mass();
-    double m3 = p->getDaug( 2 )->mass();
-
-    EvtVector4R p4_1 = p->getDaug( 0 )->getP4();
-    EvtVector4R p4_2 = p->getDaug( 1 )->getP4();
-    EvtVector4R p4_3 = p->getDaug( 2 )->getP4();
-
-    EvtVector4R p4_12 = p4_1 + p4_2;
-    EvtVector4R p4_13 = p4_1 + p4_3;
-    // do not compute p4_23 to avoid breaking p4 conservation ???
-    EvtVector4R p4_23 = p4_2 + p4_3;
-
-    double m12 = p4_12.mass();
-    double m13 = p4_13.mass();
-    double m23 = p4_23.mass();
-
-    double m12norm = 2 * ( ( m12 - ( m1 + m2 ) ) / ( mB - ( m1 + m2 + m3 ) ) ) -
-                     1;
-    double mPrime = acos( m12norm ) / EvtConst::pi;
-    double thPrime = acos( ( m12 * m12 * ( m23 * m23 - m13 * m13 ) -
-                             ( m2 * m2 - m1 * m1 ) * ( mB * mB - m3 * m3 ) ) /
-                           ( sqrt( pow( m12 * m12 + m1 * m1 - m2 * m2, 2 ) -
-                                   4 * m12 * m12 * m1 * m1 ) *
-                             sqrt( pow( -m12 * m12 + mB * mB - m3 * m3, 2 ) -
-                                   4 * m12 * m12 * m3 * m3 ) ) ) /
-                     EvtConst::pi;
-
-    double p3st = sqrt( pow( mB * mB - m3 * m3 - m12 * m12, 2 ) -
-                        pow( 2 * m12 * m3, 2 ) ) /
-                  ( 2 * m12 );
-    double p1st = sqrt( pow( m2 * m2 - m1 * m1 - m12 * m12, 2 ) -
-                        pow( 2 * m12 * m1, 2 ) ) /
-                  ( 2 * m12 );
-
-    double jacobian = 2 * pow( EvtConst::pi, 2 ) * sin( EvtConst::pi * mPrime ) *
-                      sin( EvtConst::pi * thPrime ) * p1st * p3st * m12 *
-                      ( mB - ( m1 + m2 + m3 ) );
-
-    double prob = 1. / jacobian;    //pow(1./(jacobian),2);
-
-    //  std::cout << mB << " " << mPrime << " " << thPrime << " " << prob << std::endl;
-
-    setProb( prob );
-
-    if ( prob < 1 )
-        setProb( prob );
-    else
-        setProb( 1. );
+    p->makeDaughters( getNDaug(), getDaugs() );
+    p->generateMassTree();
+    const double mParent = p->mass();
+    EvtParticle* daug1 = p->getDaug( 0 );
+    EvtParticle* daug2 = p->getDaug( 1 );
+    EvtParticle* daug3 = p->getDaug( 2 );
+    const double mDaug1 = daug1->mass();
+    const double mDaug2 = daug2->mass();
+    const double mDaug3 = daug3->mass();
+    const double mParentSq = mParent * mParent;
+    const double mDaug1Sq = mDaug1 * mDaug1;
+    const double mDaug2Sq = mDaug2 * mDaug2;
+    const double mDaug3Sq = mDaug3 * mDaug3;
+
+    // Generate m' and theta'
+    const double mPrime = EvtRandom::Flat( m_mPrimeMin, m_mPrimeMax );
+    const double thetaPrime = EvtRandom::Flat( m_thetaPrimeMin, m_thetaPrimeMax );
+
+    // calculate m12 and m23
+    const double m12 = 0.5 * ( std::cos( mPrime * EvtConst::pi ) + 1 ) *
+                           ( mParent - ( mDaug1 + mDaug2 + mDaug3 ) ) +
+                       mDaug1 + mDaug2;
+    const double m12Sq = m12 * m12;
+
+    const double en1 = ( m12Sq - mDaug2Sq + mDaug1Sq ) / ( 2. * m12 );
+    const double en3 = ( mParentSq - m12Sq - mDaug3Sq ) / ( 2. * m12 );
+
+    const double p1 = std::sqrt( en1 * en1 - mDaug1Sq );
+    const double p3 = std::sqrt( en3 * en3 - mDaug3Sq );
+    const double m13Sq =
+        mDaug1Sq + mDaug3Sq +
+        2.0 * ( en1 * en3 - p1 * p3 * std::cos( EvtConst::pi * thetaPrime ) );
+    const double m23Sq = mParentSq - m12Sq - m13Sq + mDaug1Sq + mDaug2Sq +
+                         mDaug3Sq;
+
+    // Turn m12 and m23 into momenta
+    EvtGenKine::ThreeBodyKine( m12Sq, m23Sq, p );
 
     return;
 }
diff --git a/src/EvtGenModels/EvtThreeBodyPhsp.cpp b/src/EvtGenModels/EvtThreeBodyPhsp.cpp
--- a/src/EvtGenModels/EvtThreeBodyPhsp.cpp
+++ b/src/EvtGenModels/EvtThreeBodyPhsp.cpp
@@ -20,14 +20,15 @@
 
 #include "EvtGenModels/EvtThreeBodyPhsp.hh"
 
-#include "EvtGenBase/EvtParticle.hh"
-#include "EvtGenBase/EvtReport.hh"
 #include "EvtGenBase/EvtConst.hh"
+#include "EvtGenBase/EvtGenKine.hh"
+#include "EvtGenBase/EvtParticle.hh"
 #include "EvtGenBase/EvtRandom.hh"
+#include "EvtGenBase/EvtReport.hh"
 
-#include <iostream>
 #include <algorithm>
 #include <cmath>
+#include <iostream>
 
 std::string EvtThreeBodyPhsp::getName()
 {
@@ -67,11 +68,11 @@
     p->makeDaughters( getNDaug(), getDaugs() );
     p->generateMassTree();
     const double mParent = p->mass();
-    EvtParticle* daug1 = p->getDaug(0);
-    EvtParticle* daug2 = p->getDaug(1);
-    EvtParticle* daug3 = p->getDaug(2);
-    const double mDaug1 = daug1->mass(); 
-    const double mDaug2 = daug2->mass(); 
+    EvtParticle* daug1 = p->getDaug( 0 );
+    EvtParticle* daug2 = p->getDaug( 1 );
+    EvtParticle* daug3 = p->getDaug( 2 );
+    const double mDaug1 = daug1->mass();
+    const double mDaug2 = daug2->mass();
     const double mDaug3 = daug3->mass();
 
     const double m12borderMin = mDaug1 + mDaug2;
@@ -115,42 +116,7 @@
     }
 
     // At this moment we have valid invariant masses squared
-    const double En1 = 0.5 * ( mParent * mParent + mDaug1 * mDaug1 - m23Sq ) / mParent;
-    const double En3 = 0.5 * ( mParent * mParent + mDaug3 * mDaug3 - m12Sq ) / mParent;
-    const double En2 = mParent - En1 - En3;
-    const double p1mag = std::sqrt( En1 * En1 - mDaug1 * mDaug1 );
-    const double p2mag = std::sqrt( En2 * En2 - mDaug2 * mDaug2 );
-    double cosPhi = 0.5 * ( mDaug1 * mDaug1 + mDaug2 * mDaug2 +
-                                  2 * En1 * En2 - m12Sq )/p1mag/p2mag;
-    if ( cosPhi > 1.0 ) {
-        EvtGenReport( EVTGEN_WARNING, "EvtThreeBodyPhsp" )
-            << " Model: cosine larger than one: " << cosPhi << std::endl;
-        cosPhi = 1.0;
-    }
-    double sinPhi = std::sqrt( 1 - cosPhi * cosPhi );
-    if ( EvtRandom::Flat( 0., 1. ) > 0.5 ) {
-        sinPhi *= -1;
-    }
-    const double p2x = p2mag * cosPhi;
-    const double p2y = p2mag * sinPhi;
-    const double p3x = -p1mag - p2x;
-    const double p3y = -p2y;
-
-    // Construct 4-momenta and rotate them randomly in space
-    EvtVector4R p1( En1, p1mag, 0., 0. );
-    EvtVector4R p2( En2, p2x, p2y, 0. );
-    EvtVector4R p3( En3, p3x, p3y, 0. );
-    const double euler1 = EvtRandom::Flat( 0., EvtConst::twoPi );
-    const double euler2 = std::acos( EvtRandom::Flat( -1.0, 1.0 ) );
-    const double euler3 = EvtRandom::Flat( 0., EvtConst::twoPi );
-    p1.applyRotateEuler(euler1, euler2, euler3);
-    p2.applyRotateEuler(euler1, euler2, euler3);
-    p3.applyRotateEuler(euler1, euler2, euler3);
-
-    // set momenta for daughters
-    daug1->init( getDaug( 0 ), p1 );
-    daug2->init( getDaug( 1 ), p2 );
-    daug3->init( getDaug( 2 ), p3 );
+    EvtGenKine::ThreeBodyKine( m12Sq, m23Sq, p );
 
     return;
 }
diff --git a/test/do_tests b/test/do_tests
--- a/test/do_tests
+++ b/test/do_tests
@@ -98,3 +98,4 @@
 time ./evtgenlhc_test1 baryonic 1000
 time ./evtgenlhc_test1 phspdecaytimecut 10000
 time ./evtgenlhc_test1 3bodyPhsp 1000000
+time ./evtgenlhc_test1 flatSqDalitz 1000000
diff --git a/test/evtgenlhc_test1.cc b/test/evtgenlhc_test1.cc
--- a/test/evtgenlhc_test1.cc
+++ b/test/evtgenlhc_test1.cc
@@ -157,6 +157,7 @@
                            TH1F* mom = 0 );
 void runBaryonic( int nEvent, EvtGen& myGenerator );
 void run3BPhspRegion( int nEvent, EvtGen& myGenerator );
+void runFlatSqDalitz( int nEvent, EvtGen& myGenerator );
 
 int main( int argc, char* argv[] )
 {
@@ -531,7 +532,13 @@
     if ( !strcmp( argv[1], "3bodyPhsp" ) ) {
         int nevent = atoi( argv[2] );
         EvtRadCorr::setNeverRadCorr();
-        run3BPhspRegion( nevent, myGenerator);
+        run3BPhspRegion( nevent, myGenerator );
+    }
+
+    if ( !strcmp( argv[1], "flatSqDalitz" ) ) {
+        int nevent = atoi( argv[2] );
+        EvtRadCorr::setNeverRadCorr();
+        runFlatSqDalitz( nevent, myGenerator );
     }
 
     //*******************************************************
@@ -5753,3 +5760,62 @@
     file->Close();
     EvtGenReport( EVTGEN_INFO, "EvtGen" ) << "SUCCESS\n";
 }
+
+void runFlatSqDalitz( int nevent, EvtGen& myGenerator )
+{
+    TFile* file = new TFile( "flatSqDalitz.root", "RECREATE" );
+
+    TH2F* dalitz = new TH2F( "h4", "Dalitz", 50, 0.0, 1.0, 50, 0.0, 1.0 );
+
+    int count = 1;
+
+    char udecay_name[100];
+    strcpy( udecay_name, "exampleFiles/flatSqDalitz.dec" );
+    myGenerator.readUDecay( udecay_name );
+
+    static EvtId B = EvtPDL::getId( std::string( "Lambda_b0" ) );
+
+    do {
+        EvtVector4R pinit( EvtPDL::getMass( B ), 0.0, 0.0, 0.0 );
+
+        EvtParticle* root_part = EvtParticleFactory::particleFactory( B, pinit );
+
+        myGenerator.generateDecay( root_part );
+
+        double mB = root_part->mass();
+        double m1 = root_part->getDaug( 0 )->mass();
+        double m2 = root_part->getDaug( 1 )->mass();
+        double m3 = root_part->getDaug( 2 )->mass();
+        double mBSq{ mB * mB };
+        double m1Sq{ m1 * m1 };
+        double m2Sq{ m2 * m2 };
+        double m3Sq{ m3 * m3 };
+
+        EvtParticle* daug1 = root_part->getDaug( 0 );
+        EvtParticle* daug2 = root_part->getDaug( 1 );
+        EvtParticle* daug3 = root_part->getDaug( 2 );
+        double m12 = ( daug1->getP4() + daug2->getP4() ).mass();
+        double m13 = ( daug1->getP4() + daug3->getP4() ).mass();
+        double m12Sq{ m12 * m12 };
+        double m13Sq{ m13 * m13 };
+
+        double m12norm =
+            2 * ( ( m12 - ( m1 + m2 ) ) / ( mB - ( m1 + m2 + m3 ) ) ) - 1;
+        double mPrime = acos( m12norm ) / EvtConst::pi;
+        double en1 = ( m12Sq - m2Sq + m1Sq ) / ( 2.0 * m12 );
+        double en3 = ( mBSq - m12Sq - m3Sq ) / ( 2.0 * m12 );
+        double p1 = std::sqrt( en1 * en1 - m1Sq );
+        double p3 = std::sqrt( en3 * en3 - m3Sq );
+        double cosTheta = ( -m13Sq + m1Sq + m3Sq + 2. * en1 * en3 ) /
+                          ( 2. * p1 * p3 );
+        double thPrime = acos( cosTheta ) / EvtConst::pi;
+
+        dalitz->Fill( mPrime, thPrime );
+
+        root_part->deleteTree();
+    } while ( count++ < nevent );
+
+    file->Write();
+    file->Close();
+    EvtGenReport( EVTGEN_INFO, "EvtGen" ) << "SUCCESS\n";
+}
diff --git a/test/exampleFiles/flatSqDalitz.dec b/test/exampleFiles/flatSqDalitz.dec
new file mode 100644
--- /dev/null
+++ b/test/exampleFiles/flatSqDalitz.dec
@@ -0,0 +1,17 @@
+Alias      MyLambda     Lambda0
+Alias      MyantiLambda anti-Lambda0
+ChargeConj MyLambda     MyantiLambda
+#
+Decay Lambda_b0
+  1.000        K+      K-      MyLambda      FLATSQDALITZ;
+Enddecay
+CDecay anti-Lambda_b0
+#
+Decay MyLambda
+  1.000        p+      pi-                PHSP;
+Enddecay
+CDecay MyantiLambda
+#
+End
+
+