#!/usr/bin/env python3 # -*- coding: utf-8 -*- import sys import os import math import socket from typing import List, Tuple def fib_fast(n: int) -> int: def f(k: int) -> Tuple[int, int]: if k == 0: return (0, 1) a, b = f(k >> 1) c = a * (2 * b - a) d = a * a + b * b if k & 1: return (d, c + d) else: return (c, d) return f(n)[0] def arith_sum(a: int, b: int, s: int) -> int: n = (b - a) // s + 1 return n * (a + b) // 2 def integral_poly(a: int, b: int, coeff: List[int]) -> int: res = 0 for k, c in enumerate(coeff): e = k + 1 if c % e != 0: raise ValueError("Non-integer term encountered; protocol assumption violated") ce = c // e res += ce * (pow(b, e) - pow(a, e)) return res def eval_expr(expr: str) -> int: t = expr.strip().split() if not t: raise ValueError("Empty expression") op = t[0] def i(x: str) -> int: return int(x) if op == "plus": return i(t[1]) + i(t[2]) if op == "minus": return i(t[1]) - i(t[2]) if op == "times": return i(t[1]) * i(t[2]) if op == "div": a, b = i(t[1]), i(t[2]) return a // b if op == "mod": a, b = i(t[1]), i(t[2]) return a % b if op == "pow": a, b = i(t[1]), i(t[2]) return pow(a, b) if op == "sqrt": a = i(t[1]) r = math.isqrt(a) return r if op == "abs": return abs(i(t[1])) if op == "fact": return math.factorial(i(t[1])) if op == "gcd": return math.gcd(i(t[1]), i(t[2])) if op == "lcm": a, b = i(t[1]), i(t[2]) g = math.gcd(a, b) return a // g * b if op == "det2": a, b, c, d = i(t[1]), i(t[2]), i(t[3]), i(t[4]) return a * d - b * c if op == "det3": a, b, c = i(t[1]), i(t[2]), i(t[3]) d, e, f = i(t[4]), i(t[5]), i(t[6]) g, h, j = i(t[7]), i(t[8]), i(t[9]) return a * (e * j - f * h) - b * (d * j - f * g) + c * (d * h - e * g) if op == "sum": if t[3] != "step": raise ValueError("Expected 'step' in sum expression") a, b, s = i(t[1]), i(t[2]), i(t[4]) return arith_sum(a, b, s) if op == "fib": return fib_fast(i(t[1])) if op == "lin1": a, b, c = i(t[1]), i(t[2]), i(t[3]) return (c - b) // a if op == "lin2x": a1, b1, c1, a2, b2, c2 = i(t[1]), i(t[2]), i(t[3]), i(t[4]), i(t[5]), i(t[6]) D = a1 * b2 - a2 * b1 return (c1 * b2 - c2 * b1) // D if op == "lin2y": a1, b1, c1, a2, b2, c2 = i(t[1]), i(t[2]), i(t[3]), i(t[4]), i(t[5]), i(t[6]) D = a1 * b2 - a2 * b1 return (a1 * c2 - a2 * c1) // D if op == "integral": a, b = i(t[1]), i(t[2]) if t[3] != "poly": raise ValueError("Expected 'poly' in integral expression") coeff = list(map(int, t[4:])) return integral_poly(a, b, coeff) raise ValueError(f"Unknown operator: {op}") def run(host: str, port: int): with socket.create_connection((host, port)) as sock: sock_file = sock.makefile("rwb", buffering=0) pending = None # (expr_str, answer) while True: line = sock_file.readline() if not line: break try: s = line.decode().rstrip("\r\n") except Exception: continue print(f"[SRV] {s}") if s.startswith("Task ") and ":" in s: try: expr = s.split(":", 1)[1].strip() except Exception: expr = "" ans = eval_expr(expr) pending = (expr, ans) print(f"[SOLVE] {expr} => {ans}") continue if s.strip().startswith(">") and pending is not None: print(f"[SEND] {pending[1]}") ans_bytes = str(pending[1]).encode() + b"\n" sock_file.write(ans_bytes) sock_file.flush() pending = None continue if "Here is your flag:" in s or s.endswith("Bye."): print(s) continue def main(): host = sys.argv[1] if len(sys.argv) >= 2 else os.getenv("HOST", "127.0.0.1") port = int(sys.argv[2]) if len(sys.argv) >= 3 else int(os.getenv("PORT", "5000")) run(host, port) if __name__ == "__main__": main()