panann
FeedForwardNeuralNetwork.h
1 //-------------------------------------------------------------------------------------------------------
2 // Copyright (C) Taylor Woll and panann contributors. All rights reserved.
3 // Licensed under the MIT license. See LICENSE.txt file in the project root for
4 // full license information.
5 //-------------------------------------------------------------------------------------------------------
6 
7 #ifndef FEEDFORWARDNEURALNETWORK_H__
8 #define FEEDFORWARDNEURALNETWORK_H__
9 
10 #include <vector>
11 
12 #include "Perceptron.h"
13 
14 namespace panann {
15 
16 class TrainingData;
17 
22  public:
23  enum class TrainingAlgorithmType : uint8_t {
31  Backpropagation = 0,
76  };
77 
78  FeedForwardNeuralNetwork() = default;
80  FeedForwardNeuralNetwork& operator=(const FeedForwardNeuralNetwork&) = delete;
81  ~FeedForwardNeuralNetwork() override = default;
82 
87  void SetLearningRate(double learning_rate);
88  double GetLearningRate() const;
89 
94  void SetMomentum(double momentum);
95  double GetMomentum() const;
96 
101  void SetQpropMu(double mu);
102  double GetQpropMu() const;
103 
108  void SetQpropWeightDecay(double weight_decay);
109  double GetQpropWeightDecay() const;
110 
116  void SetRpropWeightStepInitial(double weight_step);
117  double GetRpropWeightStepInitial() const;
118 
126  void SetRpropWeightStepMin(double weight_step);
127  double GetRpropWeightStepMin() const;
128 
137  void SetRpropWeightStepMax(double weight_step);
138  double GetRpropWeightStepMax() const;
139 
150  void SetRpropIncreaseFactor(double factor);
151  double GetRpropIncreaseFactor() const;
152 
162  void SetRpropDecreaseFactor(double factor);
163  double GetRpropDecreaseFactor() const;
164 
171  void SetSarpropWeightDecayShift(double k1);
172  double GetSarpropWeightDecayShift() const;
173 
179  void SetSarpropStepThresholdFactor(double k2);
180  double GetSarpropStepThresholdFactor() const;
181 
187  void SetSarpropStepShift(double k3);
188  double GetSarpropStepShift() const;
189 
195  void SetSarpropTemperature(double t);
196  double GetSarpropTemperature() const;
197 
204  TrainingAlgorithmType GetTrainingAlgorithmType() const;
205 
228  void Train(TrainingData* training_data, size_t epoch_count);
229 
230  protected:
231  void UpdateSlopes();
232  void UpdateWeightsOnline();
233  void UpdateWeightsOffline(size_t current_epoch, size_t step_count);
234  void UpdateWeightsBatchingBackpropagation(size_t step_count);
235  void UpdateWeightsQuickBackpropagation(size_t step_count);
236  void UpdateWeightsResilientBackpropagation();
237  void UpdateWeightsSimulatedAnnealingResilientBackpropagation(
238  size_t current_epoch);
239 
240  void ResetWeightSteps();
241  void ResetSlopes();
242  void ResetPreviousSlopes();
243 
244  void TrainOffline(TrainingData* training_data, size_t epoch_count);
245  void TrainOnline(TrainingData* training_data, size_t epoch_count);
246 
247  private:
248  static constexpr double DefaultLearningRate = 0.7;
249  static constexpr double DefaultMomentum = 0.1;
250  static constexpr double DefaultQpropMu = 1.75;
251  static constexpr double DefaultQpropWeightDecay = -0.0001;
252  static constexpr double DefaultRpropWeightStepInitial = 0.0125;
253  static constexpr double DefaultRpropWeightStepMin = 0.000001;
254  static constexpr double DefaultRpropWeightStepMax = 50;
255  static constexpr double DefaultRpropIncreaseFactor = 1.2;
256  static constexpr double DefaultRpropDecreaseFactor = 0.5;
257  static constexpr double DefaultSarpropWeightDecayShift = 0.01;
258  static constexpr double DefaultSarpropStepThresholdFactor = 0.1;
259  static constexpr double DefaultSarpropStepShift = 3;
260  static constexpr double DefaultSarpropTemperature = 0.015;
261 
262  std::vector<double> previous_weight_steps_;
263  std::vector<double> slopes_;
264  std::vector<double> previous_slopes_;
265 
266  double learning_rate_ = DefaultLearningRate;
267  double momentum_ = DefaultMomentum;
268  double qprop_mu_ = DefaultQpropMu;
269  double qprop_weight_decay_ = DefaultQpropWeightDecay;
270  double rprop_weight_step_initial_ = DefaultRpropWeightStepInitial;
271  double rprop_weight_step_min_ = DefaultRpropWeightStepMin;
272  double rprop_weight_step_max_ = DefaultRpropWeightStepMax;
273  double rprop_increase_factor_ = DefaultRpropIncreaseFactor;
274  double rprop_decrease_factor_ = DefaultRpropDecreaseFactor;
275  double sarprop_weight_decay_shift_ = DefaultSarpropWeightDecayShift;
276  double sarprop_step_threshold_factor_ = DefaultSarpropStepThresholdFactor;
277  double sarprop_step_shift_ = DefaultSarpropStepShift;
278  double sarprop_temperature_ = DefaultSarpropTemperature;
279 
280  TrainingAlgorithmType training_algorithm_type_ =
282 };
283 
284 } // namespace panann
285 
286 #endif // FEEDFORWARDNEURALNETWORK_H__
panann::FeedForwardNeuralNetwork::SetTrainingAlgorithmType
void SetTrainingAlgorithmType(TrainingAlgorithmType type)
Definition: FeedForwardNeuralNetwork.cc:114
panann::TrainingData
Definition: TrainingData.h:30
panann::FeedForwardNeuralNetwork::SetRpropIncreaseFactor
void SetRpropIncreaseFactor(double factor)
Definition: FeedForwardNeuralNetwork.cc:66
panann::FeedForwardNeuralNetwork::SetRpropWeightStepMin
void SetRpropWeightStepMin(double weight_step)
Definition: FeedForwardNeuralNetwork.cc:50
panann::FeedForwardNeuralNetwork::TrainingAlgorithmType::QuickBackpropagation
@ QuickBackpropagation
panann::FeedForwardNeuralNetwork::SetSarpropStepThresholdFactor
void SetSarpropStepThresholdFactor(double k2)
Definition: FeedForwardNeuralNetwork.cc:90
panann::FeedForwardNeuralNetwork::SetSarpropTemperature
void SetSarpropTemperature(double t)
Definition: FeedForwardNeuralNetwork.cc:106
panann::Perceptron
Definition: Perceptron.h:23
panann::FeedForwardNeuralNetwork::SetRpropWeightStepMax
void SetRpropWeightStepMax(double weight_step)
Definition: FeedForwardNeuralNetwork.cc:58
panann::FeedForwardNeuralNetwork::SetMomentum
void SetMomentum(double momentum)
Definition: FeedForwardNeuralNetwork.cc:24
panann::FeedForwardNeuralNetwork::SetQpropMu
void SetQpropMu(double mu)
Definition: FeedForwardNeuralNetwork.cc:30
panann::FeedForwardNeuralNetwork::TrainingAlgorithmType::BatchingBackpropagation
@ BatchingBackpropagation
panann::FeedForwardNeuralNetwork::SetRpropWeightStepInitial
void SetRpropWeightStepInitial(double weight_step)
Definition: FeedForwardNeuralNetwork.cc:42
panann::FeedForwardNeuralNetwork::TrainingAlgorithmType
TrainingAlgorithmType
Definition: FeedForwardNeuralNetwork.h:23
panann::FeedForwardNeuralNetwork
Definition: FeedForwardNeuralNetwork.h:21
panann::FeedForwardNeuralNetwork::SetQpropWeightDecay
void SetQpropWeightDecay(double weight_decay)
Definition: FeedForwardNeuralNetwork.cc:34
panann::FeedForwardNeuralNetwork::SetRpropDecreaseFactor
void SetRpropDecreaseFactor(double factor)
Definition: FeedForwardNeuralNetwork.cc:74
panann::FeedForwardNeuralNetwork::TrainingAlgorithmType::SimulatedAnnealingResilientBackpropagation
@ SimulatedAnnealingResilientBackpropagation
panann::FeedForwardNeuralNetwork::SetSarpropStepShift
void SetSarpropStepShift(double k3)
Definition: FeedForwardNeuralNetwork.cc:98
panann::FeedForwardNeuralNetwork::SetSarpropWeightDecayShift
void SetSarpropWeightDecayShift(double k1)
Definition: FeedForwardNeuralNetwork.cc:82
panann::FeedForwardNeuralNetwork::Train
void Train(TrainingData *training_data, size_t epoch_count)
Definition: FeedForwardNeuralNetwork.cc:368
panann::FeedForwardNeuralNetwork::TrainingAlgorithmType::ResilientBackpropagation
@ ResilientBackpropagation
panann::FeedForwardNeuralNetwork::TrainingAlgorithmType::Backpropagation
@ Backpropagation
panann::FeedForwardNeuralNetwork::SetLearningRate
void SetLearningRate(double learning_rate)
Definition: FeedForwardNeuralNetwork.cc:16