#!/usr/bin/env python3 # 20240512 djb # SPDX-License-Identifier: LicenseRef-PD-hp OR CC0-1.0 OR 0BSD OR MIT-0 OR MIT # functionality notes: # sample(n,t,x) reads sequence of t integers x # with x[0] < n, x[1] < n-1, etc., all nonnegative # sample(n,t,x) outputs sorted sequence # of t distinct positions selected from range(n) # which user can then map to weight-t binary vector # or map to weight-t ternary vector, etc. # if x is a uniform random sequence # (e.g., generated by rejection sampling) # then sample(n,t,x) is also a uniform random sequence # which is tested for small n and t by test() below # if x is close to uniform # (e.g., generated by reductions of larger bit strings) # then sample(n,t,x) is close to uniform; # can prove bounds on, e.g., divergence or statistical distance # typically x will be generated from an RNG (or "PRNG" or "DRBG" etc.) # can use results for cryptosystems such as McEliece and NTRU # speed notes: # any merging algorithm will work # can handle (n,t) via (n,n-t) by generating complement # can guess n-t is faster if t > n/2, or benchmark # can replace s = t//2 with anything between 1 and t-1 # can take, e.g., s as largest power of 2 in this range # can account for merging benchmarks in choice of s # whole algorithm is parallelizable and vectorizable # but parallelism is limited if s is close to 1 or t-1 # constant-time notes: # use any merging network # e.g., use Batcher odd-even merging network def merge(L,R): return sorted(L+R) def sample(n,t,x): assert 0 <= t and t <= n x = list(x) assert len(x) == t for j in range(t): assert 0 <= x[j] and x[j] < n-j if t == 0: return () if t == 1: return (x[0],) s = t//2 L = sample(n,s,x[:s]) R = sample(n-s,t-s,x[s:]) L = [(Lj-j,0) for j,Lj in enumerate(L)] R = [(Rj,1) for Rj in R] result = [] numL = 0 for y,right in merge(L,R): result += [y+numL] numL += 1-right return tuple(result) def test(): import sys for n in range(10): for t in range(n+1): print('n',n,'t',t) sys.stdout.flush() results = {} z = 1 for j in range(t): z *= n-j tfactorial = 1 for j in range(t): tfactorial *= t-j for r in range(z): x = [] for j in range(t): x += [r%(n-j)] r //= n-j result = sample(n,t,x) assert len(result) == t assert result == tuple(sorted(set(result))) if t > 0: assert result[0] >= 0 assert result[t-1] < n if result not in results: results[result] = 0 results[result] += 1 assert len(results)*tfactorial == z for result in results: assert results[result] == tfactorial if __name__ == '__main__': test()