Trait collenchyma_nn::LogSoftmax
[−]
[src]
pub trait LogSoftmax<F>: NN<F> {
fn log_softmax(&self, x: &mut SharedTensor<F>, result: &mut SharedTensor<F>) -> Result<(), Error>;
fn log_softmax_plain(&self, x: &SharedTensor<F>, result: &mut SharedTensor<F>) -> Result<(), Error>;
fn log_softmax_grad(&self, x: &mut SharedTensor<F>, x_diff: &mut SharedTensor<F>, result_diff: &mut SharedTensor<F>) -> Result<(), Error>;
fn log_softmax_grad_plain(&self, x: &SharedTensor<F>, x_diff: &SharedTensor<F>, result_diff: &mut SharedTensor<F>) -> Result<(), Error>;
}Provides the functionality for a Backend to support LogSoftmax operations.
Required Methods
fn log_softmax(&self, x: &mut SharedTensor<F>, result: &mut SharedTensor<F>) -> Result<(), Error>
Computes a logarithmic softmax over the input Tensor x with complete memory management.
Saves the result to result.
For a no-memory managed version see log_softmax_plain.
fn log_softmax_plain(&self, x: &SharedTensor<F>, result: &mut SharedTensor<F>) -> Result<(), Error>
Computes the logarithmic softmax over the input Tensor x without any memory management.
Saves the result to result.
Attention:
For a correct computation result, you need to manage the memory allocation and synchronization yourself.
For a memory managed version see log_softmax.
fn log_softmax_grad(&self, x: &mut SharedTensor<F>, x_diff: &mut SharedTensor<F>, result_diff: &mut SharedTensor<F>) -> Result<(), Error>
Computes the gradient of a logarithmic softmax over the input Tensor x with complete memory management.
Saves the result to result_diff.
For a no-memory managed version see log_softmax_grad_plain.
fn log_softmax_grad_plain(&self, x: &SharedTensor<F>, x_diff: &SharedTensor<F>, result_diff: &mut SharedTensor<F>) -> Result<(), Error>
Computes the gradient of a logarithmic softmax over the input Tensor x without any memory management.
Saves the result to result_diff.
Attention:
For a correct computation result, you need to manage the memory allocation and synchronization yourself.
For a memory managed version see log_softmax_grad.
Implementors
impl LogSoftmax<f32> for Backend<Native>impl LogSoftmax<f64> for Backend<Native>