#!/usr/bin/env python3
# 20240512 djb
# SPDX-License-Identifier: LicenseRef-PD-hp OR CC0-1.0 OR 0BSD OR MIT-0 OR MIT
# functionality notes:
# sample(n,t,x) reads sequence of t integers x
# with x[0] < n, x[1] < n-1, etc., all nonnegative
# sample(n,t,x) outputs sorted sequence
# of t distinct positions selected from range(n)
# which user can then map to weight-t binary vector
# or map to weight-t ternary vector, etc.
# if x is a uniform random sequence
# (e.g., generated by rejection sampling)
# then sample(n,t,x) is also a uniform random sequence
# which is tested for small n and t by test() below
# if x is close to uniform
# (e.g., generated by reductions of larger bit strings)
# then sample(n,t,x) is close to uniform;
# can prove bounds on, e.g., divergence or statistical distance
# typically x will be generated from an RNG (or "PRNG" or "DRBG" etc.)
# can use results for cryptosystems such as McEliece and NTRU
# speed notes:
# any merging algorithm will work
# can handle (n,t) via (n,n-t) by generating complement
# can guess n-t is faster if t > n/2, or benchmark
# can replace s = t//2 with anything between 1 and t-1
# can take, e.g., s as largest power of 2 in this range
# can account for merging benchmarks in choice of s
# whole algorithm is parallelizable and vectorizable
# but parallelism is limited if s is close to 1 or t-1
# constant-time notes:
# use any merging network
# e.g., use Batcher odd-even merging network
def merge(L,R):
return sorted(L+R)
def sample(n,t,x):
assert 0 <= t and t <= n
x = list(x)
assert len(x) == t
for j in range(t):
assert 0 <= x[j] and x[j] < n-j
if t == 0: return ()
if t == 1: return (x[0],)
s = t//2
L = sample(n,s,x[:s])
R = sample(n-s,t-s,x[s:])
L = [(Lj-j,0) for j,Lj in enumerate(L)]
R = [(Rj,1) for Rj in R]
result = []
numL = 0
for y,right in merge(L,R):
result += [y+numL]
numL += 1-right
return tuple(result)
def test():
import sys
for n in range(10):
for t in range(n+1):
print('n',n,'t',t)
sys.stdout.flush()
results = {}
z = 1
for j in range(t): z *= n-j
tfactorial = 1
for j in range(t): tfactorial *= t-j
for r in range(z):
x = []
for j in range(t):
x += [r%(n-j)]
r //= n-j
result = sample(n,t,x)
assert len(result) == t
assert result == tuple(sorted(set(result)))
if t > 0:
assert result[0] >= 0
assert result[t-1] < n
if result not in results:
results[result] = 0
results[result] += 1
assert len(results)*tfactorial == z
for result in results:
assert results[result] == tfactorial
if __name__ == '__main__':
test()