#!/usr/bin/env python3 # 20240515 djb # SPDX-License-Identifier: LicenseRef-PD-hp OR CC0-1.0 OR 0BSD OR MIT-0 OR MIT # functionality notes: # sample(n,t,x) reads sequence of t integers x # with x[0] < n, x[1] < n-1, etc., all nonnegative # sample(n,t,x) outputs sorted sequence # of t distinct positions selected from range(n) # which user can then map to weight-t binary vector # or map to weight-t ternary vector, etc. # if x is a uniform random sequence # (e.g., generated by rejection sampling) # then sample(n,t,x) is also a uniform random sequence # which is tested for small n and t by test() below # if x is close to uniform # (e.g., generated by reductions of larger bit strings) # then sample(n,t,x) is close to uniform; # can prove bounds on, e.g., divergence or statistical distance # typically x will be generated from an RNG (or "PRNG" or "DRBG" etc.) # can use results for cryptosystems such as McEliece and NTRU # can also tweak top level of recursion for s entries +1, t-s entries -1 # speed notes: # any merging algorithm will work # can handle (n,t) via (n,n-t) by generating complement # can guess n-t is faster if t > n/2, or benchmark # can replace s = t//2 with anything between 1 and t-1 # (this does not affect the output) # can take, e.g., s as largest power of 2 in this range # can account for merging benchmarks in choice of s # whole algorithm is parallelizable and vectorizable # but parallelism is limited if s is close to 1 or t-1 # constant-time notes: # use any merging network # e.g., use Batcher odd-even merging network def merge(L,R): return sorted(L+R) def assertvalidinputs(n,t,x): assert 0 <= t and t <= n assert len(x) == t for j in range(t): assert 0 <= x[j] and x[j] < n-j def simplesample(n,t,x): x = list(x) assertvalidinputs(n,t,x) L = [0]*(n-t) for xj in reversed(x): L = L[:xj]+[1]+L[xj:] return tuple(j for j,Lj in enumerate(L) if Lj) def simplesample2(n,t,x): x = list(x) assertvalidinputs(n,t,x) L = [] for xj in reversed(x): L = [Lj for Lj in L if Lj 0: assert result[0] >= 0 assert result[t-1] < n if result not in results: results[result] = 0 results[result] += 1 assert len(results)*tfactorial == z for result in results: assert results[result] == tfactorial if __name__ == '__main__': test()