1 /** 2 * Internal abstraction to report progress during training. 3 * 4 * Copyright: 2017 Netflix, Inc. 5 * License: $(LINK2 http://www.apache.org/licenses/LICENSE-2.0, Apache License Version 2.0) 6 */ 7 module vectorflow.monitor; 8 9 private 10 { 11 import core.time : dur, Duration, MonoTime, ticksToNSecs; 12 import std.algorithm : max, filter, sum; 13 import std.array; 14 import std.conv : to; 15 import std.format : format, sformat; 16 import std.math : round; 17 import std.stdio : stdout; 18 19 import vectorflow.math : fabs; 20 import vectorflow.utils : isTerminal; 21 } 22 23 24 25 class SGDMonitor { 26 27 bool verbose; 28 ulong num_epochs; 29 uint num_cores; 30 bool with_loss; 31 MonoTime start_time; 32 MonoTime last_time; 33 34 ulong[] examples_seen; 35 ulong[] features_seen; 36 ushort[] passes_seen; 37 double[] acc_loss; 38 39 char[] _bar_buff; 40 41 char[] _buff_stdout_line; 42 43 string _pattern; 44 bool _isTerminal; 45 ushort _max_rows_no_term; 46 float _last_percent_displayed; 47 ushort _rows_no_term; 48 49 this(bool verbose_, ulong num_epochs_, 50 uint num_cores_, MonoTime start_time_, bool with_loss_) 51 { 52 verbose = verbose_; 53 num_epochs = num_epochs_; 54 num_cores = num_cores_; 55 start_time = start_time_; 56 with_loss = with_loss_; 57 examples_seen.length = num_cores; 58 features_seen[] = 0; 59 features_seen.length = num_cores; 60 features_seen[] = 0; 61 passes_seen.length = num_cores; 62 passes_seen[] = 0; 63 acc_loss.length = num_cores; 64 acc_loss[] = 0.0; 65 66 _bar_buff = new char[100]; 67 _bar_buff[0] = '['; 68 _bar_buff[51] = ']'; 69 70 _buff_stdout_line = new char[240]; 71 72 if(with_loss_) 73 _pattern = 74 "Progress: %s | Elapsed: %s | Remaining: %s | %04d passes " ~ 75 "| Loss: %.4e | %.2e obs/sec | %.2e features/sec"; 76 else 77 _pattern = 78 "Progress: %s | Elapsed: %s | Remaining: %s | %04d passes " ~ 79 "| %.2e obs/sec | %.2e features/sec"; 80 81 _isTerminal = isTerminal(); 82 if(_isTerminal) 83 _pattern ~= "\r"; 84 else 85 { 86 _pattern ~= "\n"; 87 _max_rows_no_term = 20; 88 _rows_no_term = 0; 89 } 90 _last_percent_displayed = 0.0f; 91 } 92 93 private char[] get_progress_bar(float percentage) 94 { 95 _bar_buff[1..51] = ' '; 96 auto num_finished = round(percentage * 50).to!size_t; 97 if(num_finished >= 1) 98 _bar_buff[1..num_finished+1] = 'o'; 99 auto end = sformat(_bar_buff[52..100], 100 " (%.1f %%)", percentage * 100).length; 101 return _bar_buff[0..52+end]; 102 } 103 104 private static string time_clock_str(Duration d, bool with_ms) 105 { 106 auto ds = d.split!("hours", "minutes", "seconds", "msecs")(); 107 if(with_ms) 108 return "%02d:%02d:%02d.%03d".format( 109 ds.hours, ds.minutes, ds.seconds, ds.msecs); 110 return "%02d:%02d:%02d".format( 111 ds.hours, ds.minutes, ds.seconds); 112 } 113 114 void progress_callback(uint core_id, ulong epoch, ulong num_examples, 115 ulong num_features, double sum_loss) 116 { 117 if(!verbose) 118 return; 119 last_time = MonoTime.currTime; 120 examples_seen[core_id] += num_examples; 121 features_seen[core_id] += num_features; 122 acc_loss[core_id] += sum_loss; 123 passes_seen[core_id] = epoch.to!ushort; 124 auto ticks = (last_time.ticks - start_time.ticks); 125 double seconds = ticksToNSecs(ticks).to!double / 1e9; 126 auto time = time_clock_str(last_time - start_time, true); 127 128 auto passes = passes_seen.filter!(x => x > 0).array; 129 auto avg_passes = passes.sum.to!float / max(1, passes.length); 130 auto total_ex_seen = examples_seen.sum.to!double; 131 auto total_feats_seen = features_seen.sum.to!double; 132 auto total_loss = acc_loss.sum; 133 auto avg_loss_per_ex = total_loss / total_ex_seen; 134 135 auto percent = avg_passes/ num_epochs; 136 bool write_line = true; 137 if(!_isTerminal) 138 { 139 if(fabs(percent - 1) < 1e-8 140 || percent > _rows_no_term.to!float / _max_rows_no_term) 141 _rows_no_term++; 142 else 143 write_line = false; 144 } 145 146 if(!write_line) 147 return; 148 string predict_remaining = "--------"; 149 if(percent > 0 && percent < 1) 150 { 151 auto remaining_secs = seconds / percent - seconds; 152 auto remaining_dur = dur!"seconds"(round(remaining_secs).to!long); 153 predict_remaining = time_clock_str(remaining_dur, false); 154 } 155 char[] line; 156 auto bar = get_progress_bar(percent); 157 if(with_loss) 158 { 159 line = sformat(_buff_stdout_line, _pattern, 160 bar, time, predict_remaining, round(avg_passes).to!long, 161 avg_loss_per_ex, 162 total_ex_seen / seconds, total_feats_seen / seconds); 163 } 164 else 165 { 166 line = sformat(_buff_stdout_line, _pattern, 167 bar, time, predict_remaining, round(avg_passes).to!long, 168 total_ex_seen / seconds, total_feats_seen / seconds); 169 } 170 stdout.write(line); 171 stdout.flush(); 172 } 173 174 void wrap_up() 175 { 176 if(!verbose) 177 return; 178 if(_isTerminal && num_epochs > 0) 179 { 180 stdout.write("\n"); 181 stdout.flush(); 182 } 183 } 184 }