1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#[macro_export]
macro_rules! impl_isolver_sgd {
($t:ty) => (
impl<SolverB: IBackend + SolverOps<f32>, NetB: IBackend + LayerOps<f32> + 'static> ISolver<SolverB, NetB> for $t {
fn init(&mut self, net: &Layer<NetB>) {
self.history = Vec::with_capacity(net.learnable_weights_gradients().len());
for weight_gradient in net.learnable_weights_gradients() {
let shape = weight_gradient.read().unwrap().desc().clone();
let mut tensor = SharedTensor::new(IBackend::device(&*self.backend),
&shape).unwrap();
let filler = ::weight::FillerType::Constant { value: 0f32 };
filler.fill(&mut tensor);
let history_tensor = Arc::new(RwLock::new(tensor));
self.history.push(history_tensor);
}
}
fn compute_update(&mut self, config: &SolverConfig, net: &mut Layer<NetB>, iter: usize) {
let rate = config.get_learning_rate(iter);
SGDSolver::<SolverB, NetB>::clip_gradients(self, config, net);
for (weight_id, weight_gradient) in net.learnable_weights_gradients().iter().enumerate() {
SGDSolver::<SolverB, NetB>::normalize(self, config, weight_gradient);
SGDSolver::<SolverB, NetB>::compute_update_value(self, config,
weight_gradient,
weight_id,
&rate,
&net.learnable_weights_lr()[weight_id].unwrap());
}
}
fn backend(&self) -> &SolverB {
&self.backend
}
}
)
}
pub use self::momentum::Momentum;
pub mod momentum;