From 6b5c7f13d3940e261e5218e58e0c243008eecc4e Mon Sep 17 00:00:00 2001
From: Pietro Incardona <incardon@mpi-cbg.de>
Date: Sun, 9 Jun 2019 22:42:20 +0200
Subject: [PATCH] Fixing getVector

---
 .../vector_dist_operators_tests_util.hpp      |  35 +++++
 .../Vector/vector_dist_operators.hpp          | 120 +++++++++++++++++-
 .../vector_dist_operators_extensions.hpp      |   5 +-
 .../vector_dist_operators_functions.hpp       |  18 +++
 4 files changed, 176 insertions(+), 2 deletions(-)

diff --git a/src/Operators/Vector/tests/vector_dist_operators_tests_util.hpp b/src/Operators/Vector/tests/vector_dist_operators_tests_util.hpp
index a421ba95..ff873680 100644
--- a/src/Operators/Vector/tests/vector_dist_operators_tests_util.hpp
+++ b/src/Operators/Vector/tests/vector_dist_operators_tests_util.hpp
@@ -911,6 +911,13 @@ void check_all_expressions_imp(vector_type & vd,
 	vA = (vC + vB) / (vC + vB);
 	check_values_div_4<float,vtype,A,B,C,impl>(vd);
 
+	if (impl == comp_host)
+	{
+		auto test = vC + vB;
+		auto & v = test.getVector();
+		BOOST_REQUIRE_EQUAL((void *)&v,(void *)&vd);
+	}
+
 	// We try with vectors
 
 	// Various combination of 2 operator
@@ -943,6 +950,13 @@ void check_all_expressions_imp(vector_type & vd,
 	vVA = vVC / vVB;
 	check_values_div<VectorS<3,float>,vtype,VA,VB,VC,impl>(vd,vd);
 
+	if (impl == comp_host)
+	{
+		auto test = vVB / 2.0;
+		auto & v = test.getVector();
+		BOOST_REQUIRE_EQUAL((void *)&v,(void *)&vd);
+	}
+
 	// Variuos combination 3 operator
 
 	vVA = vVB + (vVC + vVB);
@@ -972,6 +986,13 @@ void check_all_expressions_imp(vector_type & vd,
 	vA = (vVC + vVB) * (vVC + vVB);
 	check_values_mul_4<float,vtype,A,VB,VC,impl>(vd);
 
+	if (impl == comp_host)
+	{
+		auto test = (vVC + vVB) * (vVC + vVB);
+		auto & v = test.getVector();
+		BOOST_REQUIRE_EQUAL((void *)&v,(void *)&vd);
+	}
+
 	vVA = vVB / (vVC + vVB);
 	check_values_div_31<VectorS<3,float>,vtype,VA,VB,VC,impl>(vd);
 	vVA = (vVC + vVB) / vVB;
@@ -999,6 +1020,13 @@ void check_all_expressions_imp(vector_type & vd,
 	vVA = -vVB;
 	check_values_point_sub<Point<3,float>,vtype,VA,VB,VC,impl>(vd,p0);
 
+	if (impl == comp_host)
+	{
+		auto test = vPOS + p0_e;
+		auto & v = test.getVector();
+		BOOST_REQUIRE_EQUAL((void *)&v,(void *)&vd);
+	}
+
 	// Just check it compile testing it will test the same code
 	// as the previuous one
 	vVC = exp(vVB);
@@ -1010,6 +1038,13 @@ void check_all_expressions_imp(vector_type & vd,
 	vVA = 2.0 - vPOS;
 	vVA = vPOS - vPOS;
 
+	if (impl == comp_host)
+	{
+		auto test = exp(vVB);
+		auto & v = test.getVector();
+		BOOST_REQUIRE_EQUAL((void *)&v,(void *)&vd);
+	}
+
 	vVA = vPOS * 2.0;
 	vVA = 2.0 * vPOS;
 	vVA = vPOS * vPOS;
diff --git a/src/Operators/Vector/vector_dist_operators.hpp b/src/Operators/Vector/vector_dist_operators.hpp
index c6e4a19e..363feb36 100644
--- a/src/Operators/Vector/vector_dist_operators.hpp
+++ b/src/Operators/Vector/vector_dist_operators.hpp
@@ -128,6 +128,43 @@ class vector_dist_expression_op
 
 };
 
+template<typename v1_type, typename v2_type>
+struct vector_result
+{
+	typedef v1_type type;
+
+	template<typename exp1, typename exp2>
+	static const type & getVector(const exp1 & o1, const exp2 & o2)
+	{
+		return o1.getVector();
+	}
+
+	template<typename exp1>
+	static const type & getVector(const exp1 & o1)
+	{
+		return o1.getVector();
+	}
+};
+
+
+template<typename v2_type>
+struct vector_result<void,v2_type>
+{
+	typedef v2_type type;
+
+	template<typename exp1, typename exp2>
+	static const type & getVector(const exp1 & o1, const exp2 & o2)
+	{
+		return o2.getVector();
+	}
+
+	template<typename exp2>
+	static const type & getVector(exp2 & o2)
+	{
+		return o2.getVector();
+	}
+};
+
 /*! \brief Sum operation
  *
  * \tparam exp1 expression1
@@ -145,13 +182,27 @@ class vector_dist_expression_op<exp1,exp2,VECT_SUM>
 
 public:
 
+	//! indicate if this vector is kernel type
 	typedef typename exp1::is_ker is_ker;
 
+	//! return the vector type on which this expression operate
+	typedef typename vector_result<typename exp1::vtype,typename exp2::vtype>::type vtype;
+
 	//! constructor of the expression to sum two expression
 	inline vector_dist_expression_op(const exp1 & o1, const exp2 & o2)
 	:o1(o1),o2(o2)
 	{}
 
+	/*! \brief Return the underlying vector
+	 *
+	 * \return the vector
+	 *
+	 */
+	const vtype & getVector() const
+	{
+		return vector_result<typename exp1::vtype,typename exp2::vtype>::getVector(o1,o2);
+	}
+
 	/*! \brief This function must be called before value
 	 *
 	 * it initialize the expression if needed
@@ -210,11 +261,24 @@ public:
 
 	typedef typename exp1::is_ker is_ker;
 
+	//! return the vector type on which this expression operate
+	typedef typename vector_result<typename exp1::vtype,typename exp2::vtype>::type vtype;
+
 	//! Costruct a subtraction expression out of two expressions
 	inline vector_dist_expression_op(const exp1 & o1, const exp2 & o2)
 	:o1(o1),o2(o2)
 	{}
 
+	/*! \brief Return the underlying vector
+	 *
+	 * \return the vector
+	 *
+	 */
+	const vtype & getVector()
+	{
+		return vector_result<typename exp1::vtype,typename exp2::vtype>::getVector(o1,o2);
+	}
+
 	/*! \brief This function must be called before value
 	 *
 	 * it initialize the expression if needed
@@ -272,11 +336,24 @@ public:
 
 	typedef typename exp1::is_ker is_ker;
 
+	//! return the vector type on which this expression operate
+	typedef typename vector_result<typename exp1::vtype,typename exp2::vtype>::type vtype;
+
 	//! constructor from two expressions
 	vector_dist_expression_op(const exp1 & o1, const exp2 & o2)
 	:o1(o1),o2(o2)
 	{}
 
+	/*! \brief Return the underlying vector
+	 *
+	 * \return the vector
+	 *
+	 */
+	const vtype & getVector()
+	{
+		return vector_result<typename exp1::vtype,typename exp2::vtype>::getVector(o1,o2);
+	}
+
 	/*! \brief This function must be called before value
 	 *
 	 * it initialize the expression if needed
@@ -333,11 +410,24 @@ public:
 
 	typedef typename exp1::is_ker is_ker;
 
+	//! return the vector type on which this expression operate
+	typedef typename vector_result<typename exp1::vtype,typename exp2::vtype>::type vtype;
+
 	//! constructor from two expressions
 	vector_dist_expression_op(const exp1 & o1, const exp2 & o2)
 	:o1(o1),o2(o2)
 	{}
 
+	/*! \brief Return the underlying vector
+	 *
+	 * \return the vector
+	 *
+	 */
+	const vtype & getVector()
+	{
+		return vector_result<typename exp1::vtype,typename exp2::vtype>::getVector(o1,o2);
+	}
+
 	/*! \brief This function must be called before value
 	 *
 	 * it initialize the expression if needed
@@ -389,11 +479,24 @@ public:
 
 	typedef typename exp1::is_ker is_ker;
 
+	//! return the vector type on which this expression operate
+	typedef typename vector_result<typename exp1::vtype,void>::type vtype;
+
 	//! constructor from an expresssion
 	vector_dist_expression_op(const exp1 & o1)
 	:o1(o1)
 	{}
 
+	/*! \brief Return the underlying vector
+	 *
+	 * \return the vector
+	 *
+	 */
+	const vtype & getVector()
+	{
+		return vector_result<typename exp1::vtype,void>::getVector(o1);
+	}
+
 	//! initialize the expression tree
 	inline void init() const
 	{
@@ -482,6 +585,7 @@ public:
 	//! Property id of the point
 	static const unsigned int prop = prp;
 
+
 	//! constructor for an external vector
 	vector_dist_expression(vector & v)
 	:v(v),vdl(NULL)
@@ -494,6 +598,18 @@ public:
 		{vdl->remove(v.v);}
 	}
 
+	/*! \brief Return the vector on which is acting
+	 *
+	 * It return the vector used in getVExpr, to get this object
+	 *
+	 * \return the vector
+	 *
+	 */
+	__device__ __host__ const vector & getVector() const
+	{
+		return v.v;
+	}
+
 	/*! \brief Return the vector on which is acting
 	 *
 	 * It return the vector used in getVExpr, to get this object
@@ -674,6 +790,8 @@ public:
 
 	typedef std::false_type is_ker;
 
+	typedef void vtype;
+
 	//! constructor from a constant expression
 	inline vector_dist_expression(const double & d)
 	:d(d)
@@ -732,7 +850,7 @@ public:
 	typedef std::false_type is_ker;
 
 	//! type of object the structure return then evaluated
-	typedef float vtype;
+	typedef void vtype;
 
 	//! constrictor from constant value
 	inline vector_dist_expression(const float & d)
diff --git a/src/Operators/Vector/vector_dist_operators_extensions.hpp b/src/Operators/Vector/vector_dist_operators_extensions.hpp
index adf208df..c24bb661 100644
--- a/src/Operators/Vector/vector_dist_operators_extensions.hpp
+++ b/src/Operators/Vector/vector_dist_operators_extensions.hpp
@@ -15,7 +15,8 @@
  * \param v
  *
  */
-template <unsigned int dim, typename T> inline vector_dist_expression<16384,Point<dim,T> > getVExpr(Point<dim,T> & v)
+template <unsigned int dim, typename T>
+inline vector_dist_expression<16384,Point<dim,T> > getVExpr(Point<dim,T> & v)
 {
 	vector_dist_expression<(unsigned int)16384,Point<dim,T>> exp_v(v);
 
@@ -36,6 +37,8 @@ class vector_dist_expression<16384,point>
 
 public:
 
+	typedef void vtype;
+
 	//! vector expression from a constant point
 	vector_dist_expression(point p)
 	:p(p)
diff --git a/src/Operators/Vector/vector_dist_operators_functions.hpp b/src/Operators/Vector/vector_dist_operators_functions.hpp
index 3b05f254..9958aedb 100644
--- a/src/Operators/Vector/vector_dist_operators_functions.hpp
+++ b/src/Operators/Vector/vector_dist_operators_functions.hpp
@@ -30,10 +30,17 @@ class vector_dist_expression_op<exp1,void,OP_ID>\
 public:\
 \
 	typedef typename exp1::is_ker is_ker;\
+\
+	typedef typename vector_result<typename exp1::vtype,void>::type vtype;\
 \
 	vector_dist_expression_op(const exp1 & o1)\
 	:o1(o1)\
 	{}\
+\
+	const vtype & getVector()\
+	{\
+		return vector_result<typename exp1::vtype,void>::getVector(o1);\
+	}\
 \
 	inline const exp1 & getExpr() const\
 	{\
@@ -134,6 +141,14 @@ class vector_dist_expression_op<exp1,exp2,OP_ID>\
 public:\
 \
 	typedef std::integral_constant<bool,exp1::is_ker::value || exp1::is_ker::value> is_ker;\
+\
+	typedef typename vector_result<typename exp1::vtype,typename exp2::vtype>::type vtype;\
+\
+	vtype & getVector()\
+	{\
+		return vector_result<typename exp1::vtype,typename exp2::vtype>::getVector(o1,o2);\
+	}\
+\
 \
 	vector_dist_expression_op(const exp1 & o1, const exp2 & o2)\
 	:o1(o1),o2(o2)\
@@ -322,6 +337,9 @@ public:
 	//! Indicate if it is an in kernel expression
 	typedef typename exp1::is_ker is_ker;
 
+	//! return the vector type on which this expression operate
+	typedef typename vector_result<typename exp1::vtype,void>::type vtype;
+
 	//! constructor from an epxression exp1 and a vector vd
 	vector_dist_expression_op(const exp1 & o1)
 	:o1(o1),val(0)
-- 
GitLab