diff --git a/assets/dash.css b/assets/dash.css new file mode 100644 index 000000000..fa03745b9 --- /dev/null +++ b/assets/dash.css @@ -0,0 +1,22 @@ +:root { + --font-color: #F1F1F1; + --dropdown-bg: #005050; +} + +body { + background-color: black !important; + color: var(--font-color) !important; +} + +.rc-slider-mark-text { + color: var(--font-color) !important; +} + +.Select-control, .Select-menu-outer, .Select-value-label, .Select-option { + color: var(--font-color) !important; + background-color: var(--dropdown-bg) !important; +} + +h1, h2, h3, h4, h5, h6 { + color: var(--font-color) !important; +} diff --git a/cache_data.py b/cache_data.py new file mode 100644 index 000000000..d1d5c0a13 --- /dev/null +++ b/cache_data.py @@ -0,0 +1,185 @@ +import numpy as np + +import json +import glob +import os + + +env_names = sorted([ + 'breakout', + 'impulse_wars', + 'pacman', + 'tetris', + 'g2048', + 'moba', + 'pong', + 'tower_climb', + 'grid', + 'nmmo3', + 'snake', + 'tripletriad' +]) + +HYPERS = [ + 'train/learning_rate', + 'train/ent_coef', + 'train/gamma', + 'train/gae_lambda', + 'train/vtrace_rho_clip', + 'train/vtrace_c_clip', + 'train/clip_coef', + 'train/vf_clip_coef', + 'train/vf_coef', + 'train/max_grad_norm', + 'train/adam_beta1', + 'train/adam_beta2', + 'train/adam_eps', + 'train/prio_alpha', + 'train/prio_beta0', + 'train/bptt_horizon', + 'train/num_minibatches', + 'train/minibatch_size', + 'policy/hidden_size', + 'env/num_envs', +] + +ALL_KEYS = [ + 'agent_steps', + 'cost', + 'environment/score', + 'environment/perf' +] + HYPERS + +def pareto_idx(steps, costs, scores): + idxs = [] + for i in range(len(steps)): + better = [scores[j] >= scores[i] and + costs[j] < costs[i] and steps[j] < steps[i] + for j in range(len(scores))] + if not any(better): + idxs.append(i) + + return idxs + +def load_sweep_data(path): + data = {} + keys = None + for fpath in glob.glob(path): + if 'cache.json' in fpath: + continue + + with open(fpath, 'r') as f: + exp = json.load(f) + + if not data: + for kk in exp.keys(): + if kk == 'data': + for k, v in exp[kk][-1].items(): + data[k] = [] + else: + data[kk] = [] + + discard = False + for kk in list(data.keys()): + if kk not in exp and kk not in exp['data'][-1]: + discard = True + break + + if discard: + continue + + for kk in list(data.keys()): + if kk in exp: + v = exp[kk] + sweep_key = f'sweep/{kk}/distribution' + if sweep_key in data and exp[sweep_key] == 'logit_normal': + v = 1 - v + elif kk in ('train/vtrace_rho_clip', 'train/vtrace_c_clip'): + v = max(v, 0.1) + + data[kk].append(v) + else: + data[kk].append(exp['data'][-1][kk]) + + steps = data['agent_steps'] + costs = data['cost'] + scores = data['environment/score'] + + idxs = pareto_idx(steps, costs, scores) + + # Filter to pareto + for k in data: + data[k] = [data[k][i] for i in idxs] + + # Monkey patch: Cap performance + data['environment/perf'] = [min(e, 1.0) for e in data['environment/perf']] + + # Monkey patch: Adjust steps by frameskip if present + if 'env/frameskip' in data: + skip = data['env/frameskip'] + data['agent_steps'] = [n*m for n, m in zip(data['agent_steps'], skip)] + + return data + +def cached_sweep_load(path, env_name): + cache_file = os.path.join(path, 'c_cache.json') + if not os.path.exists(cache_file): + data = load_sweep_data(os.path.join(path, '*.json')) + with open(cache_file, 'w') as f: + json.dump(data, f) + + with open(cache_file, 'r') as f: + data = json.load(f) + + print(f'Loaded {env_name}') + return data + +def compute_tsne(): + data = {name: cached_sweep_load(f'experiments/logs/puffer_{name}', name) for name in env_names} + + flat = [] + flat_mmin = [] + flat_mmax = [] + for env in env_names: + flat.append(np.stack([data[env][hyper] for hyper in HYPERS], axis=1)) + flat_mmin.append(np.stack([data[env][f'sweep/{hyper}/min'] for hyper in HYPERS], axis=1)) + flat_mmax.append(np.stack([data[env][f'sweep/{hyper}/max'] for hyper in HYPERS], axis=1)) + + flat_distribution = [data[env][f'sweep/{hyper}/distribution'] for env in env_names for hyper in HYPERS] + + flat = np.concatenate(flat, axis=0) + flat_mmin = np.concatenate(flat_mmin, axis=0).min(axis=0) + flat_mmax = np.concatenate(flat_mmax, axis=0).max(axis=0) + + normed = flat.copy() + for i in range(len(HYPERS)): + dist = flat_distribution[i] + if 'log' in dist or 'pow2' in dist: + flat_mmin[i] = np.log(flat_mmin[i]) + flat_mmax[i] = np.log(flat_mmax[i]) + normed[:, i] = np.log(flat[:, i]) + + normed[:, i] = (normed[:, i] - flat_mmin[i]) / (flat_mmax[i] - flat_mmin[i]) + + from sklearn.manifold import TSNE + proj = TSNE(n_components=2) + reduced = proj.fit_transform(normed) + + row = 0 + for env in env_names: + ''' + for i, hyper in enumerate(HYPERS): + sz = len(data[env][hyper]) + data[env][hyper] = normed[row:row+sz, i].tolist() + ''' + sz = len(data[env]['agent_steps']) + + data[env] = {k: v for k, v in data[env].items() if k in ALL_KEYS} + data[env]['tsne1'] = reduced[row:row+sz, 0].tolist() + data[env]['tsne2'] = reduced[row:row+sz, 1].tolist() + row += sz + + json.dump(data, open('all_cache.json', 'w')) + +if __name__ == '__main__': + compute_tsne() diff --git a/compile_puffer.py b/compile_puffer.py new file mode 100644 index 000000000..73d6a665c --- /dev/null +++ b/compile_puffer.py @@ -0,0 +1,196 @@ +import torch +from torch import nn +from torch.utils.benchmark import Timer +from torch.utils.flop_counter import FlopCounterMode +from torch import func + +from torch.backends import cudnn +cudnn.benchmark = True +cudnn.deterministic = False +cudnn.benchmark_limit = 32 + +torch.set_float32_matmul_precision('high') +torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True + +class Default(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super().__init__() + self.hidden_size = hidden_size + self.encoder = torch.nn.Sequential( + nn.Linear(input_size, hidden_size), + nn.GELU(), + ) + self.decoder = nn.Linear(hidden_size, output_size) + self.value = nn.Linear(hidden_size, 1) + + def forward(self, observations): + hidden = self.encode_observations(observations) + logits, values = self.decode_actions(hidden) + return logits, values + + def encode_observations(self, observations, state=None): + batch_size = observations.shape[0] + observations = observations.view(batch_size, -1) + return self.encoder(observations) + + def decode_actions(self, hidden): + logits = self.decoder(hidden) + values = self.value(hidden) + return logits, values + + +class LSTMWrapper(nn.Module): + def __init__(self, policy, input_size, hidden_size, output_size): + super().__init__() + self.policy = policy + input_size = hidden_size + + self.input_size = input_size + self.hidden_size = hidden_size + + self.cell = torch.nn.LSTMCell(input_size, hidden_size) + + def forward(self, observations, h, c): + hidden = self.policy.encode_observations(observations) + hidden, c = self.cell(hidden, (h, c)) + logits, values = self.policy.decode_actions(hidden) + return logits, values, hidden, c + +def get_params_and_buffers(model): + buffers = dict(model.named_buffers()) + param_names = [k for k, v in model.named_parameters() if v.requires_grad] + params = [v for k, v in model.named_parameters() if v.requires_grad] + params_dict = dict(zip(param_names, params)) + return {**buffers, **params_dict} + + +@torch.compile(fullgraph=True, dynamic=False, mode='reduce-overhead') +def functional_forward(model, params_and_buffers, batch, h, c): + return func.functional_call(model, params_and_buffers, (batch, h, c)) + +def rollout(model, params_and_buffers, batch, h, c, seq): + all_logits = [] + all_values = [] + for i in range(seq): + logits, values, h, c = functional_forward(model, params_and_buffers, batch[i], h, c) + all_logits.append(logits) + all_values.append(values) + + logits = torch.stack(all_logits, dim=0) + values = torch.stack(all_values, dim=0) + + return logits, values + +@torch.compile(fullgraph=True, dynamic=False, mode='reduce-overhead') +def fast_rollout(model, batch, h, c, seq): + logits = torch.empty(seq, batch.shape[1], OUTPUT_SIZE, device=batch.device, dtype=batch.dtype) + values = torch.empty(seq, batch.shape[1], 1, device=batch.device, dtype=batch.dtype) + for i in range(seq): + l, v, h, c = model(batch[i], h, c) + logits[i] = l + values[i] = v + + return logits, values + +def evaluate(model, params_and_buffers, batch, h, c, seq): + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + return fast_rollout(model, batch, h, c, seq) + +def compute_loss(params_and_buffers, model, batch, h, c, seq): + logits, values = rollout(model, params_and_buffers, batch, h, c, seq) + loss = -torch.log(torch.softmax(logits, dim=-1)).mean() + (values**2).mean() + return loss + +grad_fn = torch.compile(func.grad(compute_loss), + fullgraph=True, dynamic=False, mode='reduce-overhead') + +#grad_fn = func.grad(compute_loss) + +def train(model, params_and_buffers, batch, h, c, loops, seq): + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for _ in range(loops): + grads = grad_fn(params_and_buffers, model, batch, h, c, seq) + for name in grads: + params_and_buffers[name].sub_(0.01 * grads[name]) + + return params_and_buffers + +if __name__ == '__main__': + INPUT_SIZE = 128 + HIDDEN_SIZE = 128 + OUTPUT_SIZE = 4 + B = 256 + SEQ = 64 + LOOPS = 4 + dtype = torch.bfloat16 + + model = LSTMWrapper( + Default(INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE), + INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE + ).cuda() + + # TODO: carefully test slowdown from this + params_and_buffers = get_params_and_buffers(model) + #model = torch.compile(model, mode='reduce-overhead', dynamic=False, fullgraph=True) + + # Create input batch + batch = torch.randn(SEQ, B, INPUT_SIZE).cuda().to(dtype) + + # Define a multi-step function to run multiple forwards in one compiled graph + # Manual FLOPs calculation + I = INPUT_SIZE + H = HIDDEN_SIZE + O = OUTPUT_SIZE + flops = B * (2*I*H + 16*H*H + 2*H*O + 2*H) + + h = torch.zeros(B, HIDDEN_SIZE).cuda().to(dtype) + c = torch.zeros(B, HIDDEN_SIZE).cuda().to(dtype) + + # Warmup + for _ in range(3): + _ = evaluate(model, params_and_buffers, batch, h, c, SEQ) + # Timing + timer = Timer( + stmt='evaluate(model, params_and_buffers, batch, h, c, SEQ)', + globals={ + 'evaluate': evaluate, + 'params_and_buffers': params_and_buffers, + 'model': model, + 'batch': batch, + 'h': h, + 'c': c, + 'SEQ': SEQ, + } + ) + output = timer.timeit(LOOPS) + + cost = output.mean / SEQ # Average time per forward pass (fixed from times[0] to mean) + FLOPS = flops / cost + perf_evaluate = f'FLOPS: {FLOPS / 1e12:.2f}T, SPS: {B/cost/1e6:.2f}M' + + # Warmup + for _ in range(1): + _ = train(model, params_and_buffers, batch, h, c, LOOPS, SEQ) + + # Timing + timer = Timer( + stmt='train(model, params_and_buffers, batch, h, c, LOOPS, SEQ)', + globals={ + 'train': train, + 'params_and_buffers': params_and_buffers, + 'model': model, + 'batch': batch, + 'h': h, + 'c': c, + 'LOOPS': LOOPS, + 'SEQ': SEQ, + } + ) + + output = timer.timeit(1) + cost = output.mean / SEQ / LOOPS # Average time per forward pass (fixed from times[0] to mean) + FLOPS = 3*flops / cost + perf_train = f'FLOPS: {FLOPS / 1e12:.2f}T, SPS: {B/cost/1e6:.2f}M' + + print(perf_evaluate) + print(perf_train) diff --git a/constellation.py b/constellation.py new file mode 100644 index 000000000..a0bed137f --- /dev/null +++ b/constellation.py @@ -0,0 +1,890 @@ +from dash import Dash, html, dcc +from dash.dependencies import Input, Output +import pandas as pd +import plotly.graph_objects as go +import plotly.express as px +import numpy as np +import json +import glob +import os + +FONT_FAMILY = 'Arial' +FONT_SIZE_TITLE = 28 +FONT_SIZE_AXIS = 22 +FONT_SIZE_TICK = 20 +FONT_SIZE_TICK_3D = 14 +FONT_SIZE_LEGEND = 18 +FONT_COLOR = '#f1f1f1' +PLOT_BG_COLOR = '#061a1a' +PAPER_BG_COLOR = '#061a1a' +LINE_WIDTH = 4 +LINE_COLORS = ["#0000b3", "#0010d9", "#0020ff", "#0040ff", "#0060ff", "#0080ff", "#009fff", "#00bfff", "#00ffff"][::-1] +roygbiv = np.random.permutation(['aliceblue', 'antiquewhite', 'aqua', 'aquamarine', 'azure', 'beige', 'bisque', 'black', 'blanchedalmond', 'blue', 'blueviolet', 'brown', 'burlywood', 'cadetblue', 'chartreuse', 'chocolate', 'coral', 'cornflowerblue', 'cornsilk', 'crimson', 'cyan', 'darkblue', 'darkcyan', 'darkgoldenrod', 'darkgray', 'darkgrey', 'darkgreen', 'darkkhaki', 'darkmagenta', 'darkolivegreen', 'darkorange', 'darkorchid', 'darkred', 'darksalmon', 'darkseagreen', 'darkslateblue', 'darkslategray', 'darkslategrey', 'darkturquoise', 'darkviolet', 'deeppink', 'deepskyblue', 'dimgray', 'dimgrey', 'dodgerblue', 'firebrick', 'floralwhite', 'forestgreen', 'fuchsia', 'gainsboro', 'ghostwhite', 'gold', 'goldenrod', 'gray', 'grey', 'green', 'greenyellow', 'honeydew', 'hotpink', 'indianred', 'indigo', 'ivory', 'khaki', 'lavender', 'lavenderblush', 'lawngreen', 'lemonchiffon', 'lightblue', 'lightcoral', 'lightcyan', 'lightgoldenrodyellow', 'lightgray', 'lightgrey', 'lightgreen', 'lightpink', 'lightsalmon', 'lightseagreen', 'lightskyblue', 'lightslategray', 'lightslategrey', 'lightsteelblue', 'lightyellow', 'lime', 'limegreen', 'linen', 'magenta', 'maroon', 'mediumaquamarine', 'mediumblue', 'mediumorchid', 'mediumpurple', 'mediumseagreen', 'mediumslateblue', 'mediumspringgreen', 'mediumturquoise', 'mediumvioletred', 'midnightblue', 'mintcream', 'mistyrose', 'moccasin', 'navajowhite', 'navy', 'oldlace', 'olive', 'olivedrab', 'orange', 'orangered', 'orchid', 'palegoldenrod', 'palegreen', 'paleturquoise', 'palevioletred', 'papayawhip', 'peachpuff', 'peru', 'pink', 'plum', 'powderblue', 'purple', 'red', 'rosybrown', 'royalblue', 'saddlebrown', 'salmon', 'sandybrown', 'seagreen', 'seashell', 'sienna', 'silver', 'skyblue', 'slateblue', 'slategray', 'slategrey', 'snow', 'springgreen', 'steelblue', 'tan', 'teal', 'thistle', 'tomato', 'turquoise', 'violet', 'wheat', 'white', 'whitesmoke', 'yellow', 'yellowgreen']) +#roygbiv = ['red', 'orange', 'yellow', 'green', 'blue', 'indigo', 'violet'] +TITLE_FONT = dict( + family=FONT_FAMILY, + size=FONT_SIZE_TITLE, + color=FONT_COLOR +) +AXIS_FONT = dict( + family=FONT_FAMILY, + size=FONT_SIZE_AXIS, + color=FONT_COLOR +) +TICK_FONT = dict( + family=FONT_FAMILY, + size=FONT_SIZE_TICK, + color=FONT_COLOR +) +GRID_COLOR = '#00f1f1' +TICK_FONT_3D = dict( + family=FONT_FAMILY, + size=FONT_SIZE_TICK_3D, + color=FONT_COLOR +) +LEGEND_FONT = dict( + family=FONT_FAMILY, + size=FONT_SIZE_LEGEND, + color=FONT_COLOR +) +HYPERS = [ + 'train/learning_rate', + 'train/ent_coef', + 'train/gamma', + 'train/gae_lambda', + 'train/vtrace_rho_clip', + 'train/vtrace_c_clip', + 'train/clip_coef', + 'train/vf_clip_coef', + 'train/vf_coef', + 'train/max_grad_norm', + 'train/adam_beta1', + 'train/adam_beta2', + 'train/adam_eps', + 'train/prio_alpha', + 'train/prio_beta0', + 'train/bptt_horizon', + 'train/num_minibatches', + 'train/minibatch_size', + 'policy/hidden_size', + 'env/num_envs', +] +ALL_KEYS = [ + 'agent_steps', + 'cost', + 'environment/score', + 'environment/perf' +] + HYPERS + +SCATTER_COLOR = ['env_name'] + ALL_KEYS + +import colorsys +import numpy as np + +def rgb_to_hex(rgb): + """Convert RGB tuple to hex string.""" + return '#%02x%02x%02x' % (int(rgb[0]*255), int(rgb[1]*255), int(rgb[2]*255)) + +def generate_distinct_palette(n): + """ + Generate a palette with n maximally distinct colors across the hue spectrum. + + Parameters: + n (int): Number of colors to generate. + + Returns: + list: List of hex color strings. + """ + if n < 1: + raise ValueError("n must be at least 1") + + # Generate hues evenly spaced across the spectrum (0 to 1) + hues = np.linspace(0, 1, n, endpoint=False) + colors = [] + for hue in hues: + # Use full saturation and value for vivid colors + rgb = colorsys.hsv_to_rgb(hue, 1.0, 1.0) + colors.append(rgb) + hex_colors = [rgb_to_hex(color) for color in colors] + return hex_colors + +def pareto_idx(steps, costs, scores): + idxs = [] + for i in range(len(steps)): + better = [scores[j] >= scores[i] and + costs[j] < costs[i] and steps[j] < steps[i] + for j in range(len(scores))] + if not any(better): + idxs.append(i) + + return idxs + +def build_dataset(dataframe): + dataset = [] + for hyper in HYPERS: + dat = dataframe[hyper] + #mmin = dataframe[f'sweep/{hyper}/min'] + #mmax = dataframe[f'sweep/{hyper}/max'] + #distribution = dataframe[f'sweep/{hyper}/distribution'] + + + +def load_sweep_data(path): + data = {} + keys = None + for fpath in glob.glob(path): + with open(fpath, 'r') as f: + exp = json.load(f) + + if not data: + for kk in exp.keys(): + if kk == 'data': + for k, v in exp[kk][-1].items(): + data[k] = [] + else: + data[kk] = [] + + discard = False + for kk in list(data.keys()): + if kk not in exp and kk not in exp['data'][-1]: + discard = True + break + + if discard: + continue + + for kk in list(data.keys()): + if kk in exp: + v = exp[kk] + sweep_key = f'sweep/{kk}/distribution' + if sweep_key in data and exp[sweep_key] == 'logit_normal': + v = 1 - v + elif kk in ('train/vtrace_rho_clip', 'train/vtrace_c_clip'): + v = max(v, 0.1) + + data[kk].append(v) + else: + data[kk].append(exp['data'][-1][kk]) + + return data + +def cached_sweep_load(path, env_name): + cache_file = os.path.join(path, 'cache.json') + if not os.path.exists(cache_file): + data = load_sweep_data(os.path.join(path, '*.json')) + with open(cache_file, 'w') as f: + json.dump(data, f) + + with open(cache_file, 'r') as f: + data = json.load(f) + + steps = data['agent_steps'] + costs = data['cost'] + scores = data['environment/score'] + + idxs = pareto_idx(steps, costs, scores) + + # Create a DataFrame for this environment + df_data = {} + for k in data: + df_data[k] = [data[k][i] for i in idxs] + + # Apply performance cap + df_data['environment/perf'] = [min(e, 1.0) for e in df_data['environment/perf']] + + # Adjust steps by frameskip if present + if 'env/frameskip' in df_data: + skip = df_data['env/frameskip'] + df_data['agent_steps'] = [n*m for n, m in zip(df_data['agent_steps'], skip)] + + # Add environment name + df_data['env_name'] = [env_name] * len(idxs) + + return pd.DataFrame(df_data) + +def compute_tsne(): + dataset = EXPERIMENTS[HYPERS].copy() # Create a copy to avoid modifying the original + + # Normalize each hyperparameter column using its corresponding min and max columns + for hyper in HYPERS: + min_col = f'sweep/{hyper}/min' + max_col = f'sweep/{hyper}/max' + + mmin = min(EXPERIMENTS[min_col]) + mmax = max(EXPERIMENTS[max_col]) + + distribution = EXPERIMENTS[f'sweep/{hyper}/distribution'] + if 'log' in distribution or 'pow2' in distribution: + mmin = np.log(mmin) + mmax = np.log(mmax) + normed = np.log(dataset[hyper]) + else: + normed = dataset[hyper] + + dataset[hyper] = (normed - mmin) / (mmax - mmin) + # Normalize: (value - min) / (max - min) for each row + + #dataset[hyper] = (dataset[hyper] - EXPERIMENTS[min_col]) / (EXPERIMENTS[max_col] - EXPERIMENTS[min_col]) + + # Filter dataset based on performance threshold + # Apply TSNE + from sklearn.manifold import TSNE + proj = TSNE(n_components=2) + reduced = proj.fit_transform(dataset) + EXPERIMENTS['tsne1'] = reduced[:, 0] + EXPERIMENTS['tsne2'] = reduced[:, 1] + +env_names = ['tripletriad', 'grid', 'moba', 'tower_climb', 'tetris', 'breakout', 'pong', 'g2048', 'snake', 'pacman'] +env_all = ['all'] + env_names +#env_names = ['grid', 'breakout', 'g2048'] +#env_names = ['grid'] + +roygbiv = generate_distinct_palette(len(env_names)) + +# Create a list of DataFrames for each environment +dfs = [cached_sweep_load(f'experiments/logs/puffer_{name}', name) for name in env_names] + +# Concatenate all DataFrames into a single DataFrame +EXPERIMENTS = pd.concat(dfs, ignore_index=True) +#EXPERIMENTS.set_index('env_name', inplace=True) +compute_tsne() + +app = Dash() +app.css.append_css({'external_stylesheets': 'dash.css'}) +app.layout = html.Div([ + html.H1('Puffer Constellation', style={'textAlign': 'center'}), + html.Br(), + + html.Label([ + "X: ", + dcc.Dropdown( + id="optimal-dropdown-x", + options=[{"label": key, "value": key} for key in ALL_KEYS], + value="cost", + style={"width": "50%"} + ) + ]), + html.Label([ + "Y: ", + dcc.Dropdown( + id="optimal-dropdown-y", + options=[{"label": key, "value": key} for key in ALL_KEYS], + value="agent_steps", + style={"width": "50%"} + ) + ]), + html.Label([ + "Z: ", + dcc.Dropdown( + id="optimal-dropdown-z", + options=[{"label": key, "value": key} for key in ALL_KEYS], + value="environment/perf", + style={"width": "50%"} + ) + ]), + dcc.Graph(id='optimal'), + html.Br(), + + html.Label([ + "Environment: ", + dcc.Dropdown( + id="scatter-dropdown-env", + options=[{"label": key, "value": key} for key in env_all], + value="all", + style={"width": "50%"} + ) + ]), + html.Br(), + html.Label([ + "X: ", + dcc.Dropdown( + id="scatter-dropdown-x", + options=[{"label": key, "value": key} for key in ALL_KEYS], + value="train/learning_rate", + style={"width": "50%"} + ), + dcc.Checklist( + id="scatter-checkbox-logx", + options=[{"label": "Log", "value": "log"}], + value=["log"], + style={"display": "inline-block", "margin-left": "10px"} + ), + ]), + html.Br(), + html.Label([ + "Y: ", + dcc.Dropdown( + id="scatter-dropdown-y", + options=[{"label": key, "value": key} for key in ALL_KEYS], + value="environment/perf", + style={"width": "50%"} + ), + dcc.Checklist( + id="scatter-checkbox-logy", + options=[{"label": "Log", "value": "log"}], + value=[], + style={"display": "inline-block", "margin-left": "10px"} + ), + + ]), + html.Br(), + html.Label([ + "Color: ", + dcc.Dropdown( + id="scatter-dropdown-color", + options=[{"label": key, "value": key} for key in SCATTER_COLOR], + value="env_name", + style={"width": "50%"} + ) + ]), + html.Br(), + html.Label([ + "Range 1: ", + dcc.Dropdown( + id="scatter-dropdown-range-1", + options=[{"label": key, "value": key} for key in ALL_KEYS], + value="agent_steps", + style={"width": "50%"} + ), + dcc.RangeSlider( + id='scatter-range-1', + min=0.0, + max=1.0, + step=0.05, + value=[0.0, 0.25] + ), + ]), + html.Br(), + html.Label([ + "Range 2: ", + dcc.Dropdown( + id="scatter-dropdown-range-2", + options=[{"label": key, "value": key} for key in ALL_KEYS], + value="cost", + style={"width": "50%"} + ), + dcc.RangeSlider( + id='scatter-range-2', + min=0.0, + max=1.0, + step=0.05, + value=[0.0, 0.95] + ), + ]), + dcc.Graph(id='scatter'), + html.Br(), + + #html.Label([ + # "X Axis: ", + # dcc.Dropdown( + # id="hyper-box-x", + # options=[{"label": key, "value": key} for key in ['cost', 'agent_steps']], + # value="agent_steps", + # style={"width": "50%"} + # ) + #]), + #dcc.Graph(id='hyper-box'), + + html.Br(), + html.Label([ + "Range 1: ", + dcc.Dropdown( + id="hyper-dropdown-range-1", + options=[{"label": key, "value": key} for key in ALL_KEYS], + value="environment/perf", + style={"width": "50%"} + ), + dcc.RangeSlider( + id='hyper-range-1', + min=0.0, + max=1.0, + step=0.05, + value=[0.8, 1.0] + ), + ]), + html.Br(), + html.Label([ + "Range 2: ", + dcc.Dropdown( + id="hyper-dropdown-range-2", + options=[{"label": key, "value": key} for key in ALL_KEYS], + value="agent_steps", + style={"width": "50%"} + ), + dcc.RangeSlider( + id='hyper-range-2', + min=0.0, + max=1.0, + step=0.05, + value=[0.0, 1.0] + ), + ]), + dcc.Graph(id='hyper'), + + + html.Br(), + html.Label([ + "Range 1: ", + dcc.Dropdown( + id="tsnee-dropdown-range-1", + options=[{"label": key, "value": key} for key in ALL_KEYS], + value="environment/perf", + style={"width": "50%"} + ), + dcc.RangeSlider( + id='tsnee-range-1', + min=0.0, + max=1.0, + step=0.05, + value=[0.5, 1.0] + ), + ]), + html.Br(), + html.Label([ + "Range 2: ", + dcc.Dropdown( + id="tsnee-dropdown-range-2", + options=[{"label": key, "value": key} for key in ALL_KEYS], + value="cost", + style={"width": "50%"} + ), + dcc.RangeSlider( + id='tsnee-range-2', + min=0.0, + max=1.0, + step=0.05, + value=[0.0, 1.0] + ), + ]), + dcc.Graph(id='tsnee'), + +], +style={"width": 1280} +) + +import plotly.express as px +import plotly.graph_objects as go +import numpy as np +from scipy.spatial.distance import cdist + +# Assuming EXPERIMENTS is your pandas DataFrame, and xkey, ykey, zkey are defined. +# Also assuming percentages for cutoffs, e.g.: +percentage1 = 5.0 # Percentage for XYZ distance threshold relative to plot diagonal in transformed space +percentage2 = 0.5 # Percentage for PCA distance threshold relative to PCA diagonal + +@app.callback( + Output("optimal", "figure"), + Input("optimal-dropdown-x", "value"), + Input("optimal-dropdown-y", "value"), + Input("optimal-dropdown-z", "value") +) +def update_optimal_plot(xkey, ykey, zkey): + all_x = EXPERIMENTS[xkey].values + all_y = EXPERIMENTS[ykey].values + all_z = EXPERIMENTS[zkey].values + all_pca1 = EXPERIMENTS['tsne1'].values + all_pca2 = EXPERIMENTS['tsne2'].values + all_env = EXPERIMENTS['env_name'].values# Handle transformed coordinates for XYZ (accounting for log axes) + trans_x = np.log10(all_x) # Assuming all_x > 0 + trans_y = np.log10(all_y) # Assuming all_y > 0 + trans_z = all_z + points_trans_xyz = np.column_stack((trans_x, trans_y, trans_z)) + + # Compute ranges in transformed space + range_tx = np.max(trans_x) - np.min(trans_x) + range_ty = np.max(trans_y) - np.min(trans_y) + range_tz = np.max(trans_z) - np.min(trans_z) + diagonal_xyz = np.sqrt(range_tx**2 + range_ty**2 + range_tz**2) + delta1 = (percentage1 / 100.0) * diagonal_xyz + + # For PCA (assuming linear scales) + points_pca = np.column_stack((all_pca1, all_pca2)) + range_p1 = np.max(all_pca1) - np.min(all_pca1) + range_p2 = np.max(all_pca2) - np.min(all_pca2) + diagonal_pca = np.sqrt(range_p1**2 + range_p2**2) + delta2 = (percentage2 / 100.0) * diagonal_pca + + # Create the base scatter plot + f = px.scatter_3d( + x=all_x, + y=all_y, + z=all_z, + color=all_env, + log_x=True, + log_y=True, + log_z=False, + color_discrete_sequence=roygbiv + ) + + # Compute pairwise L2 distances in transformed spaces + dists_xyz = cdist(points_trans_xyz, points_trans_xyz) + dists_pca = cdist(points_pca, points_pca) + + # Create boolean masks + xyz_mask = dists_xyz < delta1 + pca_mask = dists_pca < delta2 + # Use boolean array for upper triangle to avoid type mismatch + triu_mask = np.triu(np.ones_like(dists_xyz, dtype=bool), k=1) + + # Combine masks with boolean operations + mask = xyz_mask & pca_mask & triu_mask + + # Get indices of valid pairs + i, j = np.where(mask) + + # Collect line segment coordinates (in original space) + line_x = [] + line_y = [] + line_z = [] + for k in range(len(i)): + line_x.extend([all_x[i[k]], all_x[j[k]], None]) + line_y.extend([all_y[i[k]], all_y[j[k]], None]) + line_z.extend([all_z[i[k]], all_z[j[k]], None]) + + # Add the lines as a single trace + if line_x: + f.add_trace( + go.Scatter3d( + x=line_x, + y=line_y, + z=line_z, + mode='lines', + line=dict(color='rgba(255,255,255,0.25)', width=2), + showlegend=False + ) + ) + + # Show the figure + f.show() + + layout_dict = { + 'title': dict(text='Pareto', font=TITLE_FONT), + 'showlegend': True, + 'legend': dict(font=LEGEND_FONT), + 'plot_bgcolor': PLOT_BG_COLOR, + 'paper_bgcolor': PAPER_BG_COLOR, + 'width': 1280, + 'height': 720, + 'autosize': False, + 'scene': dict( + xaxis=dict( + title=dict(text=xkey, font=AXIS_FONT), + tickfont=TICK_FONT_3D, + type='log', + showgrid=True, + gridcolor=GRID_COLOR, + backgroundcolor=PLOT_BG_COLOR, + zeroline=False + ), + yaxis=dict( + title=dict(text=ykey, font=AXIS_FONT), + tickfont=TICK_FONT_3D, + type='log', + showgrid=True, + gridcolor=GRID_COLOR, + backgroundcolor=PLOT_BG_COLOR, + zeroline=False + ), + zaxis=dict( + title=dict(text=zkey, font=AXIS_FONT), + tickfont=TICK_FONT_3D, + type='linear', + showgrid=True, + gridcolor=GRID_COLOR, + backgroundcolor=PLOT_BG_COLOR, + zeroline=False + ), + bgcolor=PLOT_BG_COLOR, + ) + } + f.update_layout(**layout_dict) + return f + + +@app.callback( + Output("scatter", "figure"), + Input("scatter-dropdown-env", "value"), + Input("scatter-dropdown-x", "value"), + Input("scatter-checkbox-logx", "value"), + Input("scatter-dropdown-y", "value"), + Input("scatter-checkbox-logy", "value"), + Input("scatter-dropdown-color", "value"), + Input("scatter-dropdown-range-1", "value"), + Input("scatter-range-1", "value"), + Input("scatter-dropdown-range-2", "value"), + Input("scatter-range-2", "value"), +) +def update_scatter(env, xkey, logx, ykey, logy, zkey, range1_key, range1, range2_key, range2): + #env_data = EXPERIMENTS.loc[env] + if env == 'all': + env_data = EXPERIMENTS + else: + env_data = EXPERIMENTS[EXPERIMENTS['env_name'] == env] + + range1_mmin = min(EXPERIMENTS[range1_key]) + range1_mmax = max(EXPERIMENTS[range1_key]) + norm_range1 = (EXPERIMENTS[range1_key] - range1_mmin) / (range1_mmax - range1_mmin) + + range2_mmin = min(EXPERIMENTS[range2_key]) + range2_mmax = max(EXPERIMENTS[range2_key]) + norm_range2 = (EXPERIMENTS[range2_key] - range2_mmin) / (range2_mmax - range2_mmin) + + mask = (norm_range1 >= range1[0]) & (norm_range1 <= range1[1]) & (norm_range2 >= range2[0]) & (norm_range2 <= range2[1]) + + env_data = env_data[mask] + + x = env_data[xkey] + y = env_data[ykey] + z = env_data[zkey] + + if zkey == 'env_name': + f = px.scatter(x=x, y=y, color=z, color_discrete_sequence=roygbiv) + else: + mmin = min(z) + mmax = max(z) + thresh = np.geomspace(mmin, mmax, 8) + all_fx = [] + all_fy = [] + bin_label = [] + for j in range(7): + filter = (thresh[j] < z) & (z < thresh[j+1]) + if filter.sum() <= 2: + continue + + fx = x[filter] + fy = y[filter] + all_fx += fx.tolist() + all_fy += fy.tolist() + bin_label += [str(thresh[j])] * len(fx) + + f = px.scatter(x=all_fx, y=all_fy, color=bin_label, color_discrete_sequence=roygbiv) + + f.update_traces(marker_size=10) + layout_dict = { + 'title': dict(text='Experiments', font=TITLE_FONT), + 'showlegend': True, + 'legend': dict(font=LEGEND_FONT), + 'plot_bgcolor': PLOT_BG_COLOR, + 'paper_bgcolor': PAPER_BG_COLOR, + 'width': 1280, + 'height': 720, + 'autosize': False, + 'xaxis': dict( + title=dict(text=xkey, font=AXIS_FONT), + tickfont=TICK_FONT, + showgrid=False, + type='log' if 'log' in logx else 'linear', + ), + 'yaxis': dict( + title=dict(text=ykey, font=AXIS_FONT), + tickfont=TICK_FONT, + showgrid=False, + type='log' if 'log' in logy else 'linear', + ) + } + f.update_layout(**layout_dict) + return f + +@app.callback( + Output("hyper-box", "figure"), + Input("hyper-box-x", "value") +) +def update_hyper_box(x): + buckets = 4 + env_data = {} + for env in env_names: + #data = EXPERIMENTS.loc[env] + data = EXPERIMENTS[EXPERIMENTS['env_name'] == env] + steps = data['agent_steps'] + costs = data['cost'] + scores = data['environment/score'] + x_data = costs if x == 'cost' else steps + hyper_data = {} + env_data[env] = {'x': x_data, 'hypers': hyper_data} + for h in HYPERS: + hyper_data[h] = data[h] + all_x = [x for env in env_data for x in env_data[env]['x']] + x_min, x_max = min(all_x), max(all_x) + bucket_edges = np.linspace(x_min, x_max, buckets + 1) + bucket_centers = (bucket_edges[:-1] + bucket_edges[1:]) / 2 + heatmap_data = np.zeros((len(HYPERS), buckets)) + for i, hyper in enumerate(HYPERS): + for j in range(buckets): + bucket_means = [] + for env in env_data: + if hyper not in env_data[env]['hypers']: + continue + x_vals = np.array(env_data[env]['x']) + hyper_vals = np.array(env_data[env]['hypers'][hyper]) + idxs = (x_vals >= bucket_edges[j]) & (x_vals < bucket_edges[j+1]) + if np.any(idxs): + bucket_means.append(np.mean(hyper_vals[idxs])) + heatmap_data[i, j] = np.mean(bucket_means) if bucket_means else np.nan + heatmap_data = np.log(heatmap_data) + heatmap_data -= heatmap_data[:, 0, None] # Normalize + f = px.imshow(heatmap_data, x=bucket_centers, y=HYPERS, color_continuous_scale='Viridis', zmin=np.nanmin(heatmap_data), zmax=np.nanmax(heatmap_data), labels=dict(color="Value")) + layout_dict = { + 'title': dict(text="Hyperparameter Drift", font=TITLE_FONT), + 'showlegend': True, + 'legend': dict(font=LEGEND_FONT), + 'plot_bgcolor': PLOT_BG_COLOR, + 'paper_bgcolor': PAPER_BG_COLOR, + 'width': 1280, + 'height': 720, + 'autosize': False, + 'xaxis': dict( + title=dict(text=x.capitalize(), font=AXIS_FONT), + tickfont=TICK_FONT, + showgrid=False + ), + 'yaxis': dict( + title=dict(text="Hyperparameters", font=AXIS_FONT), + tickfont=TICK_FONT, + showgrid=False + ) + } + f.update_layout(**layout_dict) + return f + +@app.callback( + Output("hyper", "figure"), + Input("hyper-dropdown-range-1", "value"), + Input("hyper-range-1", "value"), + Input("hyper-dropdown-range-2", "value"), + Input("hyper-range-2", "value"), +) +def update_hyper_plot(xkey, range1, ykey, range2): + # Initialize figure + f = go.Figure() + f.update_layout( + title=dict(text='Hyperparameter Stable Range', font=TITLE_FONT), + xaxis=dict(title=dict(text='Value', font=AXIS_FONT), tickfont=TICK_FONT), + yaxis=dict(title=dict(text='Hyper', font=AXIS_FONT), tickfont=TICK_FONT), + showlegend=True, + legend=dict(font=LEGEND_FONT), + plot_bgcolor=PLOT_BG_COLOR, + paper_bgcolor=PAPER_BG_COLOR, + width=1280, + height=720, + autosize=False, + xaxis_type='log', + barmode='overlay', # Overlay bars instead of stacking + ) + f.update_xaxes(showgrid=False) + f.update_yaxes(showgrid=False) + + range1_mmin = min(EXPERIMENTS[xkey]) + range1_mmax = max(EXPERIMENTS[xkey]) + norm_x = (EXPERIMENTS[xkey] - range1_mmin) / (range1_mmax - range1_mmin) + range2_mmin = min(EXPERIMENTS[ykey]) + range2_mmax = max(EXPERIMENTS[ykey]) + norm_y = (EXPERIMENTS[ykey] - range2_mmin) / (range2_mmax - range2_mmin) + mask = (norm_x >= range1[0]) & (norm_x <= range1[1]) & (norm_y >= range2[0]) & (norm_y <= range2[1]) + filtered = EXPERIMENTS[mask] + + for i, env in enumerate(env_names): + #env_data = EXPERIMENTS.loc[env] + env_data = filtered[filtered['env_name'] == env] + if len(env_data) < 2: + continue + + steps = env_data['agent_steps'] + costs = env_data['cost'] + scores = env_data['environment/score'] + + max_score = max(scores) + max_steps = max(steps) + n = len(scores) + + + for k, hyper in enumerate(HYPERS): + y = env_data[hyper] + + ymin = min(y) + ymax = max(y) + f.add_trace( + go.Bar( + x=[ymax - ymin], + y=[hyper], # Hyperparameter as x-axis + base=ymin, + showlegend=False, + marker_color='#00f1f1', + opacity=0.25, + width=1.0, + orientation='h' + ) + ) + + return f + + +@app.callback( + Output("tsnee", "figure"), + Input("tsnee-dropdown-range-1", "value"), + Input("tsnee-range-1", "value"), + Input("tsnee-dropdown-range-2", "value"), + Input("tsnee-range-2", "value"), +) +def update_pca_plot(xkey, range1, ykey, range2): + # Initialize figure + f = go.Figure() + f.update_layout( + title=dict(text='Hyperparameter Stable Range', font=TITLE_FONT), + xaxis=dict(title=dict(text='Value', font=AXIS_FONT), tickfont=TICK_FONT), + yaxis=dict(title=dict(text='Hyper', font=AXIS_FONT), tickfont=TICK_FONT), + showlegend=True, + legend=dict(font=LEGEND_FONT), + plot_bgcolor=PLOT_BG_COLOR, + paper_bgcolor=PAPER_BG_COLOR, + width=1280, + height=720, + autosize=False, + xaxis_type='log', + barmode='overlay', # Overlay bars instead of stacking + ) + f.update_xaxes(showgrid=False) + f.update_yaxes(showgrid=False) + + range1_mmin = min(EXPERIMENTS[xkey]) + range1_mmax = max(EXPERIMENTS[xkey]) + norm_x = (EXPERIMENTS[xkey] - range1_mmin) / (range1_mmax - range1_mmin) + range2_mmin = min(EXPERIMENTS[ykey]) + range2_mmax = max(EXPERIMENTS[ykey]) + norm_y = (EXPERIMENTS[ykey] - range2_mmin) / (range2_mmax - range2_mmin) + mask = (norm_x >= range1[0]) & (norm_x <= range1[1]) & (norm_y >= range2[0]) & (norm_y <= range2[1]) + filtered = EXPERIMENTS[mask] + + f = px.scatter( + x=filtered['tsne1'], + y=filtered['tsne2'], + color=filtered['env_name'], + color_discrete_sequence=roygbiv + ) + + f.update_traces(marker_size=10) + layout_dict = { + 'title': dict(text='Experiments', font=TITLE_FONT), + 'showlegend': True, + 'legend': dict(font=LEGEND_FONT), + 'plot_bgcolor': PLOT_BG_COLOR, + 'paper_bgcolor': PAPER_BG_COLOR, + 'width': 1280, + 'height': 720, + 'autosize': False, + 'xaxis': dict( + title=dict(text='TSNE-1', font=AXIS_FONT), + tickfont=TICK_FONT, + showgrid=False + ), + 'yaxis': dict( + title=dict(text='TSNE-2', font=AXIS_FONT), + tickfont=TICK_FONT, + showgrid=False + ) + } + f.update_layout(**layout_dict) + return f + + +if __name__ == '__main__': + app.run(host='0.0.0.0', port=8000) diff --git a/profile_jax.py b/profile_jax.py new file mode 100644 index 000000000..8d5d51a43 --- /dev/null +++ b/profile_jax.py @@ -0,0 +1,67 @@ +import jax +import jax.numpy as jnp +from jax import jit, random, lax +import timeit + +INPUT_SIZE = 16 +HIDDEN_SIZE = 128 +OUTPUT_SIZE = 16 +B = 2048 +dtype = jnp.bfloat16 +inner_loops = 100 # Number of inner iterations to amortize overhead + +def init_params(key): + keys = random.split(key, 3) + # Use uniform initialization to match PyTorch's Kaiming uniform for ReLU + bound1 = jnp.sqrt(6 / INPUT_SIZE) + w1 = random.uniform(keys[0], shape=(INPUT_SIZE, HIDDEN_SIZE), minval=-bound1, maxval=bound1, dtype=dtype) + b1 = jnp.zeros(HIDDEN_SIZE, dtype=dtype) + bound2 = jnp.sqrt(6 / HIDDEN_SIZE) + w2 = random.uniform(keys[1], shape=(HIDDEN_SIZE, HIDDEN_SIZE), minval=-bound2, maxval=bound2, dtype=dtype) + b2 = jnp.zeros(HIDDEN_SIZE, dtype=dtype) + bound3 = jnp.sqrt(6 / HIDDEN_SIZE) + w3 = random.uniform(keys[2], shape=(HIDDEN_SIZE, OUTPUT_SIZE), minval=-bound3, maxval=bound3, dtype=dtype) + b3 = jnp.zeros(OUTPUT_SIZE, dtype=dtype) + return {'w1': w1, 'b1': b1, 'w2': w2, 'b2': b2, 'w3': w3, 'b3': b3} + +def model(params, x): + precision = lax.Precision.HIGH # Use HIGH precision for 4090 to leverage Tensor Cores + h = jnp.maximum(jnp.dot(x, params['w1'], precision=precision) + params['b1'], 0) + h = jnp.maximum(jnp.dot(h, params['w2'], precision=precision) + params['b2'], 0) + return jnp.dot(h, params['w3'], precision=precision) + params['b3'] + +# Manual FLOPs calculation (ignores bias adds and ReLUs as negligible) +flops_per_forward = ( + 2 * B * INPUT_SIZE * HIDDEN_SIZE + # First matmul + 2 * B * HIDDEN_SIZE * HIDDEN_SIZE + # Second matmul + 2 * B * HIDDEN_SIZE * OUTPUT_SIZE # Third matmul +) + +# Create concrete inputs +key = random.key(0) +params = init_params(key) +batch = random.normal(random.key(1), (B, INPUT_SIZE), dtype=dtype) + +# Define a jitted multi-step function with lax.scan for better optimization +@jit +def multi_step(params, batch): + def body_fun(carry, _): + y = model(params, batch) + carry += y.sum() # Forces computation without noise + return carry, None + carry, _ = lax.scan(body_fun, jnp.array(0.0, dtype=jnp.float32), None, length=inner_loops) + return carry + +# Warmup +for _ in range(10): + _ = multi_step(params, batch).block_until_ready() + +# Timing +def run(): + return multi_step(params, batch).block_until_ready() + +t = timeit.timeit(run, number=10) +cost = t / 10 / inner_loops # Average time per forward pass + +FLOPS = flops_per_forward / cost +print(f'TFLOPS: {FLOPS / 1e12:.2f}') diff --git a/profile_kernels.cu b/profile_kernels.cu new file mode 100644 index 000000000..19579197d --- /dev/null +++ b/profile_kernels.cu @@ -0,0 +1,1035 @@ +// profile_kernels.cu +// Minimal standalone profiler for CUDA kernels +// +// Without torch: nvcc -O3 -arch=sm_80 profile_kernels.cu -o profile_kernels -I. +// With torch: Build with cmake/pytorch and -DUSE_TORCH +// +// Run: ./profile_kernels + +#include +#include +#include + +#ifdef USE_TORCH +#include +#include +#include +#include "pufferlib/extensions/cuda/modules.cu" +#else +#include "pufferlib/extensions/cuda/kernels.cu" +#endif + +const int WARMUP_ITERS = 1000; +const int TIMING_ITERS = 10000; + +const int BR = 4096; // Rollout batch (no T dim) +const int BT = 512; // Train batch (with T dim) +const int T = 64; +const int H = 128; +const int A = 4; + +typedef void (*kernel_fn)(void*); + +void print_timing(const char* name, float ms, int N) { + printf(" %-18s %6.1f us %6.2f M elem/s\n", name, ms * 1000, N / ms / 1e3); +} + +void warmup_gpu() { + // Warm up GPU clocks with some busy work + float* dummy; + cudaMalloc(&dummy, 64 * 1024 * 1024); // 64MB + for (int i = 0; i < 100; i++) { + cudaMemset(dummy, 0, 64 * 1024 * 1024); + } + cudaDeviceSynchronize(); + cudaFree(dummy); +} + +float profile_kernel(kernel_fn fn, void* args) { + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + for (int i = 0; i < WARMUP_ITERS; ++i) { + fn(args); + cudaDeviceSynchronize(); + } + + cudaEventRecord(start); + for (int i = 0; i < TIMING_ITERS; ++i) { + fn(args); + } + cudaEventRecord(stop); + cudaEventSynchronize(stop); + + float ms = 0; + cudaEventElapsedTime(&ms, start, stop); + cudaEventDestroy(start); + cudaEventDestroy(stop); + + cudaDeviceSynchronize(); + c10::cuda::CUDACachingAllocator::emptyCache(); + return ms / TIMING_ITERS; +} + +#ifdef USE_TORCH +float profile_graph(kernel_fn fn, void* args) { + cudaDeviceSynchronize(); + + at::cuda::CUDAGraph cuda_graph; + at::cuda::CUDAStream current_stream = at::cuda::getCurrentCUDAStream(); + + at::cuda::CUDAStream warmup_stream = at::cuda::getStreamFromPool(); + at::cuda::setCurrentCUDAStream(warmup_stream); + for (int i = 0; i < WARMUP_ITERS; ++i) { + fn(args); + } + warmup_stream.synchronize(); + + at::cuda::CUDAStream cap_stream = at::cuda::getStreamFromPool(); + at::cuda::setCurrentCUDAStream(cap_stream); + cuda_graph.capture_begin(); + fn(args); + cuda_graph.capture_end(); + cap_stream.synchronize(); + + cudaDeviceSynchronize(); + at::cuda::setCurrentCUDAStream(current_stream); + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + cudaEventRecord(start); + for (int i = 0; i < TIMING_ITERS; ++i) { + cuda_graph.replay(); + } + cudaEventRecord(stop); + cudaEventSynchronize(stop); + + float ms = 0; + cudaEventElapsedTime(&ms, start, stop); + cudaEventDestroy(start); + cudaEventDestroy(stop); + + return ms / TIMING_ITERS; +} +#endif + +float rand1() { + return (float)rand() / RAND_MAX * 2.0f - 1.0f; +} + +// Fused mingru_gate for inference: takes combined (B, 1, 3*H) = [hidden, gate, proj] +// Outputs: out = sigmoid(proj) * mingru_out, next_state = mingru_out (for recurrence) +typedef struct { + float* state; // (B, 1, H) - input state + float* combined; // (B, 1, 3*H) = [hidden, gate, proj] + float* out; // (B, 1, H) - sigmoid(proj) * mingru_out + float* next_state; // (B, 1, H) - raw mingru_out + int B; + int H; +} MingruGateArgs; + +MingruGateArgs* create_mingrugateargs(int batch, int hidden) { + MingruGateArgs* args = (MingruGateArgs*)calloc(1, sizeof(MingruGateArgs)); + args->B = batch; + args->H = hidden; + + int N_state = batch * hidden; + int N_combined = batch * 3 * hidden; + + cudaMalloc(&args->state, N_state * sizeof(float)); + cudaMalloc(&args->combined, N_combined * sizeof(float)); + cudaMalloc(&args->out, N_state * sizeof(float)); + cudaMalloc(&args->next_state, N_state * sizeof(float)); + + float* state_buf = (float*)malloc(N_state * sizeof(float)); + float* combined_buf = (float*)malloc(N_combined * sizeof(float)); + + // Initialize state with positive values + for (int i = 0; i < N_state; ++i) { + state_buf[i] = fabsf(rand1()) + 0.1f; + } + // Initialize combined = [hidden, gate, proj] + for (int b = 0; b < batch; ++b) { + int base = b * 3 * hidden; + for (int h = 0; h < hidden; ++h) { + combined_buf[base + h] = rand1() * 5.0f; // hidden + combined_buf[base + hidden + h] = rand1() * 5.0f; // gate + combined_buf[base + 2 * hidden + h] = rand1() * 2.0f; // proj + } + } + + cudaMemcpy(args->state, state_buf, N_state * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(args->combined, combined_buf, N_combined * sizeof(float), cudaMemcpyHostToDevice); + + free(state_buf); + free(combined_buf); + return args; +} + +void free_mingrugateargs(MingruGateArgs* args) { + cudaFree(args->state); + cudaFree(args->combined); + cudaFree(args->out); + cudaFree(args->next_state); + free(args); +} + +void run_mingrugate_forward(MingruGateArgs* args) { + launch_mingru_gate_inference( + args->out, args->next_state, args->combined, args->state, + args->H, args->B, 0); +} + +#ifdef USE_TORCH + +typedef struct { + torch::Tensor state; // (B, 1, H) + torch::Tensor combined; // (B, 1, 3*H) + int B; + int H; +} MingruGateArgsTorch; + +MingruGateArgsTorch* create_mingrugateargs_torch(MingruGateArgs* raw) { + MingruGateArgsTorch* args = new MingruGateArgsTorch(); + args->B = raw->B; + args->H = raw->H; + + auto opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + args->state = torch::from_blob(raw->state, {raw->B, 1, raw->H}, opts); + args->combined = torch::from_blob(raw->combined, {raw->B, 1, 3 * raw->H}, opts); + + return args; +} + +void run_mingrugate_forward_torch(MingruGateArgsTorch* args) { + torch::NoGradGuard no_grad; + mingru_gate(args->state, args->combined); +} + +void run_mingrugate_forward_cpp(MingruGateArgsTorch* args) { + torch::NoGradGuard no_grad; + mingru_gate_cpp(args->state, args->combined); +} + +#endif + +void profile_mingrugate(int batch, int hidden) { + MingruGateArgs* args = create_mingrugateargs(batch, hidden); + + printf("mingru_gate (B=%d, H=%d, combined=%dx%d)\n", batch, hidden, batch, 3*hidden); + + float fwd_ms = profile_kernel((kernel_fn)run_mingrugate_forward, args); + print_timing("\tforward", fwd_ms, batch); + +#ifdef USE_TORCH + MingruGateArgsTorch* args_torch = create_mingrugateargs_torch(args); + + float fwd_torch_ms = profile_kernel((kernel_fn)run_mingrugate_forward_torch, args_torch); + print_timing("\tforward (torch)", fwd_torch_ms, batch); + + float fwd_cpp_ms = profile_kernel((kernel_fn)run_mingrugate_forward_cpp, args_torch); + print_timing("\tforward (cpp)", fwd_cpp_ms, batch); + + float fwd_graph_ms = profile_graph((kernel_fn)run_mingrugate_forward_cpp, args_torch); + print_timing("\tforward (graph)", fwd_graph_ms, batch); + + delete args_torch; +#endif + printf("\n"); + + free_mingrugateargs(args); +} + +typedef struct { + float* gate; + float* hidden; + float* log_coeffs; + float* log_values; + float* grad_log_coeffs; + float* grad_log_values; + float* grad_gate; + float* grad_hidden; + int N; +} LogCoeffsAndValuesArgs; + +LogCoeffsAndValuesArgs* create_logcoeffsandvaluesargs(int batch, int seq, int hidden) { + LogCoeffsAndValuesArgs* args = (LogCoeffsAndValuesArgs*)calloc(1, sizeof(LogCoeffsAndValuesArgs)); + args->N = batch*seq * hidden; + + cudaMalloc(&args->gate, args->N * sizeof(float)); + cudaMalloc(&args->hidden, args->N * sizeof(float)); + cudaMalloc(&args->log_coeffs, args->N * sizeof(float)); + cudaMalloc(&args->log_values, args->N * sizeof(float)); + cudaMalloc(&args->grad_gate, args->N * sizeof(float)); + cudaMalloc(&args->grad_hidden, args->N * sizeof(float)); + cudaMalloc(&args->grad_log_coeffs, args->N * sizeof(float)); + cudaMalloc(&args->grad_log_values, args->N * sizeof(float)); + + float* gate_buf = (float*)malloc(args->N * sizeof(float)); + float* hidden_buf = (float*)malloc(args->N * sizeof(float)); + float* grad_log_coeffs_buf = (float*)malloc(args->N * sizeof(float)); + float* grad_log_values_buf = (float*)malloc(args->N * sizeof(float)); + for (int i = 0; i < args->N; ++i) { + gate_buf[i] = rand1() * 5.0f; + hidden_buf[i] = rand1() * 5.0f; + grad_log_coeffs_buf[i] = rand1(); + grad_log_values_buf[i] = rand1(); + } + + cudaMemcpy(args->gate, gate_buf, args->N * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(args->hidden, hidden_buf, args->N * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(args->grad_log_coeffs, grad_log_coeffs_buf, args->N * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(args->grad_log_values, grad_log_values_buf, args->N * sizeof(float), cudaMemcpyHostToDevice); + + free(gate_buf); + free(hidden_buf); + free(grad_log_coeffs_buf); + free(grad_log_values_buf); + + return args; +} + +void free_logcoeffsandvaluesargs(LogCoeffsAndValuesArgs* args) { + cudaFree(args->gate); + cudaFree(args->hidden); + cudaFree(args->log_coeffs); + cudaFree(args->log_values); + cudaFree(args->grad_gate); + cudaFree(args->grad_hidden); + cudaFree(args->grad_log_coeffs); + cudaFree(args->grad_log_values); + free(args); +} + +void run_logcoeffsandvalues_forward(LogCoeffsAndValuesArgs* args) { + launch_log_coeffs_and_values( + args->log_coeffs, args->log_values, args->gate, args->hidden, args->N, 0); +} + +void run_logcoeffsandvalues_backward(LogCoeffsAndValuesArgs* args) { + launch_log_coeffs_and_values_backward( + args->grad_gate, args->grad_hidden, args->grad_log_coeffs, + args->grad_log_values, args->gate, args->hidden, args->N, 0); +} + +#ifdef USE_TORCH + +typedef struct { + torch::Tensor gate; + torch::Tensor hidden; + torch::Tensor grad_log_coeffs; + torch::Tensor grad_log_values; + torch::Tensor out_log_coeffs; + torch::Tensor out_log_values; + int N; +} LogCoeffsAndValuesArgsTorch; + +LogCoeffsAndValuesArgsTorch* create_logcoeffsandvaluesargs_torch(LogCoeffsAndValuesArgs* raw) { + LogCoeffsAndValuesArgsTorch* args = new LogCoeffsAndValuesArgsTorch(); + args->N = raw->N; + + auto opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + args->gate = torch::from_blob(raw->gate, {raw->N}, opts).requires_grad_(true); + args->hidden = torch::from_blob(raw->hidden, {raw->N}, opts).requires_grad_(true); + args->grad_log_coeffs = torch::from_blob(raw->grad_log_coeffs, {raw->N}, opts); + args->grad_log_values = torch::from_blob(raw->grad_log_values, {raw->N}, opts); + + return args; +} + +void run_logcoeffsandvalues_forward_torch(LogCoeffsAndValuesArgsTorch* args) { + torch::NoGradGuard no_grad; + log_coeffs_and_values(args->gate, args->hidden); +} + +void run_logcoeffsandvalues_backward_torch(LogCoeffsAndValuesArgsTorch* args) { + args->gate.mutable_grad() = torch::Tensor(); + args->hidden.mutable_grad() = torch::Tensor(); + torch::autograd::backward( + {args->out_log_coeffs, args->out_log_values}, + {args->grad_log_coeffs, args->grad_log_values}, + /*retain_graph=*/true); +} + +void run_logcoeffsandvalues_forward_cpp(LogCoeffsAndValuesArgsTorch* args) { + torch::NoGradGuard no_grad; + log_coeffs_and_values_cpp(args->gate, args->hidden); +} + +#endif + +void profile_logcoeffsandvalues(int batch, int seq, int hidden) { + LogCoeffsAndValuesArgs* args = create_logcoeffsandvaluesargs(batch, seq, hidden); + + printf("log_coeffs_and_values (N=%d, %dx%dx%d)\n", args->N, batch, seq, hidden); + + float fwd_ms = profile_kernel((kernel_fn)run_logcoeffsandvalues_forward, args); + print_timing("\tforward", fwd_ms, batch*seq); + + float bwd_ms = profile_kernel((kernel_fn)run_logcoeffsandvalues_backward, args); + print_timing("\tbackward", bwd_ms, batch*seq); + +#ifdef USE_TORCH + LogCoeffsAndValuesArgsTorch* args_torch = create_logcoeffsandvaluesargs_torch(args); + + float fwd_torch_ms = profile_kernel((kernel_fn)run_logcoeffsandvalues_forward_torch, args_torch); + print_timing("\tforward (torch)", fwd_torch_ms, batch*seq); + + auto kernel_outputs = log_coeffs_and_values(args_torch->gate, args_torch->hidden); + args_torch->out_log_coeffs = kernel_outputs[0]; + args_torch->out_log_values = kernel_outputs[1]; + + float bwd_torch_ms = profile_kernel((kernel_fn)run_logcoeffsandvalues_backward_torch, args_torch); + print_timing("\tbackward (torch)", bwd_torch_ms, batch*seq); + + float fwd_cpp_ms = profile_kernel((kernel_fn)run_logcoeffsandvalues_forward_cpp, args_torch); + print_timing("\tforward (cpp)", fwd_cpp_ms, batch*seq); + + auto cpp_outputs = log_coeffs_and_values_cpp(args_torch->gate, args_torch->hidden); + args_torch->out_log_coeffs = cpp_outputs[0]; + args_torch->out_log_values = cpp_outputs[1]; + + float bwd_cpp_ms = profile_kernel((kernel_fn)run_logcoeffsandvalues_backward_torch, args_torch); + print_timing("\tbackward (cpp)", bwd_cpp_ms, batch*seq); + + float fwd_graph_ms = profile_graph((kernel_fn)run_logcoeffsandvalues_forward_cpp, args_torch); + print_timing("\tforward (graph)", fwd_graph_ms, batch*seq); + + delete args_torch; +#endif + printf("\n"); + + free_logcoeffsandvaluesargs(args); +} + +typedef struct { + float* x; + float* out; + double* s_buf; + float* grad_x; + float* grad_out; + int B; + int T; + int H; + int N; +} LogcumsumexpArgs; + +LogcumsumexpArgs* create_logcumsumexpargs(int batch, int seq, int hidden) { + LogcumsumexpArgs* args = (LogcumsumexpArgs*)calloc(1, sizeof(LogcumsumexpArgs)); + args->B = batch; + args->T = seq; + args->H = hidden; + args->N = batch*seq * hidden; + + cudaMalloc(&args->x, args->N * sizeof(float)); + cudaMalloc(&args->out, args->N * sizeof(float)); + cudaMalloc(&args->s_buf, args->N * sizeof(double)); + cudaMalloc(&args->grad_x, args->N * sizeof(float)); + cudaMalloc(&args->grad_out, args->N * sizeof(float)); + + float* buf = (float*)malloc(args->N * sizeof(float) * 2); + float* x_buf = buf; + float* grad_out_buf = buf + args->N; + for (int i = 0; i < args->N; ++i) { + x_buf[i] = rand1(); + grad_out_buf[i] = rand1(); + } + + cudaMemcpy(args->x, x_buf, args->N * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(args->grad_out, grad_out_buf, args->N * sizeof(float), cudaMemcpyHostToDevice); + + free(buf); + return args; +} + +void free_logcumsumexpargs(LogcumsumexpArgs* args) { + cudaFree(args->x); + cudaFree(args->out); + cudaFree(args->s_buf); + cudaFree(args->grad_x); + cudaFree(args->grad_out); + free(args); +} + +void run_logcumsumexp_forward(LogcumsumexpArgs* args) { + launch_logcumsumexp_forward( + args->out, args->s_buf, args->x, args->T, args->H, args->B, 0); +} + +void run_logcumsumexp_backward(LogcumsumexpArgs* args) { + launch_logcumsumexp_backward( + args->grad_x, args->grad_out, args->x, args->s_buf, args->T, args->H, args->B, 0); +} + +#ifdef USE_TORCH + +typedef struct { + torch::Tensor x; + torch::Tensor out; + torch::Tensor grad_out; + int N; +} LogcumsumexpArgsTorch; + +LogcumsumexpArgsTorch* create_logcumsumexpargs_torch(LogcumsumexpArgs* raw) { + LogcumsumexpArgsTorch* args = new LogcumsumexpArgsTorch(); + args->N = raw->N; + + auto opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + args->x = torch::from_blob(raw->x, {raw->B, raw->T, raw->H}, opts).requires_grad_(true); + args->grad_out = torch::from_blob(raw->grad_out, {raw->B, raw->T, raw->H}, opts); + + return args; +} + +void run_logcumsumexp_forward_torch(LogcumsumexpArgsTorch* args) { + torch::NoGradGuard no_grad; + logcumsumexp_cuda(args->x); +} + +void run_logcumsumexp_backward_torch(LogcumsumexpArgsTorch* args) { + args->x.mutable_grad() = torch::Tensor(); + args->out.backward(args->grad_out, /*retain_graph=*/true); +} + +void run_logcumsumexp_forward_cpp(LogcumsumexpArgsTorch* args) { + torch::NoGradGuard no_grad; + logcumsumexp_cpp(args->x); +} + +#endif + +void profile_logcumsumexp(int batch, int seq, int hidden) { + LogcumsumexpArgs* args = create_logcumsumexpargs(batch, seq, hidden); + + printf("logcumsumexp (N=%d, %dx%dx%d)\n", args->N, batch, seq, hidden); + + float fwd_ms = profile_kernel((kernel_fn)run_logcumsumexp_forward, args); + print_timing("\tforward", fwd_ms, batch*seq); + + float bwd_ms = profile_kernel((kernel_fn)run_logcumsumexp_backward, args); + print_timing("\tbackward", bwd_ms, batch*seq); + +#ifdef USE_TORCH + LogcumsumexpArgsTorch* args_torch = create_logcumsumexpargs_torch(args); + + float fwd_torch_ms = profile_kernel((kernel_fn)run_logcumsumexp_forward_torch, args_torch); + print_timing("\tforward (torch)", fwd_torch_ms, batch*seq); + + args_torch->out = logcumsumexp_cuda(args_torch->x); + + float bwd_torch_ms = profile_kernel((kernel_fn)run_logcumsumexp_backward_torch, args_torch); + print_timing("\tbackward (torch)", bwd_torch_ms, batch*seq); + + float fwd_cpp_ms = profile_kernel((kernel_fn)run_logcumsumexp_forward_cpp, args_torch); + print_timing("\tforward (cpp)", fwd_cpp_ms, batch*seq); + + args_torch->out = logcumsumexp_cpp(args_torch->x); + + float bwd_cpp_ms = profile_kernel((kernel_fn)run_logcumsumexp_backward_torch, args_torch); + print_timing("\tbackward (cpp)", bwd_cpp_ms, batch*seq); + + float fwd_graph_ms = profile_graph((kernel_fn)run_logcumsumexp_forward_cpp, args_torch); + print_timing("\tforward (graph)", fwd_graph_ms, batch*seq); + + delete args_torch; +#endif + printf("\n"); + + free_logcumsumexpargs(args); +} + +// New fused_scan takes combined (B, T, 3*H) = [hidden, gate, proj] and state (B, 1, H) +// Outputs: out (B, T, H) = sigmoid(proj) * scan_result, next_state (B, 1, H) +typedef struct { + float* combined; // (B, T, 3*H) = [hidden, gate, proj] + float* state; // (B, 1, H) + float* out; // (B, T, H) + float* next_state; // (B, 1, H) + float* a_star; // (B, T+1, H) + float* s_vals; // (B, T+1, H) + float* log_values_buf; // (B, T+1, H) + float* grad_combined; // (B, T, 3*H) + float* grad_state; // (B, 1, H) + float* grad_out; // (B, T, H) + float* grad_next_state;// (B, 1, H) + int B; + int T; + int H; + int N; +} FusedScanArgs; + +FusedScanArgs* create_fusedscanargs(int batch, int seq, int hidden) { + FusedScanArgs* args = (FusedScanArgs*)calloc(1, sizeof(FusedScanArgs)); + args->B = batch; + args->T = seq; + args->H = hidden; + args->N = batch * seq * hidden; + + int N_combined = batch * seq * 3 * hidden; + int N_state = batch * hidden; + int N_buf = batch * (seq + 1) * hidden; + + cudaMalloc(&args->combined, N_combined * sizeof(float)); + cudaMalloc(&args->state, N_state * sizeof(float)); + cudaMalloc(&args->out, args->N * sizeof(float)); + cudaMalloc(&args->next_state, N_state * sizeof(float)); + cudaMalloc(&args->a_star, N_buf * sizeof(float)); + cudaMalloc(&args->s_vals, N_buf * sizeof(float)); + cudaMalloc(&args->log_values_buf, N_buf * sizeof(float)); + cudaMalloc(&args->grad_combined, N_combined * sizeof(float)); + cudaMalloc(&args->grad_state, N_state * sizeof(float)); + cudaMalloc(&args->grad_out, args->N * sizeof(float)); + cudaMalloc(&args->grad_next_state, N_state * sizeof(float)); + + // Allocate and initialize host buffers + float* combined_buf = (float*)malloc(N_combined * sizeof(float)); + float* state_buf = (float*)malloc(N_state * sizeof(float)); + float* grad_out_buf = (float*)malloc(args->N * sizeof(float)); + float* grad_next_state_buf = (float*)malloc(N_state * sizeof(float)); + + // Initialize combined = [hidden, gate, proj] with reasonable values + for (int b = 0; b < batch; ++b) { + for (int t = 0; t < seq; ++t) { + for (int h = 0; h < hidden; ++h) { + int base = b * seq * 3 * hidden + t * 3 * hidden; + combined_buf[base + h] = rand1() * 5.0f; // hidden + combined_buf[base + hidden + h] = rand1() * 5.0f; // gate + combined_buf[base + 2 * hidden + h] = rand1() * 2.0f; // proj + } + } + } + // Initialize state with positive values (will be log'd) + for (int i = 0; i < N_state; ++i) { + state_buf[i] = fabsf(rand1()) + 0.1f; + } + // Initialize gradients + for (int i = 0; i < args->N; ++i) { + grad_out_buf[i] = rand1(); + } + for (int i = 0; i < N_state; ++i) { + grad_next_state_buf[i] = rand1(); + } + + cudaMemcpy(args->combined, combined_buf, N_combined * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(args->state, state_buf, N_state * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(args->grad_out, grad_out_buf, args->N * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(args->grad_next_state, grad_next_state_buf, N_state * sizeof(float), cudaMemcpyHostToDevice); + + free(combined_buf); + free(state_buf); + free(grad_out_buf); + free(grad_next_state_buf); + return args; +} + +void free_fusedscanargs(FusedScanArgs* args) { + cudaFree(args->combined); + cudaFree(args->state); + cudaFree(args->out); + cudaFree(args->next_state); + cudaFree(args->a_star); + cudaFree(args->s_vals); + cudaFree(args->log_values_buf); + cudaFree(args->grad_combined); + cudaFree(args->grad_state); + cudaFree(args->grad_out); + cudaFree(args->grad_next_state); + free(args); +} + +void run_fusedscan_forward(FusedScanArgs* args) { + launch_fused_scan_forward( + args->out, args->next_state, + args->a_star, args->s_vals, args->log_values_buf, + args->combined, args->state, + args->T, args->H, args->B, 0); +} + +void run_fusedscan_backward(FusedScanArgs* args) { + launch_fused_scan_backward( + args->grad_combined, args->grad_state, + args->grad_out, args->grad_next_state, + args->combined, args->state, + args->a_star, args->s_vals, args->log_values_buf, + args->T, args->H, args->B, 0); +} + +#ifdef USE_TORCH + +typedef struct { + torch::Tensor combined; // (B, T, 3*H) + torch::Tensor state; // (B, 1, H) + torch::Tensor out; // (B, T, H) + torch::Tensor next_state; // (B, 1, H) + torch::Tensor grad_out; // (B, T, H) + torch::Tensor grad_next_state; // (B, 1, H) + int B; + int T; + int H; +} FusedScanArgsTorch; + +FusedScanArgsTorch* create_fusedscanargs_torch(FusedScanArgs* raw) { + FusedScanArgsTorch* args = new FusedScanArgsTorch(); + args->B = raw->B; + args->T = raw->T; + args->H = raw->H; + + auto opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + args->combined = torch::from_blob(raw->combined, {raw->B, raw->T, 3 * raw->H}, opts).requires_grad_(true); + args->state = torch::from_blob(raw->state, {raw->B, 1, raw->H}, opts).requires_grad_(true); + args->grad_out = torch::from_blob(raw->grad_out, {raw->B, raw->T, raw->H}, opts); + args->grad_next_state = torch::from_blob(raw->grad_next_state, {raw->B, 1, raw->H}, opts); + + return args; +} + +void run_fusedscan_forward_torch(FusedScanArgsTorch* args) { + torch::NoGradGuard no_grad; + fused_scan(args->combined, args->state); +} + +void run_fusedscan_backward_torch(FusedScanArgsTorch* args) { + args->combined.mutable_grad() = torch::Tensor(); + args->state.mutable_grad() = torch::Tensor(); + torch::autograd::backward( + {args->out, args->next_state}, + {args->grad_out, args->grad_next_state}, + /*retain_graph=*/true); +} + +void run_fusedscan_forward_cpp(FusedScanArgsTorch* args) { + torch::NoGradGuard no_grad; + fused_scan_cpp(args->combined, args->state); +} + +#endif + +void profile_fusedscan(int batch, int seq, int hidden) { + FusedScanArgs* args = create_fusedscanargs(batch, seq, hidden); + + printf("fused_scan (N=%d, %dx%dx%d, combined=%dx%dx%d)\n", + args->N, batch, seq, hidden, batch, seq, 3*hidden); + + float fwd_ms = profile_kernel((kernel_fn)run_fusedscan_forward, args); + print_timing("\tforward", fwd_ms, batch*seq); + + float bwd_ms = profile_kernel((kernel_fn)run_fusedscan_backward, args); + print_timing("\tbackward", bwd_ms, batch*seq); + +#ifdef USE_TORCH + FusedScanArgsTorch* args_torch = create_fusedscanargs_torch(args); + + float fwd_torch_ms = profile_kernel((kernel_fn)run_fusedscan_forward_torch, args_torch); + print_timing("\tforward (torch)", fwd_torch_ms, batch*seq); + + auto scan_out = fused_scan(args_torch->combined, args_torch->state); + args_torch->out = scan_out[0]; + args_torch->next_state = scan_out[1]; + + float bwd_torch_ms = profile_kernel((kernel_fn)run_fusedscan_backward_torch, args_torch); + print_timing("\tbackward (torch)", bwd_torch_ms, batch*seq); + + float fwd_cpp_ms = profile_kernel((kernel_fn)run_fusedscan_forward_cpp, args_torch); + print_timing("\tforward (cpp)", fwd_cpp_ms, batch*seq); + + auto scan_out_cpp = fused_scan_cpp(args_torch->combined, args_torch->state); + args_torch->out = scan_out_cpp[0]; + args_torch->next_state = scan_out_cpp[1]; + + float bwd_cpp_ms = profile_kernel((kernel_fn)run_fusedscan_backward_torch, args_torch); + print_timing("\tbackward (cpp)", bwd_cpp_ms, batch*seq); + + float fwd_graph_ms = profile_graph((kernel_fn)run_fusedscan_forward_cpp, args_torch); + print_timing("\tforward (graph)", fwd_graph_ms, batch*seq); + + delete args_torch; +#endif + printf("\n"); + + free_fusedscanargs(args); +} + +typedef struct { + float* logits; + float* values_pred; + int64_t* actions; + float* old_logprobs; + float* advantages; + float* prio; + float* values; + float* returns; + float* adv_mean; + float* adv_std; + float* loss; + double* saved_for_backward; + float* grad_logits; + float* grad_values_pred; + float* grad_loss; + float clip_coef; + float vf_clip_coef; + float vf_coef; + float ent_coef; + int N; + int T; + int A; +} PPOLossArgs; + +PPOLossArgs* create_ppolossargs(int batch, int seq, int actions) { + PPOLossArgs* args = (PPOLossArgs*)calloc(1, sizeof(PPOLossArgs)); + args->N = batch; + args->T = seq; + args->A = actions; + + int NT = batch*seq; + int NTA = batch*seq * actions; + + cudaMalloc(&args->logits, NTA * sizeof(float)); + cudaMalloc(&args->values_pred, NT * sizeof(float)); + cudaMalloc(&args->actions, NT * sizeof(int64_t)); + cudaMalloc(&args->old_logprobs, NT * sizeof(float)); + cudaMalloc(&args->advantages, NT * sizeof(float)); + cudaMalloc(&args->prio, batch * sizeof(float)); + cudaMalloc(&args->values, NT * sizeof(float)); + cudaMalloc(&args->returns, NT * sizeof(float)); + cudaMalloc(&args->adv_mean, sizeof(float)); + cudaMalloc(&args->adv_std, sizeof(float)); + cudaMalloc(&args->loss, sizeof(float)); + cudaMalloc(&args->saved_for_backward, NT * 5 * sizeof(double)); + cudaMalloc(&args->grad_logits, NTA * sizeof(float)); + cudaMalloc(&args->grad_values_pred, NT * sizeof(float)); + cudaMalloc(&args->grad_loss, sizeof(float)); + + float* buf = (float*)malloc((NTA + NT * 5 + batch) * sizeof(float)); + float* logits_buf = buf; + float* values_pred_buf = buf + NTA; + float* old_logprobs_buf = buf + NTA + NT; + float* advantages_buf = buf + NTA + NT * 2; + float* values_buf = buf + NTA + NT * 3; + float* returns_buf = buf + NTA + NT * 4; + float* prio_buf = buf + NTA + NT * 5; + + int64_t* actions_buf = (int64_t*)malloc(NT * sizeof(int64_t)); + + float adv_sum = 0.0f, adv_sq_sum = 0.0f; + for (int i = 0; i < NT; ++i) { + advantages_buf[i] = rand1(); + adv_sum += advantages_buf[i]; + adv_sq_sum += advantages_buf[i] * advantages_buf[i]; + } + float adv_mean = adv_sum / NT; + float adv_std = sqrtf(adv_sq_sum / NT - adv_mean * adv_mean); + + for (int i = 0; i < NTA; ++i) { + logits_buf[i] = rand1() * 2.0f; + } + for (int i = 0; i < NT; ++i) { + values_pred_buf[i] = rand1(); + actions_buf[i] = rand() % actions; + old_logprobs_buf[i] = rand1() * 2.0f; + values_buf[i] = rand1(); + returns_buf[i] = rand1(); + } + for (int i = 0; i < batch; ++i) { + prio_buf[i] = (float)rand() / RAND_MAX; + } + + cudaMemcpy(args->logits, logits_buf, NTA * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(args->values_pred, values_pred_buf, NT * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(args->actions, actions_buf, NT * sizeof(int64_t), cudaMemcpyHostToDevice); + cudaMemcpy(args->old_logprobs, old_logprobs_buf, NT * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(args->advantages, advantages_buf, NT * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(args->prio, prio_buf, batch * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(args->values, values_buf, NT * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(args->returns, returns_buf, NT * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(args->adv_mean, &adv_mean, sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(args->adv_std, &adv_std, sizeof(float), cudaMemcpyHostToDevice); + + float grad_loss_val = 1.0f; + cudaMemcpy(args->grad_loss, &grad_loss_val, sizeof(float), cudaMemcpyHostToDevice); + + args->clip_coef = 0.1f; + args->vf_clip_coef = 0.1f; + args->vf_coef = 0.5f; + args->ent_coef = 0.01f; + + free(buf); + free(actions_buf); + return args; +} + +void free_ppolossargs(PPOLossArgs* args) { + cudaFree(args->logits); + cudaFree(args->values_pred); + cudaFree(args->actions); + cudaFree(args->old_logprobs); + cudaFree(args->advantages); + cudaFree(args->prio); + cudaFree(args->values); + cudaFree(args->returns); + cudaFree(args->adv_mean); + cudaFree(args->adv_std); + cudaFree(args->loss); + cudaFree(args->saved_for_backward); + cudaFree(args->grad_logits); + cudaFree(args->grad_values_pred); + cudaFree(args->grad_loss); + free(args); +} + +void run_ppoloss_forward(PPOLossArgs* args) { + launch_ppo_loss_forward( + args->loss, args->saved_for_backward, + args->logits, args->values_pred, args->actions, + args->old_logprobs, args->advantages, args->prio, + args->values, args->returns, args->adv_mean, args->adv_std, + args->clip_coef, args->vf_clip_coef, args->vf_coef, args->ent_coef, + args->T, args->A, args->N, 0); +} + +void run_ppoloss_backward(PPOLossArgs* args) { + launch_ppo_loss_backward( + args->grad_logits, args->grad_values_pred, args->grad_loss, + args->logits, args->actions, args->old_logprobs, + args->advantages, args->prio, args->values, args->returns, + args->saved_for_backward, args->adv_mean, args->adv_std, + args->clip_coef, args->vf_clip_coef, args->vf_coef, args->ent_coef, + args->T, args->A, args->N, 0); +} + +#ifdef USE_TORCH + +typedef struct { + torch::Tensor logits; + torch::Tensor values_pred; + torch::Tensor actions; + torch::Tensor old_logprobs; + torch::Tensor advantages; + torch::Tensor prio; + torch::Tensor values; + torch::Tensor returns; + torch::Tensor adv_mean; + torch::Tensor adv_std; + torch::Tensor loss; + float clip_coef; + float vf_clip_coef; + float vf_coef; + float ent_coef; + int N; + int T; + int A; +} PPOLossArgsTorch; + +PPOLossArgsTorch* create_ppolossargs_torch(PPOLossArgs* raw) { + PPOLossArgsTorch* args = new PPOLossArgsTorch(); + args->N = raw->N; + args->T = raw->T; + args->A = raw->A; + args->clip_coef = raw->clip_coef; + args->vf_clip_coef = raw->vf_clip_coef; + args->vf_coef = raw->vf_coef; + args->ent_coef = raw->ent_coef; + + auto opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + auto opts_int = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); + + args->logits = torch::from_blob(raw->logits, {raw->N, raw->T, raw->A}, opts).requires_grad_(true); + args->values_pred = torch::from_blob(raw->values_pred, {raw->N, raw->T}, opts).requires_grad_(true); + args->actions = torch::from_blob(raw->actions, {raw->N, raw->T}, opts_int); + args->old_logprobs = torch::from_blob(raw->old_logprobs, {raw->N, raw->T}, opts); + args->advantages = torch::from_blob(raw->advantages, {raw->N, raw->T}, opts); + args->prio = torch::from_blob(raw->prio, {raw->N}, opts); + args->values = torch::from_blob(raw->values, {raw->N, raw->T}, opts); + args->returns = torch::from_blob(raw->returns, {raw->N, raw->T}, opts); + args->adv_mean = torch::from_blob(raw->adv_mean, {1}, opts); + args->adv_std = torch::from_blob(raw->adv_std, {1}, opts); + + return args; +} + +void run_ppoloss_forward_torch(PPOLossArgsTorch* args) { + torch::NoGradGuard no_grad; + fused_ppo_loss( + args->logits, args->values_pred, args->actions, + args->old_logprobs, args->advantages, args->prio, + args->values, args->returns, args->adv_mean, args->adv_std, + args->clip_coef, args->vf_clip_coef, args->vf_coef, args->ent_coef); +} + +void run_ppoloss_backward_torch(PPOLossArgsTorch* args) { + args->logits.mutable_grad() = torch::Tensor(); + args->values_pred.mutable_grad() = torch::Tensor(); + args->loss.backward({}, /*retain_graph=*/true); +} + +void run_ppoloss_forward_cpp(PPOLossArgsTorch* args) { + torch::NoGradGuard no_grad; + fused_ppo_loss_cpp( + args->logits, args->values_pred, args->actions, + args->old_logprobs, args->advantages, args->prio, + args->values, args->returns, args->adv_mean, args->adv_std, + args->clip_coef, args->vf_clip_coef, args->vf_coef, args->ent_coef); +} + +#endif + +void profile_ppoloss(int batch, int seq, int actions) { + PPOLossArgs* args = create_ppolossargs(batch, seq, actions); + + int NT = batch*seq; + printf("ppo_loss (NT=%d, %dx%d, A=%d)\n", NT, batch, seq, actions); + + float fwd_ms = profile_kernel((kernel_fn)run_ppoloss_forward, args); + print_timing("\tforward", fwd_ms, NT); + + float bwd_ms = profile_kernel((kernel_fn)run_ppoloss_backward, args); + print_timing("\tbackward", bwd_ms, NT); + +#ifdef USE_TORCH + PPOLossArgsTorch* args_torch = create_ppolossargs_torch(args); + + float fwd_torch_ms = profile_kernel((kernel_fn)run_ppoloss_forward_torch, args_torch); + print_timing("\tforward (torch)", fwd_torch_ms, NT); + + args_torch->loss = fused_ppo_loss( + args_torch->logits, args_torch->values_pred, args_torch->actions, + args_torch->old_logprobs, args_torch->advantages, args_torch->prio, + args_torch->values, args_torch->returns, args_torch->adv_mean, args_torch->adv_std, + args_torch->clip_coef, args_torch->vf_clip_coef, args_torch->vf_coef, args_torch->ent_coef)[0]; + + float bwd_torch_ms = profile_kernel((kernel_fn)run_ppoloss_backward_torch, args_torch); + print_timing("\tbackward (torch)", bwd_torch_ms, NT); + + float fwd_cpp_ms = profile_kernel((kernel_fn)run_ppoloss_forward_cpp, args_torch); + print_timing("\tforward (cpp)", fwd_cpp_ms, NT); + + args_torch->loss = fused_ppo_loss_cpp( + args_torch->logits, args_torch->values_pred, args_torch->actions, + args_torch->old_logprobs, args_torch->advantages, args_torch->prio, + args_torch->values, args_torch->returns, args_torch->adv_mean, args_torch->adv_std, + args_torch->clip_coef, args_torch->vf_clip_coef, args_torch->vf_coef, args_torch->ent_coef); + + float bwd_cpp_ms = profile_kernel((kernel_fn)run_ppoloss_backward_torch, args_torch); + print_timing("\tbackward (cpp)", bwd_cpp_ms, NT); + + float fwd_graph_ms = profile_graph((kernel_fn)run_ppoloss_forward_cpp, args_torch); + print_timing("\tforward (graph)", fwd_graph_ms, NT); + + delete args_torch; +#endif + printf("\n"); + + free_ppolossargs(args); +} + +int main(int argc, char** argv) { + warmup_gpu(); + profile_mingrugate(BR, H); + profile_logcoeffsandvalues(BT, T, H); + profile_logcumsumexp(BT, T, H); + profile_fusedscan(BT, T, H); + profile_ppoloss(BT, T, A); + return 0; +} diff --git a/profile_torch.py b/profile_torch.py new file mode 100644 index 000000000..656e7447a --- /dev/null +++ b/profile_torch.py @@ -0,0 +1,65 @@ +import torch +from torch import nn +from torch.utils.benchmark import Timer +from torch.utils.flop_counter import FlopCounterMode + +from torch.backends import cudnn +cudnn.benchmark = True +cudnn.deterministic = False +cudnn.benchmark_limit = 32 + +torch.set_float32_matmul_precision('high') +torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True + +INPUT_SIZE = 128 +HIDDEN_SIZE1 = 128 +HIDDEN_SIZE2 = 512 +OUTPUT_SIZE = 128 +B = 8192 +dtype = torch.bfloat16 +inner_loops = 100 # Number of inner iterations to amortize overhead + +# Define the model with explicit Kaiming uniform initialization to match JAX +model = torch.nn.Sequential( + torch.nn.Linear(INPUT_SIZE, HIDDEN_SIZE1), + torch.nn.ReLU(), + torch.nn.Linear(HIDDEN_SIZE1, HIDDEN_SIZE2), + torch.nn.ReLU(), + torch.nn.Linear(HIDDEN_SIZE2, OUTPUT_SIZE), +).cuda().to(dtype) + +# Create input batch +batch = torch.randn(B, INPUT_SIZE).cuda().to(dtype) + +# Define a multi-step function to run multiple forwards in one compiled graph +@torch.compile(mode='max-autotune') +def multi_step(model, batch, inner_loops): + with torch.no_grad(): + carry = torch.tensor(0.0, dtype=torch.float32, device='cuda') + for i in range(inner_loops): + y = model(batch) + carry = carry + y.sum() + + return carry + +# Manual FLOPs calculation to match JAX (ignores bias adds and ReLUs as negligible) +flops = ( + 2 * B * INPUT_SIZE * HIDDEN_SIZE1 + + 2 * B * HIDDEN_SIZE1 * HIDDEN_SIZE2 + + 2 * B * HIDDEN_SIZE2 * OUTPUT_SIZE +) + +# Warmup +for _ in range(10): + _ = multi_step(model, batch, inner_loops) + +# Timing +timer = Timer( + stmt='multi_step(model, batch, inner_loops)', + globals={'multi_step': multi_step, 'model': model, 'batch': batch, 'inner_loops': inner_loops} +) +output = timer.timeit(50) + +cost = output.mean / inner_loops # Average time per forward pass (fixed from times[0] to mean) +FLOPS = flops / cost +print(f'TFLOPS: {FLOPS / 1e12:.2f}') diff --git a/pufferlib/config/cogames.ini b/pufferlib/config/cogames.ini index 674b48e2e..b50ad3bef 100644 --- a/pufferlib/config/cogames.ini +++ b/pufferlib/config/cogames.ini @@ -5,7 +5,7 @@ policy_name = Policy rnn_name = Recurrent [vec] -num_envs = 64 +num_envs = 4096 num_workers = 16 batch_size = auto zero_copy = True @@ -15,7 +15,7 @@ render_mode = none variants = heart_chorus inventory_heart_tune [train] -total_timesteps = 50_000_000 +total_timesteps = 3_000_000_000 batch_size = auto -minibatch_size = 1024 +minibatch_size = 32768 bptt_horizon = 64 diff --git a/pufferlib/config/default.ini b/pufferlib/config/default.ini index cc4bf1dae..595dc2261 100644 --- a/pufferlib/config/default.ini +++ b/pufferlib/config/default.ini @@ -32,7 +32,7 @@ anneal_lr = True min_lr_ratio = 0.0 gamma = 0.995 gae_lambda = 0.90 -update_epochs = 1 +num_minibatches = 16 clip_coef = 0.2 vf_coef = 2.0 vf_clip_coef = 0.2 @@ -51,7 +51,7 @@ minibatch_size = 8192 max_minibatch_size = 32768 bptt_horizon = 64 compile = False -compile_mode = max-autotune-no-cudagraphs +compile_mode = reduce-overhead compile_fullgraph = True vtrace_rho_clip = 1.0 @@ -60,6 +60,8 @@ vtrace_c_clip = 1.0 prio_alpha = 0.8 prio_beta0 = 0.2 +max_cost = -1 + [sweep] method = Protein metric = score @@ -82,6 +84,20 @@ min = 3e7 max = 1e10 scale = time +[sweep.policy.hidden_size] +distribution = uniform_pow2 +min = 16 +max = 1024 +mean = 128 +scale = auto + +[sweep.env.num_envs] +distribution = uniform_pow2 +min = 1 +max = 4096 +mean = 2048 +scale = auto + [sweep.train.bptt_horizon] distribution = uniform_pow2 min = 16 @@ -90,7 +106,7 @@ scale = auto [sweep.train.minibatch_size] distribution = uniform_pow2 -min = 8192 +min = 512 max = 65536 scale = auto @@ -130,11 +146,12 @@ min = 0.1 max = 5.0 scale = auto -#[sweep.train.update_epochs] -#distribution = int_uniform -#min = 1 -#max = 8 -#scale = 2.0 +[sweep.train.num_minibatches] +distribution = uniform_pow2 +min = 1 +max = 1024 +mean = 32 +scale = auto [sweep.train.clip_coef] distribution = uniform diff --git a/pufferlib/config/ocean/breakout.ini b/pufferlib/config/ocean/breakout.ini index d261503f5..dabe4e737 100644 --- a/pufferlib/config/ocean/breakout.ini +++ b/pufferlib/config/ocean/breakout.ini @@ -1,14 +1,16 @@ [base] package = ocean env_name = puffer_breakout -policy_name = Policy +policy_name = MinGRU rnn_name = Recurrent [vec] -num_envs = 8 +#num_envs = 4 +num_envs = 1 [env] -num_envs = 1024 +#num_envs = 2048 +num_envs = 8192 frameskip = 4 width = 576 height = 330 @@ -27,13 +29,46 @@ continuous = 0 [policy] hidden_size = 128 +num_layers = 4 +expansion_factor = 1 -[rnn] -input_size = 128 -hidden_size = 128 +[sweep.policy.hidden_size] +distribution = uniform_pow2 +min = 16 +max = 512 +mean = 128 +scale = auto + +[sweep.policy.num_layers] +distribution = int_uniform +min = 1 +max = 4 +mean = 2 +scale = auto + +#[sweep.policy.d_state] +#distribution = uniform_pow2 +#min = 32 +#max = 128 +#mean = 32 +#scale = auto + +#[sweep.policy.d_conv] +#distribution = int_uniform +#min = 1 +#max = 4 +#mean = 2 +#scale = auto + +[sweep.policy.expansion_factor] +distribution = int_uniform +min = 1 +max = 2 +mean = 1 +scale = auto [train] -total_timesteps = 90_000_000 +total_timesteps = 120_000_000 adam_beta1 = 0.8946507418260217 adam_beta2 = 0.9 adam_eps = 0.0001 @@ -53,11 +88,56 @@ vf_coef = 1.6832989594296321 vtrace_c_clip = 2.878171091654008 vtrace_rho_clip = 0.7876748061547312 -[sweep.train.total_timesteps] -distribution = log_normal -min = 3e7 -max = 2e8 -mean = 8e7 +#total_timesteps = 120_000_000 +#adam_beta1 = 0.8166332218104871 +#adam_beta2 = 0.9984879989750705 +#adam_eps = 0.0001 +#batch_size = auto +#bptt_horizon = 64 +#clip_coef = 0.42526610231849393 +#ent_coef = 0.0026822968018267775 +#gae_lambda = 0.995 +#gamma = 0.9731819086255716 +#learning_rate = 0.04301709139429238 +#max_grad_norm = 0.7029618837611082 +#minibatch_size = 16384 +#prio_alpha = 0.09999999999999998 +#prio_beta0 = 0.8437844355214735 +#vf_clip_coef = 0.807798225723059 +#vf_coef = 2.9089121311247554 +#vtrace_c_clip = 1.6205569942514606 +#vtrace_rho_clip = 1.1777184656786774 + +#total_timesteps = 40_000_000 +#adam_beta1 = 0.9389740236912132 +#adam_beta2 = 0.9998225039929157 +#adam_eps = 1.0267361590791064e-8 +#batch_size = auto +#bptt_horizon = 64 +#clip_coef = 0.01557913923814178 +#ent_coef = 0.0031759371032913 +#gae_lambda = 0.916681264452842 +#gamma = 0.9997053654668936 +#learning_rate = 0.012744235594115342 +#max_grad_norm = 1.8013800046071862 +#num_minibatches = 8 +#minibatch_size = 4096 +#prio_alpha = 0.9500430793857082 +#prio_beta0 = 0.9436845548994959 +#vf_clip_coef = 0.1 +#vf_coef = 2.5994729835919834 +#vtrace_c_clip = 2.878171091654008 +#vtrace_rho_clip = 1.3235791596831579 + +[sweep] +downsample = 10 +max_cost = 300 + +[sweep.env.num_envs] +distribution = uniform_pow2 +min = 1 +max = 4096 +mean = 2048 scale = auto [sweep.env.frameskip] diff --git a/pufferlib/config/ocean/g2048.ini b/pufferlib/config/ocean/g2048.ini index 3ca7f4e8c..1e8196051 100644 --- a/pufferlib/config/ocean/g2048.ini +++ b/pufferlib/config/ocean/g2048.ini @@ -1,11 +1,13 @@ [base] package = ocean env_name = puffer_g2048 -policy_name = G2048 +policy_name = G2048LSTM rnn_name = Recurrent [policy] hidden_size = 512 +#num_layers = 4 +expansion_factor = 1 [rnn] input_size = 512 @@ -22,6 +24,42 @@ scaffolding_ratio = 0.67 use_heuristic_rewards = True snake_reward_weight = 0.0005 +[sweep.policy.hidden_size] +distribution = uniform_pow2 +min = 16 +max = 256 +mean = 128 +scale = auto + +[sweep.policy.num_layers] +distribution = int_uniform +min = 1 +max = 4 +mean = 2 +scale = auto + +[sweep.policy.d_state] +distribution = uniform_pow2 +min = 8 +max = 128 +mean = 32 +scale = auto + +[sweep.policy.d_conv] +distribution = int_uniform +min = 1 +max = 4 +mean = 2 +scale = auto + +[sweep.policy.expand] +distribution = int_uniform +min = 1 +max = 2 +mean = 1 +scale = auto + + [train] # 512 hidden: https://wandb.ai/kywch/pufferlib/runs/5thsjr61?nw=nwuserkywch total_timesteps = 6_767_676_767 @@ -30,6 +68,7 @@ min_lr_ratio = 0.15 batch_size = auto bptt_horizon = 64 minibatch_size = 32768 +num_minibatches = 32 clip_coef = 0.067 ent_coef = 0.0267 @@ -164,4 +203,4 @@ scale = auto ; min = 0.001 ; max = 0.5 ; mean = 0.05 -; scale = auto \ No newline at end of file +; scale = auto diff --git a/pufferlib/config/ocean/grid.ini b/pufferlib/config/ocean/grid.ini index 65bd540b6..e28885c2b 100644 --- a/pufferlib/config/ocean/grid.ini +++ b/pufferlib/config/ocean/grid.ini @@ -7,10 +7,6 @@ rnn_name = Recurrent [policy] hidden_size = 512 -[rnn] -input_size = 512 -hidden_size = 512 - [vec] #num_envs = 8 num_envs = 1 @@ -63,10 +59,25 @@ vtrace_rho_clip = 4.7398234531013985 [sweep] downsample = 0 +max_cost = 300 [sweep.train.total_timesteps] distribution = log_normal -min = 3e8 -max = 6e8 +min = 1e7 +max = 1e9 mean = 3e8 scale = time + +[sweep.policy.hidden_size] +distribution = uniform_pow2 +min = 16 +max = 1024 +mean = 128 +scale = auto + +[sweep.env.num_envs] +distribution = uniform_pow2 +min = 1 +max = 4096 +mean = 2048 +scale = auto diff --git a/pufferlib/config/ocean/impulse_wars.ini b/pufferlib/config/ocean/impulse_wars.ini index c4b3bdcc1..50adb8008 100644 --- a/pufferlib/config/ocean/impulse_wars.ini +++ b/pufferlib/config/ocean/impulse_wars.ini @@ -6,9 +6,8 @@ rnn_name = ImpulseWarsLSTM max_suggestion_cost = 10_800 [policy] -cnn_channels = 64 -input_size = 512 hidden_size = 512 +cnn_channels = 64 # These must match what's set in env below continuous = False @@ -16,12 +15,12 @@ num_drones = 2 is_training = True [vec] -num_envs = 16 -num_workers = 16 -batch_size = 4 +num_envs = 4 +#num_workers = 4 +#batch_size = 4 [env] -num_envs = 256 +num_envs = 1024 num_drones = 2 num_agents = 1 enable_teams = False @@ -40,10 +39,14 @@ compile_mode = reduce-overhead compile_fullgraph = False device = cuda +[sweep] +downsample = 10 +max_cost = 900 + [sweep.env.num_envs] distribution = uniform_pow2 -min = 16 -max = 512 +min = 1 +max = 1024 mean = 128 scale = auto @@ -140,51 +143,3 @@ max = 256 mean = 128 scale = auto -[sweep.train.minibatch_size] -distribution = uniform_pow2 -min = 1024 -max = 262_144 -mean = 16_384 -scale = auto - -[sweep.train.learning_rate] -distribution = log_normal -min = 0.00001 -mean = 0.001 -max = 0.1 -scale = 0.5 - -[sweep.train.ent_coef] -distribution = log_normal -min = 0.000001 -mean = 0.001 -max = 0.2 -scale = auto - -[sweep.train.gamma] -distribution = logit_normal -min = 0.8 -mean = 0.98 -max = 0.99999 -scale = auto - -[sweep.train.gae_lambda] -distribution = logit_normal -min = 0.6 -mean = 0.93 -max = 0.995 -scale = auto - -[sweep.train.vf_coef] -distribution = uniform -min = 0.0 -max = 5.0 -mean = 1.0 -scale = auto - -[sweep.train.max_grad_norm] -distribution = uniform -min = 0.0 -mean = 1.0 -max = 5.0 -scale = auto diff --git a/pufferlib/config/ocean/moba.ini b/pufferlib/config/ocean/moba.ini index 2e0e8cea3..73bcdeb68 100644 --- a/pufferlib/config/ocean/moba.ini +++ b/pufferlib/config/ocean/moba.ini @@ -12,18 +12,36 @@ reward_tower = 4.525112152099609 num_envs = 128 [vec] -num_envs = 8 +num_envs = 4 [train] total_timesteps = 150_000_000 +[sweep] +downsample = 10 +max_cost = 500 + [sweep.train.total_timesteps] distribution = log_normal min = 2e7 -max = 2e8 +max = 5e8 mean = 1e8 scale = auto +[sweep.policy.hidden_size] +distribution = uniform_pow2 +min = 16 +max = 1024 +mean = 128 +scale = auto + +[sweep.env.num_envs] +distribution = uniform_pow2 +min = 1 +max = 4096 +mean = 2048 +scale = auto + [sweep.env.reward_death] distribution = uniform min = -1.0 diff --git a/pufferlib/config/ocean/nmmo3.ini b/pufferlib/config/ocean/nmmo3.ini index c04c77dc3..cb719cc41 100644 --- a/pufferlib/config/ocean/nmmo3.ini +++ b/pufferlib/config/ocean/nmmo3.ini @@ -1,11 +1,11 @@ [base] package = ocean env_name = puffer_nmmo3 -policy_name = NMMO3 +policy_name = NMMO3MinGRU rnn_name = NMMO3LSTM [vec] -num_envs = 8 +num_envs = 4 [env] reward_combat_level = 1.0 @@ -13,7 +13,12 @@ reward_prof_level = 1.0 reward_item_level = 1.0 reward_market = 0.0 reward_death = -1.0 -num_envs = 1 +num_envs = 2 + +[policy] +hidden_size = 512 +num_layers = 4 +expansion_factor = 1 [train] total_timesteps = 107000000000 @@ -31,6 +36,7 @@ max_minibatch_size = 32768 [sweep] metric = min_comb_prof +max_cost = 900 [sweep.env.num_envs] distribution = uniform_pow2 @@ -39,13 +45,6 @@ max = 8 mean = 4 scale = 0.5 -[sweep.train.total_timesteps] -distribution = log_normal -min = 2e8 -max = 1e9 -mean = 5e8 -scale = 0.5 - [sweep.env.reward_combat_level] distribution = uniform min = 0.0 diff --git a/pufferlib/config/ocean/pacman.ini b/pufferlib/config/ocean/pacman.ini index 45055e79b..07f03517e 100644 --- a/pufferlib/config/ocean/pacman.ini +++ b/pufferlib/config/ocean/pacman.ini @@ -5,7 +5,7 @@ policy_name = Policy rnn_name = Recurrent [vec] -num_envs = 8 +num_envs = 4 [env] num_envs = 1024 @@ -31,3 +31,27 @@ vf_coef = 0.31518694995467555 vtrace_c_clip = 0.30575543665366217 vtrace_rho_clip = 1.5301756939690652 +[sweep] +downsample = 10 +max_cost = 300 + +[sweep.train.total_timesteps] +distribution = log_normal +min = 2e7 +max = 5e8 +mean = 1e8 +scale = auto + +[sweep.policy.hidden_size] +distribution = uniform_pow2 +min = 16 +max = 1024 +mean = 128 +scale = auto + +[sweep.env.num_envs] +distribution = uniform_pow2 +min = 1 +max = 4096 +mean = 2048 +scale = auto diff --git a/pufferlib/config/ocean/pong.ini b/pufferlib/config/ocean/pong.ini index a0bf24d93..9ca522646 100644 --- a/pufferlib/config/ocean/pong.ini +++ b/pufferlib/config/ocean/pong.ini @@ -4,15 +4,20 @@ env_name = puffer_pong policy_name = Policy rnn_name = Recurrent +[policy] +hidden_size = 512 + [vec] num_envs = 4 [env] -num_envs = 1024 -frameskip = 8 +num_envs = 1024 +frameskip = 4 [train] -total_timesteps = 12_000_000 +max_cost=20 + +total_timesteps = 500_000 adam_beta1 = 0.9766295300012044 adam_beta2 = 0.9998113167362397 adam_eps = 6.301709731262074e-9 @@ -24,6 +29,7 @@ gamma = 0.9608378504980243 learning_rate = 0.07109386062895108 max_grad_norm = 1.7820203601055993 minibatch_size = 32768 +num_minibatches = 8 prio_alpha = 0.09999999999999998 prio_beta0 = 0.7475661360032159 vf_clip_coef = 2.7025841941932303 @@ -31,16 +37,51 @@ vf_coef = 1.9960893747329385 vtrace_c_clip = 1.0873122745787867 vtrace_rho_clip = 2.784150207139061 +#total_timesteps = 20000000.0 +#learning_rate = 0.08878791349515394 +#gamma = 0.9354145180237635 +#gae_lambda = 0.9020935398076688 +#num_minibatches = 32 +#clip_coef = 0.5882777043345978 +#vf_coef = 4.196442104147645 +#vf_clip_coef = 0.265385659520976 +#max_grad_norm = 0.3661413663411234 +#ent_coef = 0.0011560317997450196 +#adam_beta1 = 0.9462393585831101 +#adam_beta2 = 0.9667417156941432 +#adam_eps = 1.1005478999774079e-09 +#minibatch_size = 65536 +#max_minibatch_size = 32768 +#bptt_horizon = 64 +#vtrace_rho_clip = 1.8180933155594725 +#vtrace_c_clip = 1.4235484929825957 +#prio_alpha = 0.9553779337727483 +#prio_beta0 = 0.7125182812602482 + + +[sweep] +downsample = 0 + +[sweep.policy.hidden_size] +distribution = uniform_pow2 +min = 16 +max = 1024 +scale = auto + [sweep.train.total_timesteps] distribution = log_normal -min = 1e7 -max = 2e8 -mean = 8e7 +min = 5e5 +max = 5e6 +scale = auto + +[sweep.env.num_envs] +distribution = uniform_pow2 +min = 1 +max = 4096 scale = auto [sweep.env.frameskip] distribution = int_uniform min = 1 max = 8 -mean = 4 scale = 2.0 diff --git a/pufferlib/config/ocean/rware.ini b/pufferlib/config/ocean/rware.ini index 705e0af3e..791c426f4 100644 --- a/pufferlib/config/ocean/rware.ini +++ b/pufferlib/config/ocean/rware.ini @@ -5,10 +5,10 @@ policy_name = Policy rnn_name = Recurrent [vec] -num_envs = 8 +num_envs = 4 [env] -num_envs = 128 +num_envs = 256 map_choice = 2 num_agents = 8 num_requested_shelves = 8 @@ -17,10 +17,3 @@ num_requested_shelves = 8 total_timesteps = 100_000_000 learning_rate = 0.05 minibatch_size = 32768 - -[sweep.train.total_timesteps] -distribution = log_normal -min = 3e7 -max = 3e8 -mean = 1e8 -scale = 0.25 diff --git a/pufferlib/config/ocean/snake.ini b/pufferlib/config/ocean/snake.ini index 3827b0252..9eafe9400 100644 --- a/pufferlib/config/ocean/snake.ini +++ b/pufferlib/config/ocean/snake.ini @@ -6,7 +6,7 @@ policy_name = Snake rnn_name = Recurrent [env] -num_envs = 4 +num_envs = 16 width = 640 height = 360 num_snakes = 256 @@ -18,7 +18,7 @@ reward_corpse = 0.1 reward_death = -1.0 [vec] -num_envs = 16 +num_envs = 1 [train] total_timesteps = 500_000_000 @@ -40,6 +40,9 @@ vf_coef = 3.9655925817980053 vtrace_c_clip = 0 vtrace_rho_clip = 0.9285200248552337 +[sweep] +max_cost = 500 + [sweep.env.reward_food] distribution = uniform min = 0.0 @@ -56,7 +59,21 @@ scale = auto [sweep.train.total_timesteps] distribution = log_normal -min = 5e7 -max = 2e8 +min = 2e7 +max = 5e8 mean = 1e8 scale = auto + +[sweep.policy.hidden_size] +distribution = uniform_pow2 +min = 16 +max = 1024 +mean = 128 +scale = auto + +[sweep.env.num_envs] +distribution = uniform_pow2 +min = 1 +max = 32 +mean = 8 +scale = auto diff --git a/pufferlib/config/ocean/squared.ini b/pufferlib/config/ocean/squared.ini index ac9f69d0f..c4e4ad9db 100644 --- a/pufferlib/config/ocean/squared.ini +++ b/pufferlib/config/ocean/squared.ini @@ -1,14 +1,27 @@ [base] package = ocean env_name = puffer_squared -policy_name = Policy +policy_name = MinGRU rnn_name = Recurrent +[vec] +num_envs = 1 +backend = Serial + +[policy] +hidden_size = 128 +num_layers = 1 +expand = 2 + [env] num_envs = 4096 [train] -total_timesteps = 20_000_000 -gamma = 0.95 -learning_rate = 0.05 +optimizer = adam +total_timesteps = 200_000_000 +gamma = 0.99 +learning_rate = 0.01 minibatch_size = 32768 +num_minibatches = 8 +ent_coef = 0.0 # TODO: Are numerics bad here in cpp? +#adam_eps = 1e-5 diff --git a/pufferlib/config/ocean/tetris.ini b/pufferlib/config/ocean/tetris.ini index 5aab21422..6d53031d4 100644 --- a/pufferlib/config/ocean/tetris.ini +++ b/pufferlib/config/ocean/tetris.ini @@ -1,11 +1,11 @@ [base] package = ocean env_name = puffer_tetris -policy_name = Policy +policy_name = MinGRU rnn_name = Recurrent [vec] -num_envs = 8 +num_envs = 4 [env] num_envs = 2048 @@ -18,10 +18,45 @@ n_noise_obs = 0 [policy] hidden_size = 256 +num_layers = 1 +#d_state = 32 +#d_conv = 4 +expand = 2 -[rnn] -input_size = 256 -hidden_size = 256 +[sweep.policy.hidden_size] +distribution = uniform_pow2 +min = 16 +max = 512 +mean = 128 +scale = auto + +[sweep.policy.num_layers] +distribution = int_uniform +min = 1 +max = 4 +mean = 2 +scale = auto + +[sweep.policy.d_state] +distribution = uniform_pow2 +min = 8 +max = 128 +mean = 32 +scale = auto + +#[sweep.policy.d_conv] +#distribution = int_uniform +#min = 1 +#max = 4 +#mean = 2 +#scale = auto + +[sweep.policy.expand] +distribution = int_uniform +min = 1 +max = 2 +mean = 1 +scale = auto [train] # https://wandb.ai/kywch/pufferlib/runs/era6a8p6?nw=nwuserkywch @@ -46,10 +81,10 @@ vf_coef = 4.74 vtrace_c_clip = 1.29 vtrace_rho_clip = 0.70 - [sweep] metric = score goal = maximize +max_cost = 3600 [sweep.train.total_timesteps] distribution = log_normal @@ -78,3 +113,10 @@ min = 0.5 mean = 0.95 max = 0.999 scale = auto + +[sweep.env.num_envs] +distribution = uniform_pow2 +min = 1 +max = 4096 +mean = 2048 +scale = auto diff --git a/pufferlib/config/ocean/tower_climb.ini b/pufferlib/config/ocean/tower_climb.ini index ce6f75d59..629c09cbc 100644 --- a/pufferlib/config/ocean/tower_climb.ini +++ b/pufferlib/config/ocean/tower_climb.ini @@ -2,56 +2,86 @@ package = ocean env_name = puffer_tower_climb policy_name = TowerClimb -rnn_name = TowerClimbLSTM +rnn_name = Recurrent + +[policy] +hidden_size = 256 + +[rnn] +hidden_size = 256 +num_layers = 1 [vec] -num_envs = 8 +num_envs = 4 [env] num_envs = 1024 -num_maps = 50 -reward_climb_row = 0.636873185634613 -reward_fall_row = -0.15898257493972778 -reward_illegal_move = -0.003928301855921745 -reward_move_block = 0.235064297914505 +reward_climb_row = 0.16 +reward_fall_row = -0.13 +reward_illegal_move = -0.005 +reward_move_block = 0.035 [train] -total_timesteps = 150_000_000 -#gamma = 0.98 -#learning_rate = 0.05 -minibatch_size = 32768 +# https://wandb.ai/kywch/pufferlib/runs/b8ym2mvu/overview +total_timesteps = 600_000_000 +anneal_lr = True +min_lr_ratio = 0.1 +batch_size = auto +bptt_horizon = 64 +minibatch_size = 65536 + +clip_coef = 0.6 +ent_coef = 0.08 +gae_lambda = 0.6 +gamma = 0.95 +vf_clip_coef = 5.0 +vf_coef = 5.0 + +learning_rate = 0.023 +max_grad_norm = 5.0 + +adam_beta1 = 0.81 +adam_beta2 = 0.95 +adam_eps = 1.0e-8 +prio_alpha = 0.99 +prio_beta0 = 0.99 +vtrace_c_clip = 3.7 +vtrace_rho_clip = 3.8 + +[sweep] +metric = perf +metric_distribution = percentile + +# configs for targeted sweep. Comment these out for broad sweep +; downsample = 1 +; sweep_only = reward_climb_row, reward_fall_row, reward_illegal_move, reward_move_block, learning_rate, adam_beta1, adam_beta2, adam_eps, vtrace_c_clip, vtrace_rho_clip [sweep.train.total_timesteps] distribution = uniform -min = 50_000_000 -max = 200_000_000 -mean = 100_000_000 +min = 100_000_000 +max = 2_000_000_000 scale = 0.5 [sweep.env.reward_climb_row] distribution = uniform min = 0.0 max = 1.0 -mean = 0.5 scale = auto [sweep.env.reward_fall_row] distribution = uniform min = -1.0 max = 0.0 -mean = -0.5 scale = auto [sweep.env.reward_illegal_move] distribution = uniform min = -1e-2 max = -1e-4 -mean = -1e-3 scale = auto [sweep.env.reward_move_block] distribution = uniform min = 0.0 max = 1.0 -mean = 0.5 scale = auto diff --git a/pufferlib/config/ocean/tripletriad.ini b/pufferlib/config/ocean/tripletriad.ini index aae55d096..4d4a1ffd8 100644 --- a/pufferlib/config/ocean/tripletriad.ini +++ b/pufferlib/config/ocean/tripletriad.ini @@ -8,14 +8,14 @@ rnn_name = Recurrent num_envs = 1024 [vec] -num_envs = 8 +num_envs = 4 [train] total_timesteps = 100_000_000 [sweep.train.total_timesteps] distribution = log_normal -min = 5e7 +min = 1e7 max = 2e8 mean = 1e8 -scale = 0.25 +scale = time diff --git a/pufferlib/environments/cogames/environment.py b/pufferlib/environments/cogames/environment.py index 0fbe47595..61e36cd27 100644 --- a/pufferlib/environments/cogames/environment.py +++ b/pufferlib/environments/cogames/environment.py @@ -21,7 +21,7 @@ def make(name="cogames.cogs_v_clips.machina_1.open_world", variants=None, cogs=N simulator = Simulator() simulator.add_event_handler(StatsTracker(NoopStatsWriter())) env = PufferMettaGridEnv(simulator=simulator, cfg=env_cfg, buf=buf, seed=seed or 0) - env.render_mode = render + #env.render_mode = render if seed: env.reset(seed) return env diff --git a/pufferlib/extensions/breakout.c b/pufferlib/extensions/breakout.c new file mode 100644 index 000000000..560fa16a2 --- /dev/null +++ b/pufferlib/extensions/breakout.c @@ -0,0 +1,34 @@ +#include "../ocean/breakout/breakout.h" +#define OBS_SIZE 118 +#define ACT_SIZE 1 +#define OBS_TYPE FLOAT +#define ACT_TYPE FLOAT + +#define Env Breakout +#include "env_binding.h" + +void my_init(Env* env, Dict* kwargs) { + env->frameskip = dict_get(kwargs, "frameskip")->int_value; + env->width = dict_get(kwargs, "width")->int_value; + env->height = dict_get(kwargs, "height")->int_value; + env->initial_paddle_width = dict_get(kwargs, "paddle_width")->int_value; + env->paddle_height = dict_get(kwargs, "paddle_height")->int_value; + env->ball_width = dict_get(kwargs, "ball_width")->int_value; + env->ball_height = dict_get(kwargs, "ball_height")->int_value; + env->brick_width = dict_get(kwargs, "brick_width")->int_value; + env->brick_height = dict_get(kwargs, "brick_height")->int_value; + env->brick_rows = dict_get(kwargs, "brick_rows")->int_value; + env->brick_cols = dict_get(kwargs, "brick_cols")->int_value; + env->initial_ball_speed = dict_get(kwargs, "initial_ball_speed")->int_value; + env->max_ball_speed = dict_get(kwargs, "max_ball_speed")->int_value; + env->paddle_speed = dict_get(kwargs, "paddle_speed")->int_value; + env->continuous = dict_get(kwargs, "continuous")->int_value; + init(env); +} + +void my_log(Log* log, Dict* out) { + dict_set_float(out, "perf", log->perf); + dict_set_float(out, "score", log->score); + dict_set_float(out, "episode_return", log->episode_return); + dict_set_float(out, "episode_length", log->episode_length); +} diff --git a/pufferlib/extensions/cuda/kernels.cu b/pufferlib/extensions/cuda/kernels.cu new file mode 100644 index 000000000..8bfdfcf2a --- /dev/null +++ b/pufferlib/extensions/cuda/kernels.cu @@ -0,0 +1,1518 @@ +/* Kernels must launch on the current torch stream to be traced by cudagraphs. + * Launch functions take cudaStream_t as parameter - callers (modules.cu) should + * pass at::cuda::getCurrentCUDAStream() when using with torch. + */ + +#include +#include "ops.cuh" +#include +#include + +#include +#include + +#define SEQ_SIZE 32 +#define BLOCK_SIZE 256 +inline int grid_size(int N) { + return (N + BLOCK_SIZE - 1) / BLOCK_SIZE; +} +inline int seq_size(int N) { + return (N + SEQ_SIZE - 1) / SEQ_SIZE; +} + +// If you can get this to work, go ahead. I tried. +// NVCC won't parse templated types in kernel launches +/* +template