1428 words
7 minutes
AmateursCTF 2024 - Algorithms
2024-04-29
NOTE

Source code for all the challenges can be found here: https://github.com/les-amateurs/AmateursCTF-Public/tree/main/2024

I love algo

lis#

Problem#

Given an array a of length n (n <= 1e5), find indexes of the Longest increasing subsequence

Input#

3
14 3 13 4 12 7 9 10 4 7 2 1
1 4 1 5 1 6 7 1
1 2 1 3 1 4 5 2 63 29 1

Output#

-1 1 3 5 6 7
0 1 3 5 6
-11 -10 -8 -6 -5 -2 8

Solution#

Notice that negative index is also valid in python!

def check_output(arr, ans):
    ans.sort()
    for i in range(len(ans) - 1):
        assert arr[ans[i]] < arr[ans[i + 1]]

Simply find LIS of a + a and output LIS -= len(a)

from pwn import *

nc = remote('chal.amt.rs', 1410)

# Modified https://github.com/cheran-senthil/PyRival/blob/master/pyrival/misc/lis.py to return index instead of value
def lis(nums, cmp=lambda x, y: x < y):
    P = [0] * len(nums)
    M = [0] * (len(nums) + 1)
    L = 0

    for i in range(len(nums)):
        lo, hi = 1, L

        while lo <= hi:
            mid = (lo + hi) // 2
            if cmp(nums[M[mid]], nums[i]):
                lo = mid + 1
            else:
                hi = mid - 1

        newL = lo
        P[i] = M[newL - 1]
        M[newL] = i

        L = max(L, newL)

    S = [0] * L
    k = M[L]

    for i in range(L - 1, -1, -1):
        S[i], k = k, P[k]

    return S

t = int(nc.recvline().decode())
arrs = [list(map(int, nc.recvline().split())) for _ in range(t)]
for a in arrs:
    # negative index is also valid
    ans = [x - len(a) for x in lis(a + a)]

    nc.sendline((' '.join(map(str, ans))).encode())
nc.interactive()
> python solve.py
[+] Opening connection to chal.amt.rs on port 1410: Done
[*] Switching to interactive mode
Good job! Remember to orz larry. Here's your flag amateursCTF{orz-larry-how-is-larry-so-orz-ac3596ad5cba22151e721e205fb5b3120dd6910dc9b42af88ae52dfdfd073333}

orz-larry#

Problem#

Given a string s of length n (n <= 1e5), count number of distinct subsequence mod 1e9 + 9

Input#

3
abc
aba
aaa

Output#

8
7
4

Solution#

Output 940. Distinct Subsequences II + 1

from pwn import *

nc = remote('chal.amt.rs', 1412)

MOD = 10**9 + 9

t = int(nc.recvline().decode())
arrs = [nc.recvline().strip().decode() for _ in range(t)]

for s in arrs:
    dp = [1]
    last = {}
    for i, x in enumerate(s):
        dp.append(dp[-1] * 2)
        if x in last:
            dp[-1] -= dp[last[x]]
        last[x] = i

    nc.sendline(str(dp[-1] % MOD).encode())
nc.interactive()
> python solve.py
[+] Opening connection to chal.amt.rs on port 1412: Done
[*] Switching to interactive mode
Yay! Good job, here's your flag and remember to orz larry: amateursCTF{orz-larry-how-is-larry-so-orz-4efe27a2edde418184d668992819a62fa4b3a7e6ba5ac3a204be9a66ed7b7105}

omniscient-larry#

Problem#

Given a string s of length n (n <= 1e5) only contains o, z, l, y

Initially there is a string t = o || z || l || y, in one step you can change any character according to the following rule:

o -> lo, o -> oy
z -> yz, z -> zl
l -> ll, l -> oz
y -> yy, y -> zo

Count number of strings that can be reached from t and is a permutation of s mod 1e9 + 9

Input#

4
olozy
ollzy
oozzlllyyyyy
ooozzzzllllyyy

Output#

4
5
144
700

Explanation#

In the first test case

o -> lo -> llo -> ozlo -> oyzlo
o -> lo -> llo -> ozlo -> ozloy
o -> lo -> llo -> lozo -> loyzo
o -> lo -> llo -> lozo -> lozoy

In the second test case

l -> ll -> lll -> ozll -> oyzll
l -> ll -> loz -> lozl -> loyzl
l -> ll -> lll -> lloz -> lloyz
z -> zo -> zlo -> zloy -> zlloy
y -> zo -> yzo -> yzlo -> yzllo

Solution#

Define o, z, l, y as the count of 'o', 'z', 'l' and 'y'. Notice that t is a permutation of s if s.{o, z, l, y} = t.{o, z, l, y}

After running brute.rs on s = oozlyy:

use std::collections::HashSet;

pub const MOD: u32 = 1e9 as u32 + 9;

pub fn validate_string(s: &str) {
    for c in s.chars() {
        // orz larry
        assert!("ozly".contains(c));
    }
}

struct Solver {
    str: Vec<u8>,
    vis: HashSet<Vec<u8>>,
    ans: u32,
}

impl Solver {
    fn solve(s: String) -> u32 {
        validate_string(&s);

        let mut solver = Self {
            str: s.into_bytes(),
            vis: HashSet::new(),
            ans: 0,
        };
        solver.str.sort_unstable();

        for &c in b"ozly" {
            solver.dfs(vec![c]);
        }

        solver.ans
    }

