For a while now I was wondering how one could throttle network traffic, disk reads/writes etc. Google search quickly brought me to the Token Bucket Algorithm as well as few C++ implementations here and here. I was a bit puzzled at first by how the algorithm works when looking at the code; the two implementations I found operate on atomic integers but the algorithm operates on, well, time. After some head scratching and time at a whiteboard it made sense. Here’s how I understand it:
Imagine you’re walking in a straight line at a constant speed and you are dragging a peace of rope behind you. Every time you want to do something you first pull the rope toward you a little such that the length of it you’re dragging behind becomes shorter. You repeat it every time you want to do something (that something is what you’re trying to throttle btw). At the same time, if you choose to do nothing, you release little bit of rope such that what you’re dragging gets longer, up to the maximum length of the rope. Another way to think of it is that the rope if not yanked on unwinds at a constant rate up to its maximum length. If however you pull on the rope too much you will eventually bring it all in and now you will have to wait for it to unwind a little before you can pull it again.
Now imagine that instead of walking down a straight path you’re actually moving through time and it should all make sense now: pulling on the rope a little is like consuming a token; the length of the rope is the token bucket capacity, and the rate at which the rope unwinds up to its maximum length is the rate at which the token bucket refills if no tokens are consumed. You can also pull all of the rope in at once, and that’s the sudden burst the algorithm allows for after which the rate is limited to how fast it unwinds back behind you aka how quickly the bucket refills with tokens. I really hope that explanation makes sense to you!
Some comments about the implementations I found: both use 3 std::atomic variables where only one is actually needed (unless you want the ability to change bucket capacity and token rate reliably after constructing an instance in a multi-threaded environment, which my implementation supports); the code I linked to above only needs to keep the time variable atomic. Both also operate on integers and I felt it could be abstracted better using std::chrono. Finally, there’s no need for any atomics if only one thread is consuming tokens so I decided to create a separate class for such case (not shown below).
Complete source code:
token_bucket.hpp | throttle.cpp
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
#include <atomic> #include <chrono> #include <thread> #include <stdexcept> // Multi-Threaded Version of Token Bucket class token_bucket_mt { public: using clock = std::chrono::steady_clock; using duration = clock::duration; using time_point = clock::time_point; using atomic_duration = std::atomic<duration>; using atomic_time_point = std::atomic<time_point>; token_bucket_mt(std::size_t tokens_per_second, std::size_t token_capacity, bool full = true) : token_bucket_mt(std::chrono::duration_cast<duration>(std::chrono::seconds(1)) / tokens_per_second, token_capacity, full) {} token_bucket_mt(duration time_per_token, std::size_t token_capacity, bool full = true) : m_time{ full ? time_point{} : clock::now() }, m_time_per_token{ time_per_token }, m_time_per_burst{ time_per_token * token_capacity } { if (time_per_token.count() <= 0) throw std::invalid_argument("Invalid token rate!"); if (!token_capacity) throw std::invalid_argument("Invalid token capacity!"); } token_bucket_mt(const token_bucket_mt& other) noexcept : m_time{ other.m_time.load() }, m_time_per_token{ other.m_time_per_token.load() }, m_time_per_burst{ other.m_time_per_burst.load() } {} token_bucket_mt& operator = (const token_bucket_mt& other) noexcept { m_time = other.m_time.load(); m_time_per_token = other.m_time_per_token.load(); m_time_per_burst = other.m_time_per_burst.load(); return *this; } void set_rate(std::size_t tokens_per_second) { set_rate(std::chrono::duration_cast<duration>(std::chrono::seconds(1)) / tokens_per_second); } void set_rate(duration time_per_token) { if (time_per_token.count() <= 0) throw std::invalid_argument("Invalid token rate!"); m_time_per_token = time_per_token; } void set_capacity(std::size_t token_capacity) { if (!token_capacity) throw std::invalid_argument("Invalid token capacity!"); m_time_per_burst = m_time_per_token.load() * token_capacity; } void drain() noexcept { m_time = clock::now(); } void refill() noexcept { auto now = clock::now(); m_time = now - m_time_per_burst.load(); } [[nodiscard]] bool try_consume(std::size_t tokens = 1, duration* time_needed = nullptr) noexcept { auto now = clock::now(); auto delay = tokens * m_time_per_token.load(std::memory_order_relaxed); auto min_time = now - m_time_per_burst.load(std::memory_order_relaxed); auto old_time = m_time.load(std::memory_order_relaxed); auto new_time = min_time > old_time ? min_time : old_time; while (true) { new_time += delay; if (new_time > now) { if (time_needed != nullptr) *time_needed = new_time - now; return false; } if (m_time.compare_exchange_weak(old_time, new_time, std::memory_order_relaxed, std::memory_order_relaxed)) return true; new_time = old_time; } } void consume(std::size_t tokens = 1) noexcept { while (!try_consume(tokens)) std::this_thread::yield(); } void wait(std::size_t tokens = 1) noexcept { duration time_needed; while (!try_consume(tokens, &time_needed)) std::this_thread::sleep_for(time_needed); } private: atomic_time_point m_time; atomic_duration m_time_per_token; atomic_duration m_time_per_burst; }; // Default Token Bucket using token_bucket = token_bucket_mt; |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
#include <iostream> #include <iomanip> #include <vector> #include <latch> #include "token_bucket.hpp" #define all(c) for(auto& it : c) it int main() { using namespace std; using namespace std::chrono; try { auto N = 1; auto bucket = token_bucket(1ms, 1000, true); auto count = thread::hardware_concurrency() - 1; auto run = atomic_bool{ true }; auto total = atomic_uint64_t{}; auto counts = vector<atomic_uint64_t>(count); auto fair_start = latch(count + 2); auto threads = vector<thread>(count); thread([&] { fair_start.arrive_and_wait(); auto start = steady_clock::now(); auto sec = 0; while (run) { auto cnt = 1; for (auto& count : counts) cout << fixed << "Cnt " << cnt++ << ":\t" << count << "\t / \t" << (100.0 * count / total) << " % \n"; cout << "Total:\t" << total << "\nTime:\t" << duration_cast<seconds>(steady_clock::now() - start).count() << "s\n" << endl; this_thread::sleep_until(start + seconds(++sec)); } }).detach(); auto worker = [&](auto x) { fair_start.arrive_and_wait(); while (run) { bucket.consume(N); total += N; counts[x] += N; } }; thread([&] { bucket.drain(); fair_start.arrive_and_wait(); this_thread::sleep_for(3s); bucket.set_rate(1s); bucket.set_capacity(1000000); this_thread::sleep_for(3s); bucket.refill(); }).detach(); all(threads) = thread(worker, --count); cin.get(); run = false; all(threads).join(); } catch (exception& ex) { cerr << ex.what() << endl; } } |