1 #ifndef CAFFE_SOLVER_HPP_ 2 #define CAFFE_SOLVER_HPP_ 3 #include <boost/function.hpp> 7 #include "caffe/net.hpp" 8 #include "caffe/solver_factory.hpp" 9 #include "caffe/util/benchmark.hpp" 21 namespace SolverAction {
41 template <
typename Dtype>
44 explicit Solver(
const SolverParameter& param);
45 explicit Solver(
const string& param_file);
46 void Init(
const SolverParameter& param);
53 void SetActionFunction(ActionCallback func);
54 SolverAction::Enum GetRequestedAction();
57 virtual void Solve(
const char* resume_file = NULL);
58 inline void Solve(
const string resume_file) { Solve(resume_file.c_str()); }
63 void Restore(
const char* resume_file);
70 inline const SolverParameter& param()
const {
return param_; }
71 inline shared_ptr<Net<Dtype> > net() {
return net_; }
72 inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
75 int iter()
const {
return iter_; }
80 virtual void on_start() = 0;
81 virtual void on_gradients_ready() = 0;
86 const vector<Callback*>& callbacks()
const {
return callbacks_; }
88 callbacks_.push_back(value);
91 void CheckSnapshotWritePermissions();
95 virtual inline const char*
type()
const {
return ""; }
99 virtual void ApplyUpdate() = 0;
100 string SnapshotFilename(
const string extension);
101 string SnapshotToBinaryProto();
102 string SnapshotToHDF5();
105 void Test(
const int test_net_id = 0);
106 virtual void SnapshotSolverState(
const string& model_filename) = 0;
107 virtual void RestoreSolverStateFromHDF5(
const string& state_file) = 0;
108 virtual void RestoreSolverStateFromBinaryProto(
const string& state_file) = 0;
109 void DisplayOutputBlobs(
const int net_id);
110 void UpdateSmoothedLoss(Dtype loss,
int start_iter,
int average_loss);
112 SolverParameter param_;
115 shared_ptr<Net<Dtype> > net_;
116 vector<shared_ptr<Net<Dtype> > > test_nets_;
117 vector<Callback*> callbacks_;
118 vector<Dtype> losses_;
119 Dtype smoothed_loss_;
123 ActionCallback action_request_function_;
126 bool requested_early_exit_;
129 Timer iteration_timer_;
130 float iterations_last_;
132 DISABLE_COPY_AND_ASSIGN(
Solver);
137 #endif // CAFFE_SOLVER_HPP_ Definition: benchmark.hpp:10
A layer factory that allows one to register layers. During runtime, registered layers can be called b...
Definition: blob.hpp:14
Definition: solver.hpp:78
An interface for classes that perform optimization on Nets.
Definition: solver.hpp:42
virtual const char * type() const
Returns the solver type.
Definition: solver.hpp:95
boost::function< SolverAction::Enum()> ActionCallback
Type of a function that returns a Solver Action enumeration.
Definition: solver.hpp:33