panann
RecurrentNeuralNetwork.h
1 //-------------------------------------------------------------------------------------------------------
2 // Copyright (C) Taylor Woll and panann contributors. All rights reserved.
3 // Licensed under the MIT license. See LICENSE.txt file in the project root for
4 // full license information.
5 //-------------------------------------------------------------------------------------------------------
6 
7 #ifndef RECURRENTNEURALNETWORK_H__
8 #define RECURRENTNEURALNETWORK_H__
9 
10 #include "Perceptron.h"
11 
12 namespace panann {
13 
26  protected:
33 
39  size_t neuron_count;
40 
46 
52 
53  size_t GetNeuronsPerGate() const;
54  size_t GetForgetGateStartNeuronIndex() const;
55  size_t GetInputGateStartNeuronIndex() const;
56  size_t GetOutputGateStartNeuronIndex() const;
57  size_t GetCandidateCellStateStartNeuronIndex() const;
58  size_t GetOutputUnitStartNeuronIndex() const;
59  };
60 
61  struct CellLayer {
66 
70  size_t cell_count;
71  };
72 
73  public:
74  RecurrentNeuralNetwork() = default;
76  RecurrentNeuralNetwork& operator=(const RecurrentNeuralNetwork&) = delete;
77  ~RecurrentNeuralNetwork() override = default;
78 
87  void SetCellMemorySize(size_t memory_size);
88  size_t GetCellMemorySize() const;
89 
90  void RunForward(const std::vector<double>& input) override;
91 
95  std::vector<double>& GetCellStates();
96 
97  void AddHiddenLayer(size_t neuron_count) = delete;
98 
107  void AddHiddenLayer(size_t cell_count,
108  const std::vector<size_t>& cell_memory_sizes = {});
109 
110  void Construct() override;
111 
112  protected:
113  void ConnectFully() override;
114  void FixNeuronConnectionIndices() override;
115  void InitializeHiddenNeurons() override;
116 
117  void AllocateCellStates();
118  bool AreCellStatesAllocated() const;
119 
120  void UpdateCellState(const LongShortTermMemoryCell& cell);
121 
122  size_t AddCellMemoryStates(size_t count);
123 
124  size_t GetCellCount() const;
125  LongShortTermMemoryCell& GetCell(size_t index);
126  size_t GetCellLayerCount() const;
127  CellLayer& GetCellLayer(size_t index);
128 
136  void InitializeCellNeurons(const LongShortTermMemoryCell& cell,
137  size_t input_connection_count,
138  size_t output_connection_count);
139 
143  void InitializeCellNeuronsOneGate(size_t neuron_start_index,
144  size_t neurons_per_gate,
145  ActivationFunctionType activation_function,
146  size_t input_connection_count,
147  size_t output_connection_count);
148 
149  private:
150  static constexpr size_t DefaultCellMemorySize = 200;
151 
152  std::vector<CellLayer> layers_;
153  std::vector<LongShortTermMemoryCell> cells_;
154  std::vector<double> cell_states_;
155  size_t cell_states_count_ = 0;
156  size_t cell_memory_size_ = DefaultCellMemorySize;
157  bool is_allocated_ = false;
158 };
159 
160 } // namespace panann
161 
162 #endif // RECURRENTNEURALNETWORK_H__
panann::RecurrentNeuralNetwork::LongShortTermMemoryCell::neuron_count
size_t neuron_count
Definition: RecurrentNeuralNetwork.h:39
panann::RecurrentNeuralNetwork::GetCellStates
std::vector< double > & GetCellStates()
Definition: RecurrentNeuralNetwork.cc:453
panann::RecurrentNeuralNetwork::InitializeCellNeuronsOneGate
void InitializeCellNeuronsOneGate(size_t neuron_start_index, size_t neurons_per_gate, ActivationFunctionType activation_function, size_t input_connection_count, size_t output_connection_count)
Definition: RecurrentNeuralNetwork.cc:121
panann::RecurrentNeuralNetwork::LongShortTermMemoryCell::cell_state_start_index
size_t cell_state_start_index
Definition: RecurrentNeuralNetwork.h:45
panann::RecurrentNeuralNetwork::CellLayer::cell_start_index
size_t cell_start_index
Definition: RecurrentNeuralNetwork.h:65
panann::RecurrentNeuralNetwork::LongShortTermMemoryCell::neuron_start_index
size_t neuron_start_index
Definition: RecurrentNeuralNetwork.h:32
panann::RecurrentNeuralNetwork::Construct
void Construct() override
Definition: RecurrentNeuralNetwork.cc:310
panann::Perceptron
Definition: Perceptron.h:23
panann::RecurrentNeuralNetwork
Definition: RecurrentNeuralNetwork.h:25
panann::RecurrentNeuralNetwork::CellLayer
Definition: RecurrentNeuralNetwork.h:61
panann::RecurrentNeuralNetwork::CellLayer::cell_count
size_t cell_count
Definition: RecurrentNeuralNetwork.h:70
panann::RecurrentNeuralNetwork::InitializeHiddenNeurons
void InitializeHiddenNeurons() override
Definition: RecurrentNeuralNetwork.cc:174
panann::RecurrentNeuralNetwork::FixNeuronConnectionIndices
void FixNeuronConnectionIndices() override
Definition: RecurrentNeuralNetwork.cc:180
panann::RecurrentNeuralNetwork::InitializeCellNeurons
void InitializeCellNeurons(const LongShortTermMemoryCell &cell, size_t input_connection_count, size_t output_connection_count)
Definition: RecurrentNeuralNetwork.cc:142
panann::RecurrentNeuralNetwork::SetCellMemorySize
void SetCellMemorySize(size_t memory_size)
Definition: RecurrentNeuralNetwork.cc:50
panann::RecurrentNeuralNetwork::LongShortTermMemoryCell
Definition: RecurrentNeuralNetwork.h:27
panann::RecurrentNeuralNetwork::RunForward
void RunForward(const std::vector< double > &input) override
Definition: RecurrentNeuralNetwork.cc:433
panann::RecurrentNeuralNetwork::LongShortTermMemoryCell::cell_state_count
size_t cell_state_count
Definition: RecurrentNeuralNetwork.h:51