C++ code generation: any way of deflating large expressions?

Chris Dams chrisd at sci.kun.nl
Thu Jul 10 16:41:06 CEST 2003


Hello,

I also ran into the problem of generating source code. I wrote a C++ class
to do it. It is attached. The following code fragment hopefully makes more
or less clear how to use it

   exoutclass p;
   p.getfunction(symbol("matelsq"),matelsq);
   p.setfunction(eta,eta);
   p.setfunction(z,z);
   p.setfunction(xi,xi);
   p.setfunction(t,t);
   p.setfunction(s,s);
   ofstream hfile("regmsq.h");
   hfile << "#include \"msq.h\"\n";
   p.writeouth("regmsq","double",print_csrc_double(hfile),"msq");
   hfile << "extern double mmu;" << endl;
   hfile << "extern double MW;" << endl;
   hfile.close();
   ofstream Cfile("regmsq.C");
   p.writeoutC("regmsq","double",print_csrc_double(Cfile));
   Cfile << "double mmu=0;" << endl;
   Cfile << "double MW=" << MWval << ";" << endl;
   Cfile.close();

The idea is that this creates a C++ class with methods sets(), sett(),
setxi(), setz() and seteta() and a method getmatelsq() that lets you get
the result. The set methods should be called in opposite order as they are
created with the setfunction methods. If you want to change, say, the
value of xi, the methods setxi(), setz() and seteta() should be called in
that order before calling getmatelsq(). In the setfunction the first
argument should be a symbol of which the name occurs in the C++ output as
a variable name, the second argument is the expression that gets a value
by calling the appropriate setvariable() method.

Bye,
Chris Dams


-------------- next part --------------
An embedded and charset-unspecified text was scrubbed...
Name: exoutclass.h
Url: http://www.cebix.net/pipermail/ginac-list/attachments/20030710/0d5d027b/exoutclass.h
-------------- next part --------------
#include "exoutclass.h"
#include <iostream>
using namespace std;

bool exless::operator()(const ex&x,const ex&y)
{	int compresult=x.compare(y);
	if(!compresult && (x.info(info_flags::integer)+
				y.info(info_flags::integer))==1)
		if(x.info(info_flags::integer))
			return 1;
		else
			return 0;
	return compresult==-1;
}

Restargs restargs;

ex rest_eval(const ex&x)
{	set<ex,exless>::iterator loc=restargs.heap.find(x);
	if(loc==restargs.heap.end())
		restargs.heap.insert(x);
	else
		return rest(*loc).hold();
	return rest(x).hold();
}
REGISTER_FUNCTION(rest,eval_func(rest_eval))

splitrestrest::splitrestrest(const ex&x)
{	second=1;
	if(!is_a<mul>(x))
		first=x;
	else
		first=x.map(*this);
}

ex splitrestrest::operator()(const ex&x)
{	if(!x.match(rest(wild(0))))
		return x;
	second=x.op(0);
	return 1;
}

ex Restify::operator()(const ex&x,const symbol&s_)
{	s=s_;
	return this->operator()(x);
}

ex Restify::operator()(const ex&x)
{	if(!x.has(s))
		return rest(x);
	if(x==s)
		return x;
	if(is_a<power>(x)||is_a<function>(x))
		return x.map(*this);
	if(is_a<mul>(x))
	{	ex tmp=x.map(*this);
		ex result=1;
		ex restarg=1;
		for(int i=0;i<tmp.nops();i++)
		{	if(tmp.op(i).match(rest(wild())))
				restarg*=tmp.op(i).op(0);
			else
				result*=tmp.op(i);
		}
		if(restarg==1)
			return result;
		return result*rest(restarg);
	}
	if(is_a<add>(x))
	{	ex tmp=x.map(*this);
		ex result=0;
		map<ex,ex,exless>termmap;
		typedef map<ex,ex,exless>::iterator exmapit;
		for(int i=0;i<tmp.nops();i++)
		{	if(tmp.op(i).match(rest(wild())))
				result+=tmp.op(i).op(0);
			else
			{	splitrestrest srr(tmp.op(i));
				exmapit j=termmap.find(srr.first);
				if(j==termmap.end())
					termmap.insert(srr);
				else
					termmap[srr.first]+=srr.second;
			}
		}
		if(result!=0)
			result=rest(result);
		for(exmapit i=termmap.begin();i!=termmap.end();i++)
			result+=(i->first)*rest(i->second);
		return result;
	}
	return x;
}

Restify restify;

void exoutclass::getfunction(const symbol&s,const ex&x)
{	slist.push_back(substitution(s,x,substitution::getfun));
}

ex exoutclass::Makerestexpr::operator()(const ex&x)
{	if(x.match(rest(wild(0))))
	{	if(is_a<numeric>(x.op(0)))
			return x.op(0);
		map<ex,ex,exless>::iterator mapit=m.find(x.op(0));
		if(mapit==m.end())
		{	symbol s=eoc->newvar();
			m[x.op(0)]=s;
			return s;
		}
		else
			return mapit->second;
	}
	else if(x.nops())
		return x.map(*this);
	return x;
}

