Trait collenchyma_nn::Relu [] [src]

pub trait Relu<F>: NN<F> {
    fn relu(&self, x: &mut SharedTensor<F>, result: &mut SharedTensor<F>) -> Result<(), Error>;
    fn relu_plain(&self, x: &SharedTensor<F>, result: &mut SharedTensor<F>) -> Result<(), Error>;
    fn relu_grad(&self, x: &mut SharedTensor<F>, x_diff: &mut SharedTensor<F>, result: &mut SharedTensor<F>, result_diff: &mut SharedTensor<F>) -> Result<(), Error>;
    fn relu_grad_plain(&self, x: &SharedTensor<F>, x_diff: &SharedTensor<F>, result: &SharedTensor<F>, result_diff: &mut SharedTensor<F>) -> Result<(), Error>;
}

Provides the functionality for a Backend to support ReLU operations.

Required Methods

fn relu(&self, x: &mut SharedTensor<F>, result: &mut SharedTensor<F>) -> Result<(), Error>

Computes the Rectified linear units over the input Tensor x with complete memory management.

Saves the result to result.

For a no-memory managed version see relu_plain.

fn relu_plain(&self, x: &SharedTensor<F>, result: &mut SharedTensor<F>) -> Result<(), Error>

Computes the ReLU over the input Tensor x without any memory management.

Saves the result to result.

Attention:
For a correct computation result, you need to manage the memory allocation and synchronization yourself.
For a memory managed version see relu.

fn relu_grad(&self, x: &mut SharedTensor<F>, x_diff: &mut SharedTensor<F>, result: &mut SharedTensor<F>, result_diff: &mut SharedTensor<F>) -> Result<(), Error>

Computes the gradient of ReLU over the input Tensor x with complete memory management.

Saves the result to result_diff.

For a no-memory managed version see relu_grad_plain.

fn relu_grad_plain(&self, x: &SharedTensor<F>, x_diff: &SharedTensor<F>, result: &SharedTensor<F>, result_diff: &mut SharedTensor<F>) -> Result<(), Error>

Computes the gradient of ReLU over the input Tensor x without any memory management.

Saves the result to result_diff.

Attention:
For a correct computation result, you need to manage the memory allocation and synchronization yourself.
For a memory managed version see relu_grad.

Implementors