1 /**
2  * Internal abstraction to report progress during training.
3  *
4  * Copyright: 2017 Netflix, Inc.
5  * License: $(LINK2 http://www.apache.org/licenses/LICENSE-2.0, Apache License Version 2.0)
6  */
7 module vectorflow.monitor;
8 
9 private
10 {
11 import core.time : dur, Duration, MonoTime, ticksToNSecs;
12 import std.algorithm : max, filter, sum;
13 import std.array;
14 import std.conv : to;
15 import std.format : format, sformat;
16 import std.math : round;
17 import std.stdio : stdout;
18 
19 import vectorflow.math : fabs;
20 import vectorflow.utils : isTerminal;
21 }
22 
23 
24 
25 class SGDMonitor {
26 
27     bool verbose;
28     ulong num_epochs;
29     uint num_cores;
30     bool with_loss;
31     MonoTime start_time;
32     MonoTime last_time;
33 
34     ulong[] examples_seen;
35     ulong[] features_seen;
36     ushort[] passes_seen;
37     double[] acc_loss;
38 
39     char[] _bar_buff;
40 
41     char[] _buff_stdout_line;
42 
43     string _pattern;
44     bool _isTerminal;
45     ushort _max_rows_no_term;
46     float _last_percent_displayed;
47     ushort _rows_no_term;
48 
49     this(bool verbose_, ulong num_epochs_,
50             uint num_cores_, MonoTime start_time_, bool with_loss_)
51     {
52         verbose = verbose_;
53         num_epochs = num_epochs_;
54         num_cores = num_cores_;
55         start_time = start_time_;
56         with_loss = with_loss_;
57         examples_seen.length = num_cores;
58         features_seen[] = 0;
59         features_seen.length = num_cores;
60         features_seen[] = 0;
61         passes_seen.length = num_cores;
62         passes_seen[] = 0;
63         acc_loss.length = num_cores;
64         acc_loss[] = 0.0;
65 
66         _bar_buff = new char[100];
67         _bar_buff[0] = '[';
68         _bar_buff[51] = ']';
69 
70         _buff_stdout_line = new char[240];
71 
72         if(with_loss_)
73             _pattern = 
74             "Progress: %s | Elapsed: %s | Remaining: %s | %04d passes " ~
75             "| Loss: %.4e | %.2e obs/sec | %.2e features/sec";
76         else
77             _pattern =
78             "Progress: %s | Elapsed: %s | Remaining: %s | %04d passes " ~
79             "| %.2e obs/sec | %.2e features/sec";
80 
81         _isTerminal = isTerminal();
82         if(_isTerminal)
83             _pattern ~= "\r";
84         else
85         {
86             _pattern ~= "\n";
87             _max_rows_no_term = 20;
88             _rows_no_term = 0;
89         }
90         _last_percent_displayed = 0.0f;
91     }
92 
93     private char[] get_progress_bar(float percentage)
94     {
95         _bar_buff[1..51] = ' ';
96         auto num_finished = round(percentage * 50).to!size_t;
97         if(num_finished >= 1)
98             _bar_buff[1..num_finished+1] = 'o';
99         auto end = sformat(_bar_buff[52..100],
100                 " (%.1f %%)", percentage * 100).length;
101         return _bar_buff[0..52+end];
102     }
103 
104     private static string time_clock_str(Duration d, bool with_ms)
105     {
106         auto ds = d.split!("hours", "minutes", "seconds", "msecs")();
107         if(with_ms)
108             return "%02d:%02d:%02d.%03d".format(
109                     ds.hours, ds.minutes, ds.seconds, ds.msecs);
110         return "%02d:%02d:%02d".format(
111                     ds.hours, ds.minutes, ds.seconds);
112     }
113 
114     void progress_callback(uint core_id, ulong epoch, ulong num_examples,
115             ulong num_features, double sum_loss)
116     {
117         if(!verbose)
118             return;
119         last_time = MonoTime.currTime;
120         examples_seen[core_id] += num_examples;
121         features_seen[core_id] += num_features;
122         acc_loss[core_id] += sum_loss;
123         passes_seen[core_id] = epoch.to!ushort;
124         auto ticks = (last_time.ticks - start_time.ticks);
125         double seconds = ticksToNSecs(ticks).to!double / 1e9;
126         auto time = time_clock_str(last_time - start_time, true);
127 
128         auto passes = passes_seen.filter!(x => x > 0).array;
129         auto avg_passes = passes.sum.to!float / max(1, passes.length);
130         auto total_ex_seen = examples_seen.sum.to!double;
131         auto total_feats_seen = features_seen.sum.to!double;
132         auto total_loss = acc_loss.sum;
133         auto avg_loss_per_ex = total_loss / total_ex_seen;
134 
135         auto percent = avg_passes/ num_epochs;
136         bool write_line = true;
137         if(!_isTerminal)
138         {
139             if(fabs(percent - 1) < 1e-8
140                     || percent > _rows_no_term.to!float / _max_rows_no_term)
141                 _rows_no_term++;
142             else
143                 write_line = false;
144         }
145 
146         if(!write_line)
147             return;
148         string predict_remaining = "--------";
149         if(percent > 0 && percent < 1)
150         {
151             auto remaining_secs = seconds / percent - seconds;
152             auto remaining_dur = dur!"seconds"(round(remaining_secs).to!long);
153             predict_remaining = time_clock_str(remaining_dur, false);
154         }
155         char[] line;
156         auto bar = get_progress_bar(percent);
157         if(with_loss)
158         {
159             line = sformat(_buff_stdout_line, _pattern,
160                 bar, time, predict_remaining, round(avg_passes).to!long,
161                 avg_loss_per_ex,
162                 total_ex_seen / seconds, total_feats_seen / seconds);
163         }
164         else
165         {
166             line = sformat(_buff_stdout_line, _pattern,
167                 bar, time, predict_remaining, round(avg_passes).to!long,
168                 total_ex_seen / seconds, total_feats_seen / seconds);
169         }
170         stdout.write(line);
171         stdout.flush();
172     }
173 
174     void wrap_up()
175     {
176         if(!verbose)
177             return;
178         if(_isTerminal && num_epochs > 0)
179         {
180             stdout.write("\n");
181             stdout.flush();
182         }
183     }
184 }