10#ifndef ROL_FDIVERGENCE_HPP
11#define ROL_FDIVERGENCE_HPP
78 ROL_TEST_FOR_EXCEPTION((
thresh_ <=
zero), std::invalid_argument,
79 ">>> ERROR (ROL::FDivergence): Threshold must be positive!");
102 ROL::ParameterList &list
103 = parlist.sublist(
"SOL").sublist(
"Risk Measure").sublist(
"F-Divergence");
104 thresh_ = list.get<Real>(
"Threshold");
115 virtual Real
Fprimal(Real x,
int deriv = 0)
const = 0;
129 virtual Real
Fdual(Real x,
int deriv = 0)
const = 0;
131 bool check(std::ostream &outStream = std::cout)
const {
135 Real x =
static_cast<Real
>(rand())/
static_cast<Real
>(RAND_MAX);
136 Real t =
static_cast<Real
>(rand())/
static_cast<Real
>(RAND_MAX);
139 outStream <<
"Check Fenchel-Young Inequality: F(x) + F*(t) >= xt" << std::endl;
140 outStream <<
"x = " << x << std::endl;
141 outStream <<
"t = " << t << std::endl;
142 outStream <<
"F(x) = " << fp << std::endl;
143 outStream <<
"F*(t) = " << fd << std::endl;
144 outStream <<
"Is Valid? " << (fp+fd >= x*t) << std::endl;
145 flag = (fp+fd >= x*t) ? flag :
false;
147 x =
static_cast<Real
>(rand())/
static_cast<Real
>(RAND_MAX);
151 outStream <<
"Check Fenchel-Young Equality: F(x) + F(t) = xt for t = d/dx F(x)" << std::endl;
152 outStream <<
"x = " << x << std::endl;
153 outStream <<
"t = " << t << std::endl;
154 outStream <<
"F(x) = " << fp << std::endl;
155 outStream <<
"F*(t) = " << fd << std::endl;
156 outStream <<
"Is Valid? " << (std::abs(fp+fd - x*t)<=tol) << std::endl;
157 flag = (std::abs(fp+fd - x*t)<=tol) ? flag :
false;
159 t =
static_cast<Real
>(rand())/
static_cast<Real
>(RAND_MAX);
163 outStream <<
"Check Fenchel-Young Equality: F(x) + F(t) = xt for x = d/dt F*(t)" << std::endl;
164 outStream <<
"x = " << x << std::endl;
165 outStream <<
"t = " << t << std::endl;
166 outStream <<
"F(x) = " << fp << std::endl;
167 outStream <<
"F*(t) = " << fd << std::endl;
168 outStream <<
"Is Valid? " << (std::abs(fp+fd - x*t)<=tol) << std::endl;
169 flag = (std::abs(fp+fd - x*t)<=tol) ? flag :
false;
182 const std::vector<Real> &xstat,
184 Real val = computeValue(obj,x,tol);
185 Real xlam = xstat[0];
187 Real r =
Fdual((val-xmu)/xlam,0);
192 const std::vector<Real> &xstat,
195 sampler.
sumAll(&val_,&val,1);
196 Real xlam = xstat[0];
198 return xlam*(
thresh_ + val) + xmu;
204 const std::vector<Real> &xstat,
206 Real val = computeValue(obj,x,tol);
207 Real xlam = xstat[0];
209 Real inp = (val-xmu)/xlam;
213 val_ += weight_ * r0;
218 computeGradient(*dualVector_,obj,x,tol);
219 g_->axpy(weight_*r1,*dualVector_);
224 std::vector<Real> &gstat,
226 const std::vector<Real> &xstat,
228 std::vector<Real> mygval(3), gval(3);
232 sampler.
sumAll(&mygval[0],&gval[0],3);
234 gstat[0] =
thresh_ + gval[0] + gval[1];
235 gstat[1] =
static_cast<Real
>(1) + gval[2];
242 const std::vector<Real> &vstat,
244 const std::vector<Real> &xstat,
246 Real val = computeValue(obj,x,tol);
247 Real xlam = xstat[0];
249 Real vlam = vstat[0];
251 Real inp = (val-xmu)/xlam;
254 Real gv = computeGradVec(*dualVector_,obj,v,x,tol);
255 val_ += weight_ * r2 * inp;
256 valLam_ += weight_ * r2 * inp * inp;
257 valLam2_ -= weight_ * r2 * gv * inp;
260 hv_->axpy(weight_ * r2 * (gv - vmu - vlam*inp)/xlam, *dualVector_);
263 computeHessVec(*dualVector_,obj,v,x,tol);
264 hv_->axpy(weight_ * r1, *dualVector_);
269 std::vector<Real> &hvstat,
271 const std::vector<Real> &vstat,
273 const std::vector<Real> &xstat,
275 std::vector<Real> myhval(5), hval(5);
281 sampler.
sumAll(&myhval[0],&hval[0],5);
283 std::vector<Real> stat(2);
284 Real xlam = xstat[0];
286 Real vlam = vstat[0];
288 hvstat[0] = (vlam * hval[1] + vmu * hval[0] + hval[2])/xlam;
289 hvstat[1] = (vlam * hval[0] + vmu * hval[3] + hval[4])/xlam;
Objective_SerialSimOpt(const Ptr< Obj > &obj, const V &ui) z0 zero)()
Contains definitions of custom data types in ROL.
void updateGradient(Objective< Real > &obj, const Vector< Real > &x, const std::vector< Real > &xstat, Real &tol)
void updateValue(Objective< Real > &obj, const Vector< Real > &x, const std::vector< Real > &xstat, Real &tol)
FDivergence(ROL::ParameterList &parlist)
Constructor.
void getHessVec(Vector< Real > &hv, std::vector< Real > &hvstat, const Vector< Real > &v, const std::vector< Real > &vstat, const Vector< Real > &x, const std::vector< Real > &xstat, SampleGenerator< Real > &sampler)
void getGradient(Vector< Real > &g, std::vector< Real > &gstat, const Vector< Real > &x, const std::vector< Real > &xstat, SampleGenerator< Real > &sampler)
virtual Real Fprimal(Real x, int deriv=0) const =0
Implementation of the scalar primal F function.
Real getValue(const Vector< Real > &x, const std::vector< Real > &xstat, SampleGenerator< Real > &sampler)
void initialize(const Vector< Real > &x)
FDivergence(const Real thresh)
Constructor.
void checkInputs(void) const
bool check(std::ostream &outStream=std::cout) const
virtual Real Fdual(Real x, int deriv=0) const =0
Implementation of the scalar dual F function.
void updateHessVec(Objective< Real > &obj, const Vector< Real > &v, const std::vector< Real > &vstat, const Vector< Real > &x, const std::vector< Real > &xstat, Real &tol)
Provides the interface to evaluate objective functions.
Provides the interface to implement any functional that maps a random variable to a (extended) real n...
void sumAll(Real *input, Real *output, int dim) const
Defines the linear algebra or vector space interface.
Real ROL_EPSILON(void)
Platform-dependent machine epsilon.