Bonita Montero
2023-12-10 09:46:18 UTC
Reply
Permalinkthreads do al the punching of the non-prime multiplex beyond sqrt( max
). I found that the algorithm is limited by the memory bandwith since
the bit-range between sqrt( max ) and max is to large to fit inside
the cache. So I partitioned the each thread to a number of bit-ranges
that fits inside the L2-cache. With that I got a speedup of factor 28
when calculating about the first 210 million primes, i.e. all primes
<= 2 ^ 32.
On my Zen4 16 core PC under Windows with clang-cl 16.0.5 producing the
primes without any file output is only 0.28s. On my Linux-PC, a Zen2 64
core CPU producing the same number of primes is about 0.09s.
The file output is done with my own buffering. With that writing the
mentioned prime number range results in a 2.2GB file which is written
in about two seconds with a PCIe 4.0 SSD.
The first parameter is the highest prime candidate to be generated,
it can be prefixed with 0x for hex values. The second number is the
name of the output file; it can be "" to suppress file output. The
third parameter is the number of punching threads; if it is left the
number of threads defaults to the number of threads reported by the
runtime.
#include <iostream>
#include <cstdlib>
#include <charconv>
#include <cstring>
#include <vector>
#include <algorithm>
#include <cmath>
#include <bit>
#include <fstream>
#include <string_view>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <utility>
#include <semaphore>
#include <atomic>
#include <new>
#include <span>
#if defined(_MSC_VER)
#include <intrin.h>
#elif defined(__GNUC__) || defined(__clang__)
#include <cpuid.h>
#endif
#if defined(_MSC_VER)
#pragma warning(disable: 26495)
#endif
using namespace std;
#if defined(__cpp_lib_hardware_interference_size)
constexpr ptrdiff_t CL_SIZE = hardware_destructive_interference_size;
#else
constexpr ptrdiff_t CL_SIZE = 64;
#endif
size_t getL2Size();
int main( int argc, char **argv )
{
if( argc < 2 )
return EXIT_FAILURE;
auto parse = []( char const *str, bool hex, auto &value )
{
from_chars_result fcr = from_chars( str, str + strlen( str ), value,
!hex ? 10 : 16 );
return fcr.ec == errc() && !*fcr.ptr;
};
char const *sizeStr = argv[1];
bool hex = sizeStr[0] == '0' && (sizeStr[1] == 'x' || sizeStr[1] == 'X');
sizeStr += 2 * hex;
size_t end;
if( !parse( sizeStr, hex, end ) )
return EXIT_FAILURE;
if( (ptrdiff_t)end++ < 0 )
throw bad_alloc();
unsigned hc;
if( argc < 4 || !parse( argv[3], false, hc ) )
hc = jthread::hardware_concurrency();
if( !hc )
return EXIT_FAILURE;
using word_t = size_t;
constexpr size_t
BITS = sizeof(word_t) * 8,
CL_BITS = CL_SIZE * 8;
union alignas(CL_SIZE) ndi_words_cl { word_t words[CL_SIZE /
sizeof(word_t)]; ndi_words_cl() {} };
size_t nCls = (end + CL_BITS - 1) / CL_BITS;
vector<ndi_words_cl> sieveCachelines( (end + CL_BITS - 1) / CL_BITS );
size_t nWords = (end + BITS - 1) / BITS;
span<word_t> sieve( &sieveCachelines[0].words[0], nWords );
auto fill = sieve.begin();
*fill++ = (word_t)0xAAAAAAAAAAAAAAACu;
if( fill != sieve.end() )
for_each( fill, sieve.end(), []( word_t &w ) { w =
(word_t)0xAAAAAAAAAAAAAAAAu; } );
size_t sqrt = (ptrdiff_t)ceil( ::sqrt( (ptrdiff_t)end ) );
auto punch = [&]( auto, size_t start, size_t end, size_t prime )
{
if( start >= end )
return;
size_t multiple = start;
if( prime >= BITS ) [[likely]]
do
sieve[multiple / BITS] &= rotl( (word_t)-2, multiple % BITS );
while( (multiple += prime) < end );
else
{
auto pSieve = sieve.begin() + multiple / BITS;
do
{
word_t
word = *pSieve,
mask = rotl( (word_t)-2, multiple % BITS ),
oldMask;
do
word &= mask,
oldMask = mask,
mask = rotl( mask, prime % BITS ),
multiple += prime;
while( mask < oldMask );
*pSieve++ = word;
} while( multiple < end );
}
};
for( size_t prime = 3; prime < sqrt; ++prime )
{
for( auto pSieve = sieve.begin() + prime / BITS; ; )
{
if( word_t w = *pSieve >> prime % BITS; w ) [[likely]]
{
prime += countr_zero( w );
break;
}
prime = prime + BITS & -(ptrdiff_t)BITS;
if( prime >= sqrt )
break;
}
if( prime >= sqrt ) [[unlikely]]
break;
punch( false_type(), prime * prime, sqrt, prime );
}
auto scan = [&]( size_t start, size_t end, auto fn )
{
for( size_t prime = start; prime < end; )
{
word_t word = sieve[prime / BITS] >> prime % BITS;
bool hasBits = word;
for( unsigned gap; word; word >>= gap, word >>= 1 )
{
gap = countr_zero( word );
if( (prime += gap) >= end ) [[unlikely]]
return;
fn( prime );
if( ++prime >= end ) [[unlikely]]
return;
}
prime = (prime + BITS - hasBits) & -(ptrdiff_t)BITS;
}
};
size_t bitsPerPartition = (end - sqrt) / hc;
using range_t = pair<size_t, size_t>;
vector<pair<size_t, size_t>> ranges;
ranges.reserve( hc );
for( size_t t = hc, rEnd = end, rStart; t--; rEnd = rStart )
{
rStart = sqrt + t * bitsPerPartition;
size_t rClStart = rStart & -(CL_SIZE * 8);
rStart = rClStart >= sqrt ? rClStart : sqrt;
if( rStart >= rEnd )
continue;
ranges.emplace_back( rStart, rEnd );
}
vector<jthread> threads;
threads.reserve( ranges.size() - 1 );
static auto dbl = []( ptrdiff_t x ) { return (double)x; };
double maxPartitionBits = dbl( getL2Size() ) * 8 / 3;
for( range_t const &range : ranges )
{
auto rangePuncher = [&]( size_t rStart, size_t rEnd )
{
double nBits = dbl( rEnd - rStart );
ptrdiff_t nPartitions = (ptrdiff_t)ceil( nBits / maxPartitionBits );
double bitsPerPartition = nBits / dbl( nPartitions );
for( ptrdiff_t p = nPartitions, pEnd = rEnd, pStart; p--; pEnd = pStart )
{
pStart = rStart + (ptrdiff_t)((double)p * bitsPerPartition);
pStart &= -(CL_SIZE * 8);
scan( 3, sqrt,
[&]( size_t prime )
{
punch( true_type(), (pStart + prime - 1) / prime * prime, pEnd,
prime );
} );
}
};
if( &range != &ranges.back() )
threads.emplace_back( rangePuncher, range.first, range.second );
else
rangePuncher( range.first, range.second );
}
threads.resize( 0 );
if( argc >= 3 && !*argv[2] )
return EXIT_SUCCESS;
ofstream ofs;
ofs.exceptions( ofstream::failbit | ofstream::badbit );
ofs.open( argc >= 3 ? argv[2] : "primes.txt", ofstream::binary |
ofstream::trunc );
constexpr size_t
BUF_SIZE = 0x100000,
AHEAD = 32;
union ndi_char { char c; ndi_char() {} };
vector<ndi_char> printBuf( BUF_SIZE + AHEAD - 1 );
auto
bufBegin = printBuf.begin(),
bufEnd = bufBegin,
bufFlush = bufBegin + BUF_SIZE;
auto print = [&]() { ofs << string_view( &bufBegin->c, &to_address(
bufEnd )->c ); };
scan( 2, end,
[&]( size_t prime )
{
auto [ptr, ec] = to_chars( &bufEnd->c, &bufEnd[AHEAD - 1].c, prime );
if( ec != errc() ) [[unlikely]]
throw system_error( (int)ec, generic_category(), "converson failed" );
bufEnd = ptr - &bufBegin->c + bufBegin; // pointer to iterator - NOP
bufEnd++->c = '\n';
if( bufEnd >= bufFlush ) [[unlikely]]
print(), bufEnd = bufBegin;
} );
print();
}
size_t getL2Size()
{
int regs[4];
auto cpuId = [&]( unsigned code )
{
#if defined(_MSC_VER)
__cpuid( regs, code );
#elif defined(__GNUC__) || defined(__clang__)
__cpuid(code, regs[0], regs[1], regs[2], regs[3]);
#endif
return regs[0];
};
if( (unsigned)cpuId( 0x80000000u ) < 0x80000006u )
return 512ull * 1024;
cpuId( 0x80000006u );
return ((unsigned)regs[2] >> 16) * 1024;
}