    fn dfs(&mut self, mut s: Vec<u8>) {
        if !self.vis.insert(s.clone()) {
            return;
        } else if s.len() == self.str.len() {
            // check s is a permutation of `self.str`
            let org = s.clone();
            s.sort_unstable();

            if s == self.str {
                println!("{:?}", String::from_utf8(org.clone()));
                self.ans = (self.ans + 1) % MOD;
            }
            return;
        }

        for i in 0..s.len() {
            // perform an expansion - replace s[i] with some 2 character string
            let expansions = match s[i] {
                b'o' => [b"lo", b"oy"],
                b'z' => [b"yz", b"zl"],
                b'l' => [b"ll", b"oz"],
                b'y' => [b"yy", b"zo"],
                _ => unreachable!(),
            };

            for expansion in expansions {
                let mut next = s.clone();
                next.splice(i..=i, expansion.iter().copied());

                self.dfs(next);
            }
        }
    }
}

fn main() {
    let result = Solver::solve("oozzlyyy".to_string());
    println!("Result: {}", result);
}
oyyzlo
oyzloy
ozloyy
loyyzo
loyzoy
lozoyy
Result: 6

And on s = oozzlyy

oyyzozl
oyyzloz
oyzoyzl
oyzloyz
ozoyyzl
ozloyyz
loyyzoz
loyzoyz
lozoyyz
zlozoyy
zloyzoy
zloyyzo
zozloyy
zoyzloy
zoyyzlo
yzlozoy
yzloyzo
yzozloy
yzoyzlo
yyzlozo
yyzozlo
Total: 21

I found this pattern

if abs(z - o) > 1:
  ans = 0
else if o == z:
  - ...o...z...o...z...
  - ...z...o...z...o...
else
  - ...z...o...z...
  - ...o...z...o...

Where a box ... is either empty or l..l or y...y. The count of boxes with l is within one unit of the count of boxes with y

Then just do stars and bars. The number of ways to put n  identical objects into  k  labeled boxes is C(n - k + 1, n)

Think of z, o as bars and l, y as stars. There are 3 cases:

  • o = z
    • ...z...o...z...o...
      • Number of boxes is o * 2 + 1
      • Number of ways to put l into o boxes is C(l + o - 1, l)
      • Number of ways to put y into o + 1 boxes is C(y + o, y)
      • ➜ ans = C(l + o - 1, l) * C(y + o, y)
    • ...o...z...o...z...
      • Number of boxes is o * 2 + 1
      • Number of ways to put y into o boxes is C(y + o - 1, y)
      • Number of ways to put l into o + 1 boxes is C(l + o, l)
      • ➜ ans = C(y + o - 1, y) * C(l + o, l)
  • o > z
    • ...o...z...o...
      • Number of boxes is o * 2
      • Number of ways to put y into o boxes is C(y + o - 1, y)
      • Number of ways to put l into o boxes is C(l + o - 1, l)
      • ➜ ans = C(y + o - 1, y) * C(l + o - 1, l)
  • o < z
    • ...z...o...z...
      • Number of boxes is z * 2
      • Number of ways to put y into z boxes is C(y + z - 1, y)
      • Number of ways to put l into z boxes is C(l + z - 1, l)
      • ➜ ans = C(y + z - 1, y) * C(l + z - 1, l)
from pwn import *

MOD = 10**9 + 9

nc = remote('chal.amt.rs', 1411)
n = int(nc.recvline().decode())
arrs = [nc.recvline().strip().decode() for _ in range(n)]

# https://github.com/cheran-senthil/PyRival/blob/master/pyrival/combinatorics/nCr_mod.py
def make_nCr_mod(max_n=3 * 10**5, mod=MOD):
    max_n = min(max_n, mod - 1)

    fact, inv_fact = [0] * (max_n + 1), [0] * (max_n + 1)
    fact[0] = 1
    for i in range(max_n):
        fact[i + 1] = fact[i] * (i + 1) % mod

    inv_fact[-1] = pow(fact[-1], mod - 2, mod)
    for i in reversed(range(max_n)):
        inv_fact[i] = inv_fact[i + 1] * (i + 1) % mod

    def nCr_mod(n, r):
        res = 1
        while n or r:
            a, b = n % mod, r % mod
            if a < b:
                return 0
            res = res * fact[a] % mod * inv_fact[b] % mod * inv_fact[a - b] % mod
            n //= mod
            r //= mod
        return res

    return nCr_mod

C = make_nCr_mod()

def solve(s):
    o, z, l, y = 0, 0, 0, 0
    for i in range(len(s)):
        if s[i] == 'o': o += 1
        if s[i] == 'z': z += 1
        if s[i] == 'l': l += 1
        if s[i] == 'y': y += 1
    if abs(o - z) > 1:
        return 0
    if o == z:
        return (C(y + o, y) * C(l + o - 1, l) +
                C(y + o - 1, y) * C(l + o, l)) % MOD
    elif o > z:
        return C(y + o - 1, y) * C(l + o - 1, l) % MOD
    else:
        return C(y + z - 1, y) * C(l + z - 1, l) % MOD

for s in arrs:
    nc.sendline(str(solve(s)).encode())
nc.interactive()
> python solve.py
[+] Opening connection to chal.amt.rs on port 1411: Done
[*] Switching to interactive mode
Yay! Good job, here's your flag and remember to orz larry: amateursCTF{orz-larry-how-is-larry-so-orz-5318bfae97e201a66dc12069058e1b11d971ac7b24a8c87b2aec826dd39098d4}
AmateursCTF 2024 - Algorithms
https://rewhile.github.io/posts/amateurs-2024/
Author
rewhile
Published at
2024-04-29