Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
- Exportable Extractors now append by default with option to overwrite
- Added validation interval parameter to MLPs and GBM Learners
- Removed output layer L2 Penalty parameter from MLP Learners
- Remove Network interface
- RBX Serializer only tracks major library version number

- Convert NeuralNet classes to use NDArray instead of Matrix
- Converted back Network interface

- 2.5.0
- Added Vantage Point Spatial tree
- Blob Generator can now `simulate()` a Dataset object
Expand Down
3 changes: 2 additions & 1 deletion src/Classifiers/LogisticRegression.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

namespace Rubix\ML\Classifiers;

use Rubix\ML\NeuralNet\FeedForward;
use Rubix\ML\Online;
use Rubix\ML\Learner;
use Rubix\ML\Verbose;
Expand Down Expand Up @@ -289,7 +290,7 @@ public function train(Dataset $dataset) : void

$classes = $dataset->possibleOutcomes();

$this->network = new Network(
$this->network = new FeedForward(
new Placeholder1D($dataset->numFeatures()),
[new Dense(1, $this->l2Penalty, true, new Xavier1())],
new Binary($classes, $this->costFn),
Expand Down
3 changes: 2 additions & 1 deletion src/Classifiers/MultilayerPerceptron.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

namespace Rubix\ML\Classifiers;

use Rubix\ML\NeuralNet\FeedForward;
use Rubix\ML\Online;
use Rubix\ML\Learner;
use Rubix\ML\Verbose;
Expand Down Expand Up @@ -370,7 +371,7 @@ public function train(Dataset $dataset) : void

$hiddenLayers[] = new Dense(count($classes), 0.0, true, new Xavier1());

$this->network = new Network(
$this->network = new FeedForward(
new Placeholder1D($dataset->numFeatures()),
$hiddenLayers,
new Multiclass($classes, $this->costFn),
Expand Down
3 changes: 2 additions & 1 deletion src/Classifiers/SoftmaxClassifier.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

namespace Rubix\ML\Classifiers;

use Rubix\ML\NeuralNet\FeedForward;
use Rubix\ML\Online;
use Rubix\ML\Learner;
use Rubix\ML\Verbose;
Expand Down Expand Up @@ -285,7 +286,7 @@ public function train(Dataset $dataset) : void

$classes = $dataset->possibleOutcomes();

$this->network = new Network(
$this->network = new FeedForward(
new Placeholder1D($dataset->numFeatures()),
[new Dense(count($classes), $this->l2Penalty, true, new Xavier1())],
new Multiclass($classes, $this->costFn),
Expand Down
2 changes: 1 addition & 1 deletion src/NeuralNet/FeedForward.php
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
* @package Rubix/ML
* @author Andrew DalPino
*/
class FeedForward extends Network
class FeedForward implements Network
{
/**
* The input layer to the network.
Expand Down
254 changes: 4 additions & 250 deletions src/NeuralNet/Network.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,270 +2,24 @@

namespace Rubix\ML\NeuralNet;

use Tensor\Matrix;
use Rubix\ML\Encoding;
use Rubix\ML\Datasets\Dataset;
use Rubix\ML\Datasets\Labeled;
use Rubix\ML\NeuralNet\Layers\Input;
use Rubix\ML\NeuralNet\Layers\Output;
use Rubix\ML\NeuralNet\Layers\Parametric;
use Rubix\ML\NeuralNet\Optimizers\Adaptive;
use Rubix\ML\NeuralNet\Optimizers\Optimizer;
use Traversable;

use function array_reverse;

/**
* Network
*
* A neural network implementation consisting of an input and output layer and any number
* of intermediate hidden layers.
*
* @internal
*
* @category Machine Learning
* @package Rubix/ML
* @author Andrew DalPino
* @author Samuel Akopyan <leumas.a@gmail.com>
*/
class Network
interface Network
{
/**
* The input layer to the network.
*
* @var Input
*/
protected Input $input;

/**
* The hidden layers of the network.
*
* @var list<Layers\Hidden>
*/
protected array $hidden = [
//
];

/**
* The pathing of the backward pass through the hidden layers.
*
* @var list<Layers\Hidden>
*/
protected array $backPass = [
//
];

/**
* The output layer of the network.
*
* @var Output
*/
protected Output $output;

/**
* The gradient descent optimizer used to train the network.
*
* @var Optimizer
*/
protected Optimizer $optimizer;

/**
* @param Input $input
* @param Layers\Hidden[] $hidden
* @param Output $output
* @param Optimizer $optimizer
*/
public function __construct(Input $input, array $hidden, Output $output, Optimizer $optimizer)
{
$hidden = array_values($hidden);

$backPass = array_reverse($hidden);

$this->input = $input;
$this->hidden = $hidden;
$this->output = $output;
$this->optimizer = $optimizer;
$this->backPass = $backPass;
}

/**
* Return the input layer.
*
* @return Input
*/
public function input() : Input
{
return $this->input;
}

/**
* Return an array of hidden layers indexed left to right.
*
* @return list<Layers\Hidden>
*/
public function hidden() : array
{
return $this->hidden;
}

/**
* Return the output layer.
*
* @return Output
*/
public function output() : Output
{
return $this->output;
}

/**
* Return all the layers in the network.
* Return the layers of the network.
*
* @return Traversable<Layers\Layer>
*/
public function layers() : Traversable
{
yield $this->input;

yield from $this->hidden;

yield $this->output;
}

/**
* Return the number of trainable parameters in the network.
*
* @return int
*/
public function numParams() : int
{
$numParams = 0;

foreach ($this->layers() as $layer) {
if ($layer instanceof Parametric) {
foreach ($layer->parameters() as $parameter) {
$numParams += $parameter->param()->size();
}
}
}

return $numParams;
}

/**
* Initialize the parameters of the layers and warm the optimizer cache.
*/
public function initialize() : void
{
$fanIn = 1;

foreach ($this->layers() as $layer) {
$fanIn = $layer->initialize($fanIn);
}

if ($this->optimizer instanceof Adaptive) {
foreach ($this->layers() as $layer) {
if ($layer instanceof Parametric) {
foreach ($layer->parameters() as $param) {
$this->optimizer->warm($param);
}
}
}
}
}

/**
* Run an inference pass and return the activations at the output layer.
*
* @param Dataset $dataset
* @return Matrix
*/
public function infer(Dataset $dataset) : Matrix
{
$input = Matrix::quick($dataset->samples())->transpose();

foreach ($this->layers() as $layer) {
$input = $layer->infer($input);
}

return $input->transpose();
}

/**
* Perform a forward and backward pass of the network in one call. Returns
* the loss from the backward pass.
*
* @param Labeled $dataset
* @return float
*/
public function roundtrip(Labeled $dataset) : float
{
$input = Matrix::quick($dataset->samples())->transpose();

$this->feed($input);

$loss = $this->backpropagate($dataset->labels());

return $loss;
}

/**
* Feed a batch through the network and return a matrix of activations at the output later.
*
* @param Matrix $input
* @return Matrix
*/
public function feed(Matrix $input) : Matrix
{
foreach ($this->layers() as $layer) {
$input = $layer->forward($input);
}

return $input;
}

/**
* Backpropagate the gradient of the cost function and return the loss.
*
* @param list<string|int|float> $labels
* @return float
*/
public function backpropagate(array $labels) : float
{
[$gradient, $loss] = $this->output->back($labels, $this->optimizer);

foreach ($this->backPass as $layer) {
$gradient = $layer->back($gradient, $this->optimizer);
}

return $loss;
}

/**
* Export the network architecture as a graph in dot format.
*
* @return Encoding
*/
public function exportGraphviz() : Encoding
{
$dot = 'digraph Tree {' . PHP_EOL;
$dot .= ' node [shape=box, fontname=helvetica];' . PHP_EOL;

$layerNum = 0;

foreach ($this->layers() as $layer) {
++$layerNum;

$dot .= " N$layerNum [label=\"$layer\",style=\"rounded\"]" . PHP_EOL;

if ($layerNum > 1) {
$parentId = $layerNum - 1;

$dot .= " N{$parentId} -> N{$layerNum};" . PHP_EOL;
}
}

$dot .= '}';

return new Encoding($dot);
}
public function layers() : Traversable;
}
Loading
Loading