163 lines
4.4 KiB
Python
163 lines
4.4 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
import sys
|
|
import os
|
|
import math
|
|
import socket
|
|
from typing import List, Tuple
|
|
|
|
|
|
def fib_fast(n: int) -> int:
|
|
def f(k: int) -> Tuple[int, int]:
|
|
if k == 0:
|
|
return (0, 1)
|
|
a, b = f(k >> 1)
|
|
c = a * (2 * b - a)
|
|
d = a * a + b * b
|
|
if k & 1:
|
|
return (d, c + d)
|
|
else:
|
|
return (c, d)
|
|
return f(n)[0]
|
|
|
|
|
|
def arith_sum(a: int, b: int, s: int) -> int:
|
|
|
|
n = (b - a) // s + 1
|
|
return n * (a + b) // 2
|
|
|
|
|
|
def integral_poly(a: int, b: int, coeff: List[int]) -> int:
|
|
|
|
res = 0
|
|
for k, c in enumerate(coeff):
|
|
e = k + 1
|
|
if c % e != 0:
|
|
raise ValueError("Non-integer term encountered; protocol assumption violated")
|
|
ce = c // e
|
|
res += ce * (pow(b, e) - pow(a, e))
|
|
return res
|
|
|
|
|
|
def eval_expr(expr: str) -> int:
|
|
t = expr.strip().split()
|
|
if not t:
|
|
raise ValueError("Empty expression")
|
|
op = t[0]
|
|
|
|
def i(x: str) -> int:
|
|
return int(x)
|
|
|
|
if op == "plus":
|
|
return i(t[1]) + i(t[2])
|
|
if op == "minus":
|
|
return i(t[1]) - i(t[2])
|
|
if op == "times":
|
|
return i(t[1]) * i(t[2])
|
|
if op == "div":
|
|
a, b = i(t[1]), i(t[2])
|
|
return a // b
|
|
if op == "mod":
|
|
a, b = i(t[1]), i(t[2])
|
|
return a % b
|
|
if op == "pow":
|
|
a, b = i(t[1]), i(t[2])
|
|
return pow(a, b)
|
|
if op == "sqrt":
|
|
a = i(t[1])
|
|
r = math.isqrt(a)
|
|
return r
|
|
if op == "abs":
|
|
return abs(i(t[1]))
|
|
if op == "fact":
|
|
return math.factorial(i(t[1]))
|
|
if op == "gcd":
|
|
return math.gcd(i(t[1]), i(t[2]))
|
|
if op == "lcm":
|
|
a, b = i(t[1]), i(t[2])
|
|
g = math.gcd(a, b)
|
|
return a // g * b
|
|
if op == "det2":
|
|
a, b, c, d = i(t[1]), i(t[2]), i(t[3]), i(t[4])
|
|
return a * d - b * c
|
|
if op == "det3":
|
|
a, b, c = i(t[1]), i(t[2]), i(t[3])
|
|
d, e, f = i(t[4]), i(t[5]), i(t[6])
|
|
g, h, j = i(t[7]), i(t[8]), i(t[9])
|
|
return a * (e * j - f * h) - b * (d * j - f * g) + c * (d * h - e * g)
|
|
if op == "sum":
|
|
if t[3] != "step":
|
|
raise ValueError("Expected 'step' in sum expression")
|
|
a, b, s = i(t[1]), i(t[2]), i(t[4])
|
|
return arith_sum(a, b, s)
|
|
if op == "fib":
|
|
return fib_fast(i(t[1]))
|
|
if op == "lin1":
|
|
a, b, c = i(t[1]), i(t[2]), i(t[3])
|
|
return (c - b) // a
|
|
if op == "lin2x":
|
|
a1, b1, c1, a2, b2, c2 = i(t[1]), i(t[2]), i(t[3]), i(t[4]), i(t[5]), i(t[6])
|
|
D = a1 * b2 - a2 * b1
|
|
return (c1 * b2 - c2 * b1) // D
|
|
if op == "lin2y":
|
|
a1, b1, c1, a2, b2, c2 = i(t[1]), i(t[2]), i(t[3]), i(t[4]), i(t[5]), i(t[6])
|
|
D = a1 * b2 - a2 * b1
|
|
return (a1 * c2 - a2 * c1) // D
|
|
if op == "integral":
|
|
a, b = i(t[1]), i(t[2])
|
|
if t[3] != "poly":
|
|
raise ValueError("Expected 'poly' in integral expression")
|
|
coeff = list(map(int, t[4:]))
|
|
return integral_poly(a, b, coeff)
|
|
|
|
raise ValueError(f"Unknown operator: {op}")
|
|
|
|
|
|
def run(host: str, port: int):
|
|
with socket.create_connection((host, port)) as sock:
|
|
sock_file = sock.makefile("rwb", buffering=0)
|
|
pending = None # (expr_str, answer)
|
|
while True:
|
|
line = sock_file.readline()
|
|
if not line:
|
|
break
|
|
try:
|
|
s = line.decode().rstrip("\r\n")
|
|
except Exception:
|
|
continue
|
|
print(f"[SRV] {s}")
|
|
|
|
if s.startswith("Task ") and ":" in s:
|
|
try:
|
|
expr = s.split(":", 1)[1].strip()
|
|
except Exception:
|
|
expr = ""
|
|
ans = eval_expr(expr)
|
|
pending = (expr, ans)
|
|
print(f"[SOLVE] {expr} => {ans}")
|
|
continue
|
|
|
|
if s.strip().startswith(">") and pending is not None:
|
|
print(f"[SEND] {pending[1]}")
|
|
ans_bytes = str(pending[1]).encode() + b"\n"
|
|
sock_file.write(ans_bytes)
|
|
sock_file.flush()
|
|
pending = None
|
|
continue
|
|
|
|
if "Here is your flag:" in s or s.endswith("Bye."):
|
|
print(s)
|
|
continue
|
|
|
|
|
|
def main():
|
|
host = sys.argv[1] if len(sys.argv) >= 2 else os.getenv("HOST", "127.0.0.1")
|
|
port = int(sys.argv[2]) if len(sys.argv) >= 3 else int(os.getenv("PORT", "5000"))
|
|
run(host, port)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|