void exoutclass::setfunction(const symbol&s,const ex&x)
{	Makerestexpr makerestexpr(this);
	for(slistit i=slist.end();i!=slist.begin();)
	{	i--;
		if((i->k)==substitution::setfun)
			break;
		i->subsex=(i->subsex).subs(x==s);
		i->subsex=restify(i->subsex,s);
		i->subsex=makerestexpr(i->subsex);
		//if(is_a<symbol>(i->subsex))
		//{	for(slistit j=slist.begin();j!=i;j++)
		//		j->subsex.subs(i->subssym==i->subsex);
		//	i=slist.erase(i);
		//}
	}
	restargs.clear();
	slist.push_back(substitution(s,0,substitution::setfun));
	for(map<ex,ex,exless>::iterator i=makerestexpr.m.begin();
						i!=makerestexpr.m.end();i++)
		slist.push_back(substitution(i->second,i->first,
							substitution::temp));
}

void exoutclass::writeoutbase(char*name,char*ftype,const print_context&c)
{	c.s << "#ifndef name_" << "VAR" << endl;
	c.s << "#define name_" << "VAR" << endl;
	c.s << "class " << name << endl;
	c.s << "{private:typedef " << ftype << " floattype;" << endl;
	c.s << "public:" << name << "(){}" << endl;
	slistit i=--slist.end();
	do
	{	switch(i->k)
		{	case substitution::setfun:
				c.s << "public:virtual void set"
				    << ex_to<symbol>(i->subssym).get_name()
				    << "(" << ftype << ")=0;" << endl;;
				break;
			case substitution::getfun:
				c.s << "public: virtual " << ftype << " get"
				    << ex_to<symbol>(i->subssym).get_name()
				    << "()=0;" << endl;;
				break;
		}
	} while(i==slist.begin() ? false : (i--,true));
	c.s << "};" << endl;
	c.s << "#endif" << endl;
}

void exoutclass::writeouth(char*name,char*ftype,const print_context&c,char*parent)
{	c.s << "class " << name;
	if(*parent)
		c.s << ":public " << parent;
	c.s << endl;
	c.s << "{private:typedef " << ftype << " floattype;" << endl;
	c.s << "private:int functionsdone;" << endl;
	c.s << "public:" << name << "();" << endl;
	slistit i=--slist.end();
	do
	{	switch(i->k)
		{	case substitution::setfun:
				c.s << "public:void set"
				    << ex_to<symbol>(i->subssym).get_name()
				    << "(" << ftype << ");" << endl;;
				c.s << "private:floattype "
				    << ex_to<symbol>(i->subssym).get_name()
				    << " ;" << endl;
				break;
			case substitution::getfun:
				c.s << "private:" << ftype << " "
				    << ex_to<symbol>(i->subssym).get_name()
				    << ";" << endl;;
				c.s << "public:" << ftype << " get"
				    << ex_to<symbol>(i->subssym).get_name()
				    << "();" << endl;;
				break;
			case substitution::temp:
				c.s << "private:floattype "
				    << ex_to<symbol>(i->subssym).get_name()
				    << ";" << endl;
		}
	} while(i==slist.begin() ? false : (i--,true));
	c.s << "};" << endl;
}

void exoutclass::writeoutC(char*name,char*ftype,const print_context&c)
{	int donecount=0;
	c.s << "#include <iostream>" << endl;
	c.s << "#include \"" << name << ".h\"" << endl;
	c.s << "using namespace std;" << endl;
	c.s << name << "::" << name << "()" << endl;
	c.s << "{" << endl;
	c.s << "functionsdone=0;" << endl;
	slistit startsection=slist.end();
	while(true)
	{	slistit pos=startsection;
		while(pos!=slist.begin()&&(--pos)->k!=substitution::setfun)
		{	c.s << ex_to<symbol>(pos->subssym).get_name() << "=";
			(pos->subsex).print(c);
			c.s << ";" << endl;
		}
		c.s << "}" << endl;
		pos=startsection;
		while(pos!=slist.begin()&&(--pos)->k!=substitution::setfun)
			if(pos->k==substitution::getfun)
			{	c.s << name << "::floattype "<< name << "::get"
				    << pos->subssym << "()" << endl;
				c.s << "{" << endl;
				c.s << "if(functionsdone<" << donecount << ")"
				    << endl;
				c.s << "cerr << \"Warning: function get"
				    << pos->subssym
				    << " called prematurely.\" << endl;"
				    << endl;
				c.s << "return " << pos->subssym << ";" << endl;
				c.s << "}" << endl;
			}
		if(pos==slist.begin())
			break;
		c.s << "void " << name << "::set" << pos->subssym << "("
		    << ftype << " " << pos->subssym << ")" << endl;
		c.s << "{" << endl;
		c.s << "if(functionsdone<" << donecount << ")" << endl;
		c.s << "cerr << \"Warning: function set" << pos->subssym
		    << " called prematurely.\" << endl;" << endl;
		donecount++;
		c.s << "functionsdone=" << donecount << ";" << endl;
		startsection=pos;
	}
}

void exoutclass::print()
{	for(slistit i=slist.begin();i!=slist.end();i++)
		cout << i->subssym << " = " << i->subsex << endl;
}


More information about the GiNaC-list mailing list