import%20marimo%0A%0A__generated_with%20%3D%20%220.19.9%22%0Aapp%20%3D%20marimo.App(width%3D%22medium%22)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20import%20marimo%20as%20mo%0A%0A%20%20%20%20return%20(mo%2C)%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%20Distributed%20Tensors%20in%20Monarch%0A%0A%20%20%20%20Monarch%20can%20broadcast%20tensor%20compute%20to%20a%20mesh%20of%20processes%2C%20allowing%20a%20single%0A%20%20%20%20controller%20to%20do%20distributed%20tensor%20compute.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20import%20monarch%0A%20%20%20%20import%20torch%0A%20%20%20%20import%20torch.nn%20as%20nn%0A%20%20%20%20from%20monarch.actor%20import%20this_host%0A%0A%20%20%20%20torch.set_default_device(%22cuda%22)%0A%20%20%20%20return%20monarch%2C%20nn%2C%20this_host%2C%20torch%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20Meshes%0A%0A%20%20%20%20All%20computation%20is%20done%20on%20a%20'mesh'%20of%20devices.%0A%20%20%20%20Here%20we%20create%20a%20mesh%20composed%20of%20the%20machine%20running%20the%20notebook%3A%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(this_host)%3A%0A%20%20%20%20mesh%20%3D%20this_host().spawn_procs(%7B%22gpu%22%3A%208%7D)%0A%20%20%20%20print(mesh.to_table())%0A%20%20%20%20return%20(mesh%2C)%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Without%20a%20mesh%20active%2C%20torch%20runs%20locally.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(torch)%3A%0A%20%20%20%20torch.rand(3%2C%204)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Once%20active%2C%20torch%20runs%20on%20every%20device%20in%20the%20mesh.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mesh%2C%20torch)%3A%0A%20%20%20%20with%20mesh.activate()%3A%0A%20%20%20%20%20%20%20%20t%20%3D%20torch.rand(3%2C%204%2C%20device%3D%22cuda%22)%0A%20%20%20%20t%0A%20%20%20%20return%20(t%2C)%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%60inspect%60%20moves%20rank%200's%20copy%20of%20%60t%60%20to%20the%20notebook%20for%20debugging.%0A%20%20%20%20Providing%20coordinates%20lets%20us%20inspect%20other%20ranks'%20copies.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(monarch%2C%20t)%3A%0A%20%20%20%20monarch.inspect(t)%0A%20%20%20%20monarch.show(t)%0A%20%20%20%20monarch.show(t%2C%20gpu%3D1)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20Tensor%20Commands%0A%0A%20%20%20%20Any%20command%20done%20on%20the%20controller%2C%20such%20as%20multiplying%20these%20tensors%2C%0A%20%20%20%20performs%20that%20action%20to%20all%20of%20the%20tensors%20in%20the%20collection.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mesh%2C%20monarch%2C%20t)%3A%0A%20%20%20%20with%20mesh.activate()%3A%0A%20%20%20%20%20%20%20%20obj%20%3D%20t%20%40%20t.T%0A%20%20%20%20%20%20%20%20monarch.show(obj)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20If%20a%20command%20fails%2C%20the%20workers%20stay%20alive%20and%20can%20execute%20future%20commands%3A%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mesh%2C%20monarch%2C%20t%2C%20torch)%3A%0A%20%20%20%20try%3A%0A%20%20%20%20%20%20%20%20with%20mesh.activate()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20big_w%20%3D%20torch.rand(4%2C%201024%20*%201024%20*%201024%20*%201024%20*%208%2C%20device%3D%22cuda%22)%0A%20%20%20%20%20%20%20%20%20%20%20%20v%20%3D%20t%20%40%20big_w%0A%20%20%20%20%20%20%20%20%20%20%20%20monarch.show(v)%0A%20%20%20%20except%20Exception%3A%0A%20%20%20%20%20%20%20%20import%20traceback%0A%20%20%20%20%20%20%20%20traceback.print_exc()%0A%0A%20%20%20%20print(%22RECOVERED!%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Since%20monarch%20recovers%20from%20errors%2C%20you%20can%20search%20for%20what%20works%3A%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mesh%2C%20monarch%2C%20torch)%3A%0A%20%20%20%20N%20%3D%201%0A%20%20%20%20while%20True%3A%0A%20%20%20%20%20%20%20%20try%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20with%20mesh.activate()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20batch%20%3D%20torch.rand(N%2C%201024%20*%201024%20*%201024%2C%20device%3D%22cuda%22)%0A%20%20%20%20%20%20%20%20%20%20%20%20monarch.inspect(batch.sum())%0A%20%20%20%20%20%20%20%20%20%20%20%20N%20%3D%202%20*%20N%0A%20%20%20%20%20%20%20%20%20%20%20%20print(f%22at%20least%202**%7BN%7D%20elements%20work%22)%0A%20%20%20%20%20%20%20%20except%20Exception%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20print(f%22max%20is%202**%7BN%7D%20elements%22)%0A%20%20%20%20%20%20%20%20%20%20%20%20break%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20Collectives%0A%0A%20%20%20%20Each%20machine%20has%20its%20own%20copy%20of%20the%20tensor%2C%20similar%20to%20%60torch.distributed%60.%0A%0A%20%20%20%20To%20compute%20across%20tensors%20in%20the%20mesh%2C%20we%20use%20special%20communication%20operators%2C%0A%20%20%20%20analogous%20to%20collectives.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mesh%2C%20monarch%2C%20torch)%3A%0A%20%20%20%20with%20mesh.activate()%3A%0A%20%20%20%20%20%20%20%20a%20%3D%20torch.rand(3%2C%204%2C%20device%3D%22cuda%22)%0A%20%20%20%20%20%20%20%20r%20%3D%20a.reduce(%22gpu%22%2C%20%22sum%22)%0A%0A%20%20%20%20monarch.show(a%2C%20gpu%3D0)%0A%20%20%20%20monarch.show(a%2C%20gpu%3D1)%0A%0A%20%20%20%20monarch.show(r%2C%20gpu%3D0)%0A%20%20%20%20monarch.show(r%2C%20gpu%3D1)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20Remote%20GPUs%0A%0A%20%20%20%20We%20can%20also%20connect%20to%20remote%20GPUs%20reserved%20from%20some%20scheduler.%0A%20%20%20%20Here%20we%20simulate%20a%20multi-host%20setup%20locally%3A%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(this_host%2C%20torch)%3A%0A%20%20%20%20remote_mesh%20%3D%20this_host().spawn_procs(%7B%22host%22%3A%204%2C%20%22gpu%22%3A%204%7D)%0A%0A%20%20%20%20print(remote_mesh.to_table())%0A%20%20%20%20with%20remote_mesh.activate()%3A%0A%20%20%20%20%20%20%20%20eg%20%3D%20torch.rand(3%2C%204%2C%20device%3D%22cuda%22)%0A%20%20%20%20%20%20%20%20rgpu%20%3D%20eg.reduce(%22gpu%22%2C%20%22sum%22)%0A%20%20%20%20%20%20%20%20rhost%20%3D%20eg.reduce(%22host%22%2C%20%22sum%22)%0A%20%20%20%20return%20(remote_mesh%2C)%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20Device%20Mesh%20Dimensions%0A%0A%20%20%20%20Meshes%20can%20be%20renamed%20and%20reshaped%20to%20fit%20the%20parallelism%20desired.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(remote_mesh)%3A%0A%20%20%20%20mesh_2d_parallel%20%3D%20remote_mesh.rename(host%3D%22dp%22%2C%20gpu%3D%22tp%22)%0A%20%20%20%20print(mesh_2d_parallel.to_table())%0A%0A%20%20%20%20mesh_3d_parallel%20%3D%20remote_mesh.split(host%3D(%22dp%22%2C%20%22pp%22)%2C%20gpu%3D(%22tp%22%2C)%2C%20pp%3D2)%0A%20%20%20%20print(mesh_3d_parallel.to_table())%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20Pipelining%0A%0A%20%20%20%20Pipelining%20is%20accomplished%20by%20slicing%20the%20mesh%2C%20and%20copying%20tensors%20from%0A%20%20%20%20one%20mesh%20to%20another.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(remote_mesh)%3A%0A%20%20%20%20pipeline_mesh%20%3D%20remote_mesh.rename(host%3D%22pp%22)%0A%20%20%20%20meshes%20%3D%20%5Bpipeline_mesh.slice(pp%3Di)%20for%20i%20in%20range(pipeline_mesh.size(%22pp%22))%5D%0A%20%20%20%20print(meshes%5B0%5D.to_table())%0A%20%20%20%20return%20(meshes%2C)%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Initialize%20a%20model%20across%20multiple%20pipeline%20stages%3A%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(meshes%2C%20monarch%2C%20nn%2C%20torch)%3A%0A%20%20%20%20layers_per_stage%20%3D%202%0A%20%20%20%20stages%20%3D%20%5B%5D%0A%20%20%20%20for%20stage_mesh%20in%20meshes%3A%0A%20%20%20%20%20%20%20%20with%20stage_mesh.activate()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20layers%20%3D%20%5B%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20for%20_%20in%20range(layers_per_stage)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20layers.extend(%5Bnn.Linear(4%2C%204)%2C%20nn.ReLU()%5D)%0A%20%20%20%20%20%20%20%20%20%20%20%20stages.append(nn.Sequential(*layers))%0A%0A%20%20%20%20def%20forward_pipeline(x)%3A%0A%20%20%20%20%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20for%20stage_mesh%2C%20stage%20in%20zip(meshes%2C%20stages)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20x.to_mesh(stage_mesh)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20with%20stage_mesh.activate()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20x%20%3D%20stage(x)%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20x%0A%0A%20%20%20%20with%20meshes%5B0%5D.activate()%3A%0A%20%20%20%20%20%20%20%20input%20%3D%20torch.rand(3%2C%204%2C%20device%3D%22cuda%22)%0A%0A%20%20%20%20output%20%3D%20forward_pipeline(input)%0A%20%20%20%20monarch.show(output)%0A%20%20%20%20print(output.mesh.to_table())%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20DDP%20Example%0A%0A%20%20%20%20The%20next%20sections%20use%20an%20example%20of%20writing%20DDP%20to%20illustrate%20a%20typical%20way%20to%0A%20%20%20%20develop%20code%20in%20Monarch.%0A%0A%20%20%20%20We%20interleave%20the%20backward%20pass%20with%20the%20gradient%20reductions%20and%20parameter%20updates.%0A%20%20%20%20%60monarch.grad_generator%60%20incrementally%20runs%20the%20backward%20pass%2C%20returning%20an%20iterator%0A%20%20%20%20that%20computes%20the%20grad%20parameters%20one%20at%20a%20time.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(monarch%2C%20torch)%3A%0A%20%20%20%20def%20train(model%2C%20input%2C%20target)%3A%0A%20%20%20%20%20%20%20%20loss%20%3D%20model(input%2C%20target)%0A%20%20%20%20%20%20%20%20rparameters%20%3D%20list(reversed(list(model.parameters())))%0A%20%20%20%20%20%20%20%20grads%20%3D%20monarch.grad_generator(loss%2C%20rparameters)%0A%20%20%20%20%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20it%20%3D%20iter(zip(rparameters%2C%20grads))%0A%20%20%20%20%20%20%20%20%20%20%20%20todo%20%3D%20next(it%2C%20None)%0A%20%20%20%20%20%20%20%20%20%20%20%20while%20todo%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20param%2C%20grad%20%3D%20todo%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20grad.reduce_(%22dp%22%2C%20%22sum%22)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20todo%20%3D%20next(it%2C%20None)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20param%20%2B%3D%200.01%20*%20grad%0A%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20Simulation%0A%0A%20%20%20%20We%20can%20use%20a%20simulator%20to%20check%20for%20expected%20behavior%20of%20code%20before%20running%20it%0A%20%20%20%20for%20real.%20It%20is%20another%20kind%20of%20mesh%2C%20which%20simulates%20rather%20than%20computes%20results.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%20Overlapping%20Comms%2FCompute%0A%0A%20%20%20%20Commands%20on%20different%20devices%20run%20in%20parallel%2C%20but%20by%20default%20commands%20on%20a%20single%0A%20%20%20%20device%20run%20sequentially.%0A%0A%20%20%20%20We%20introduce%20parallelism%20on%20a%20device%20via%20stream%20objects.%20To%20use%20a%20tensor%20from%20one%0A%20%20%20%20stream%20on%20another%20we%20borrow%20it.%20The%20borrow%20API%20ensures%20deterministic%20memory%20usage%0A%20%20%20%20and%20eliminates%20the%20race%20conditions%20in%20the%20%60torch.cuda.stream%60%20API.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(monarch)%3A%0A%20%20%20%20main%20%3D%20monarch.get_active_stream()%0A%20%20%20%20comms%20%3D%20monarch.Stream(%22comms%22)%0A%20%20%20%20return%20(comms%2C)%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20DDP%20example%20again%2C%20but%20using%20multiple%20streams%3A%0A%0A%20%20%20%20---%0A%0A%20%20%20%20**Previous%3A**%20%5BNB03%20%E2%80%94%20Fault%20Tolerance%5D(.%2F03_fault_tolerance.html)%20%C2%B7%20**Next%3A**%20%5BNB05%20%E2%80%94%20RL%20Intro%5D(.%2F05_rl_intro.html)%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(comms%2C%20monarch%2C%20nn%2C%20torch)%3A%0A%20%20%20%20class%20Net(nn.Module)%3A%0A%20%20%20%20%20%20%20%20def%20__init__(self)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20super().__init__()%0A%20%20%20%20%20%20%20%20%20%20%20%20layers%20%3D%20%5B%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20for%20x%20in%20range(8)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20layers.append(nn.Linear(4%2C%204))%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20layers.append(nn.ReLU())%0A%20%20%20%20%20%20%20%20%20%20%20%20self.layers%20%3D%20nn.Sequential(*layers)%0A%0A%20%20%20%20%20%20%20%20def%20forward(self%2C%20input%2C%20target)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20output%20%3D%20self.layers(input)%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20torch.nn.functional.cross_entropy(output%2C%20target)%0A%0A%20%20%20%20def%20train2(model%2C%20input%2C%20target)%3A%0A%20%20%20%20%20%20%20%20loss%20%3D%20model(input%2C%20target)%0A%20%20%20%20%20%20%20%20rparameters%20%3D%20list(reversed(list(model.parameters())))%0A%20%20%20%20%20%20%20%20grads%20%3D%20monarch.grad_generator(loss%2C%20rparameters)%0A%20%20%20%20%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20NEW%3A%20iter%20also%20produces%20the%20tensor%20borrowed%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20to%20the%20comm%20stream%0A%20%20%20%20%20%20%20%20%20%20%20%20it%20%3D%20iter(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20(param%2C%20grad%2C%20*comms.borrow(grad%2C%20mutable%3DTrue))%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20for%20param%2C%20grad%20in%20zip(rparameters%2C%20grads)%0A%20%20%20%20%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20todo%20%3D%20next(it%2C%20None)%0A%20%20%20%20%20%20%20%20%20%20%20%20while%20todo%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20param%2C%20grad%2C%20comm_grad%2C%20borrow%20%3D%20todo%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20NEW%3A%20compute%20the%20reduce%20on%20the%20comm%20stream%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20with%20comms.activate()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20comm_grad.reduce_(%22dp%22%2C%20%22sum%22)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20borrow.drop()%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20todo%20%3D%20next(it%2C%20None)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20param%20%2B%3D%200.01%20*%20grad%0A%0A%0A%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20return%0A%0A%0Aif%20__name__%20%3D%3D%20%22__main__%22%3A%0A%20%20%20%20app.run()%0A
0598da5e3f5196b94106f5ac9f6508ce