public abstract class AbstractStochasticCachingDiffUpdateFunction extends AbstractStochasticCachingDiffFunction
AbstractStochasticCachingDiffFunction.SamplingMethod
Modifier and Type | Field and Description |
---|---|
protected boolean |
skipValCalc |
allIndices, curElement, finiteDifferenceStepSize, gradPerturbed, hasNewVals, HdotV, lastBatch, lastBatchSize, lastElement, lastVBatch, lastXBatch, method, randGenerator, recalculatePrevBatch, returnPreviousValues, sampleMethod, scaleUp, thisBatch, xPerturbed
derivative, value
Constructor and Description |
---|
AbstractStochasticCachingDiffUpdateFunction() |
Modifier and Type | Method and Description |
---|---|
void |
calculateStochasticGradient(double[] x,
int batchSize)
Performs stochastic gradient updates based
on samples indexed by batch and do not apply regularization.
|
abstract void |
calculateStochasticGradient(double[] x,
int[] batch)
Performs stochastic gradient calculation based
on samples indexed by batch and does not apply regularization.
|
abstract double |
calculateStochasticUpdate(double[] x,
double xScale,
int[] batch,
double gain)
Performs stochastic update of weights x (scaled by xScale) based
on samples indexed by batch.
|
double |
calculateStochasticUpdate(double[] x,
double xScale,
int batchSize,
double gain)
Performs stochastic update of weights x (scaled by xScale) based
on next batch of batchSize.
|
int[] |
getSample(int sampleSize)
Gets a random sample (this is sampling with replacement).
|
double |
valueAt(double[] x,
double xScale,
int batchSize) |
abstract double |
valueAt(double[] x,
double xScale,
int[] batch)
Computes value of function for specified value of x (scaled by xScale)
only over samples indexed by batch.
|
calculateStochastic, clearCache, dataDimension, decrementBatch, derivativeAt, derivativeAt, getBatch, HdotVAt, HdotVAt, HdotVAt, incrementBatch, incrementRandom, initial, lastDerivative, lastValue, scaleUp, valueAt, valueAt
calculate, copy, derivativeAt, ensure, getDerivative, gradientCheck, gradientCheck, randomInitial, valueAt
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
domainDimension
public AbstractStochasticCachingDiffUpdateFunction()
public int[] getSample(int sampleSize)
sampleSize
- number of samples to generatepublic abstract double valueAt(double[] x, double xScale, int[] batch)
x
- unscaled weightsxScale
- how much to scale x by when performing calculationsbatch
- indices of which samples to compute function overpublic double valueAt(double[] x, double xScale, int batchSize)
public abstract double calculateStochasticUpdate(double[] x, double xScale, int[] batch, double gain)
x
- unscaled weightsxScale
- how much to scale x by when performing calculationsbatch
- indices of which samples to compute function overgain
- how much to scale adjustments to xpublic double calculateStochasticUpdate(double[] x, double xScale, int batchSize, double gain)
x
- unscaled weightsxScale
- how much to scale x by when performing calculationsbatchSize
- number of samples to pick nextgain
- how much to scale adjustments to xpublic abstract void calculateStochasticGradient(double[] x, int[] batch)
x
- Unscaled weightsbatch
- Indices of which samples to compute function overpublic void calculateStochasticGradient(double[] x, int batchSize)
x
- unscaled weightsbatchSize
- number of samples to pick next