import{s as le,n as se,o as ce}from"../chunks/scheduler.defa9a21.js";import{S as ne,i as oe,g as n,s,r as B,A as re,h as o,f as a,c,j as O,u as T,x as b,k as ee,y as ie,a as l,v as g,d as v,t as k,w as G}from"../chunks/index.fe795e71.js";import{C as te}from"../chunks/CodeBlock.204b6c34.js";import{H as ae,E as pe}from"../chunks/getInferenceSnippets.2234a8dd.js";function Me(H){let r,Z,J,X,i,Y,p,Q=`When training a PyTorch model with Accelerate, you may often want to save and continue a state of training. Doing so requires saving and loading the model, optimizer, RNG generators, and the GradScaler. Inside Accelerate are two convenience functions to achieve this quickly:`,R,M,x='
  • Use save_state() for saving everything mentioned above to a folder location
  • Use load_state() for loading everything stored from an earlier save_state
  • ',I,d,L=`To further customize where and how states are saved through save_state() the ProjectConfiguration class can be used. For example if automatic_checkpoint_naming is enabled each saved checkpoint will be located then at Accelerator.project_dir/checkpoints/checkpoint_{checkpoint_number}.`,C,h,K="It should be noted that the expectation is that those states come from the same training script, they should not be from two separate scripts.",A,m,P=`
  • By using register_for_checkpointing(), you can register custom objects to be automatically stored or loaded from the two prior functions, so long as the object has a state_dict and a load_state_dict functionality. This could include objects such as a learning rate scheduler.
  • `,W,y,q="Below is a brief example using checkpointing to save and reload a state during training:",F,j,V,u,E,w,D=`After resuming from a checkpoint, it may also be desirable to resume from a particular point in the active DataLoader if the state was saved during the middle of an epoch. You can use skip_first_batches() to do so.`,N,f,S,U,$,_,z;return i=new ae({props:{title:"Checkpointing",local:"checkpointing",headingTag:"h1"}}),j=new te({props:{code:"ZnJvbSUyMGFjY2VsZXJhdGUlMjBpbXBvcnQlMjBBY2NlbGVyYXRvciUwQWltcG9ydCUyMHRvcmNoJTBBJTBBYWNjZWxlcmF0b3IlMjAlM0QlMjBBY2NlbGVyYXRvcihwcm9qZWN0X2RpciUzRCUyMm15JTJGc2F2ZSUyRnBhdGglMjIpJTBBJTBBbXlfc2NoZWR1bGVyJTIwJTNEJTIwdG9yY2gub3B0aW0ubHJfc2NoZWR1bGVyLlN0ZXBMUihteV9vcHRpbWl6ZXIlMkMlMjBzdGVwX3NpemUlM0QxJTJDJTIwZ2FtbWElM0QwLjk5KSUwQW15X21vZGVsJTJDJTIwbXlfb3B0aW1pemVyJTJDJTIwbXlfdHJhaW5pbmdfZGF0YWxvYWRlciUyMCUzRCUyMGFjY2VsZXJhdG9yLnByZXBhcmUobXlfbW9kZWwlMkMlMjBteV9vcHRpbWl6ZXIlMkMlMjBteV90cmFpbmluZ19kYXRhbG9hZGVyKSUwQSUwQSUyMyUyMFJlZ2lzdGVyJTIwdGhlJTIwTFIlMjBzY2hlZHVsZXIlMEFhY2NlbGVyYXRvci5yZWdpc3Rlcl9mb3JfY2hlY2twb2ludGluZyhteV9zY2hlZHVsZXIpJTBBJTBBJTIzJTIwU2F2ZSUyMHRoZSUyMHN0YXJ0aW5nJTIwc3RhdGUlMEFhY2NlbGVyYXRvci5zYXZlX3N0YXRlKCklMEElMEFkZXZpY2UlMjAlM0QlMjBhY2NlbGVyYXRvci5kZXZpY2UlMEFteV9tb2RlbC50byhkZXZpY2UpJTBBJTBBJTIzJTIwUGVyZm9ybSUyMHRyYWluaW5nJTBBZm9yJTIwZXBvY2glMjBpbiUyMHJhbmdlKG51bV9lcG9jaHMpJTNBJTBBJTIwJTIwJTIwJTIwZm9yJTIwYmF0Y2glMjBpbiUyMG15X3RyYWluaW5nX2RhdGFsb2FkZXIlM0ElMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBteV9vcHRpbWl6ZXIuemVyb19ncmFkKCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBpbnB1dHMlMkMlMjB0YXJnZXRzJTIwJTNEJTIwYmF0Y2glMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBpbnB1dHMlMjAlM0QlMjBpbnB1dHMudG8oZGV2aWNlKSUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMHRhcmdldHMlMjAlM0QlMjB0YXJnZXRzLnRvKGRldmljZSklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBvdXRwdXRzJTIwJTNEJTIwbXlfbW9kZWwoaW5wdXRzKSUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMGxvc3MlMjAlM0QlMjBteV9sb3NzX2Z1bmN0aW9uKG91dHB1dHMlMkMlMjB0YXJnZXRzKSUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMGFjY2VsZXJhdG9yLmJhY2t3YXJkKGxvc3MpJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwbXlfb3B0aW1pemVyLnN0ZXAoKSUwQSUyMCUyMCUyMCUyMG15X3NjaGVkdWxlci5zdGVwKCklMEElMEElMjMlMjBSZXN0b3JlJTIwdGhlJTIwcHJldmlvdXMlMjBzdGF0ZSUwQWFjY2VsZXJhdG9yLmxvYWRfc3RhdGUoJTIybXklMkZzYXZlJTJGcGF0aCUyRmNoZWNrcG9pbnRpbmclMkZjaGVja3BvaW50XzAlMjIp",highlighted:`from accelerate import Accelerator import torch accelerator = Accelerator(project_dir="my/save/path") my_scheduler = torch.optim.lr_scheduler.StepLR(my_optimizer, step_size=1, gamma=0.99) my_model, my_optimizer, my_training_dataloader = accelerator.prepare(my_model, my_optimizer, my_training_dataloader) # Register the LR scheduler accelerator.register_for_checkpointing(my_scheduler) # Save the starting state accelerator.save_state() device = accelerator.device my_model.to(device) # Perform training for epoch in range(num_epochs): for batch in my_training_dataloader: my_optimizer.zero_grad() inputs, targets = batch inputs = inputs.to(device) targets = targets.to(device) outputs = my_model(inputs) loss = my_loss_function(outputs, targets) accelerator.backward(loss) my_optimizer.step() my_scheduler.step() # Restore the previous state accelerator.load_state("my/save/path/checkpointing/checkpoint_0")`,wrap:!1}}),u=new ae({props:{title:"Restoring the state of the DataLoader",local:"restoring-the-state-of-the-dataloader",headingTag:"h2"}}),f=new te({props:{code:"ZnJvbSUyMGFjY2VsZXJhdGUlMjBpbXBvcnQlMjBBY2NlbGVyYXRvciUwQSUwQWFjY2VsZXJhdG9yJTIwJTNEJTIwQWNjZWxlcmF0b3IocHJvamVjdF9kaXIlM0QlMjJteSUyRnNhdmUlMkZwYXRoJTIyKSUwQSUwQXRyYWluX2RhdGFsb2FkZXIlMjAlM0QlMjBhY2NlbGVyYXRvci5wcmVwYXJlKHRyYWluX2RhdGFsb2FkZXIpJTBBYWNjZWxlcmF0b3IubG9hZF9zdGF0ZSglMjJteV9zdGF0ZSUyMiklMEElMEElMjMlMjBBc3N1bWUlMjB0aGUlMjBjaGVja3BvaW50JTIwd2FzJTIwc2F2ZWQlMjAxMDAlMjBzdGVwcyUyMGludG8lMjB0aGUlMjBlcG9jaCUwQXNraXBwZWRfZGF0YWxvYWRlciUyMCUzRCUyMGFjY2VsZXJhdG9yLnNraXBfZmlyc3RfYmF0Y2hlcyh0cmFpbl9kYXRhbG9hZGVyJTJDJTIwMTAwKSUwQSUwQSUyMyUyMEFmdGVyJTIwdGhlJTIwZmlyc3QlMjBpdGVyYXRpb24lMkMlMjBnbyUyMGJhY2slMjB0byUyMCU2MHRyYWluX2RhdGFsb2FkZXIlNjAlMEElMEElMjMlMjBGaXJzdCUyMGVwb2NoJTBBZm9yJTIwYmF0Y2glMjBpbiUyMHNraXBwZWRfZGF0YWxvYWRlciUzQSUwQSUyMCUyMCUyMCUyMCUyMyUyMERvJTIwc29tZXRoaW5nJTBBJTIwJTIwJTIwJTIwcGFzcyUwQSUwQSUyMyUyMFNlY29uZCUyMGVwb2NoJTBBZm9yJTIwYmF0Y2glMjBpbiUyMHRyYWluX2RhdGFsb2FkZXIlM0ElMEElMjAlMjAlMjAlMjAlMjMlMjBEbyUyMHNvbWV0aGluZyUwQSUyMCUyMCUyMCUyMHBhc3M=",highlighted:`from accelerate import Accelerator accelerator = Accelerator(project_dir="my/save/path") train_dataloader = accelerator.prepare(train_dataloader) accelerator.load_state("my_state") # Assume the checkpoint was saved 100 steps into the epoch skipped_dataloader = accelerator.skip_first_batches(train_dataloader, 100) # After the first iteration, go back to \`train_dataloader\` # First epoch for batch in skipped_dataloader: # Do something pass # Second epoch for batch in train_dataloader: # Do something pass`,wrap:!1}}),U=new pe({props:{source:"https://github.com/huggingface/accelerate/blob/main/docs/source/usage_guides/checkpoint.md"}}),{c(){r=n("meta"),Z=s(),J=n("p"),X=s(),B(i.$$.fragment),Y=s(),p=n("p"),p.textContent=Q,R=s(),M=n("ul"),M.innerHTML=x,I=s(),d=n("p"),d.innerHTML=L,C=s(),h=n("p"),h.textContent=K,A=s(),m=n("ul"),m.innerHTML=P,W=s(),y=n("p"),y.textContent=q,F=s(),B(j.$$.fragment),V=s(),B(u.$$.fragment),E=s(),w=n("p"),w.innerHTML=D,N=s(),B(f.$$.fragment),S=s(),B(U.$$.fragment),$=s(),_=n("p"),this.h()},l(e){const t=re("svelte-u9bgzb",document.head);r=o(t,"META",{name:!0,content:!0}),t.forEach(a),Z=c(e),J=o(e,"P",{}),O(J).forEach(a),X=c(e),T(i.$$.fragment,e),Y=c(e),p=o(e,"P",{"data-svelte-h":!0}),b(p)!=="svelte-13ap35m"&&(p.textContent=Q),R=c(e),M=o(e,"UL",{"data-svelte-h":!0}),b(M)!=="svelte-ziq9wt"&&(M.innerHTML=x),I=c(e),d=o(e,"P",{"data-svelte-h":!0}),b(d)!=="svelte-1u97blu"&&(d.innerHTML=L),C=c(e),h=o(e,"P",{"data-svelte-h":!0}),b(h)!=="svelte-iddkef"&&(h.textContent=K),A=c(e),m=o(e,"UL",{"data-svelte-h":!0}),b(m)!=="svelte-ubd99v"&&(m.innerHTML=P),W=c(e),y=o(e,"P",{"data-svelte-h":!0}),b(y)!=="svelte-3hrue1"&&(y.textContent=q),F=c(e),T(j.$$.fragment,e),V=c(e),T(u.$$.fragment,e),E=c(e),w=o(e,"P",{"data-svelte-h":!0}),b(w)!=="svelte-a4r6kf"&&(w.innerHTML=D),N=c(e),T(f.$$.fragment,e),S=c(e),T(U.$$.fragment,e),$=c(e),_=o(e,"P",{}),O(_).forEach(a),this.h()},h(){ee(r,"name","hf:doc:metadata"),ee(r,"content",de)},m(e,t){ie(document.head,r),l(e,Z,t),l(e,J,t),l(e,X,t),g(i,e,t),l(e,Y,t),l(e,p,t),l(e,R,t),l(e,M,t),l(e,I,t),l(e,d,t),l(e,C,t),l(e,h,t),l(e,A,t),l(e,m,t),l(e,W,t),l(e,y,t),l(e,F,t),g(j,e,t),l(e,V,t),g(u,e,t),l(e,E,t),l(e,w,t),l(e,N,t),g(f,e,t),l(e,S,t),g(U,e,t),l(e,$,t),l(e,_,t),z=!0},p:se,i(e){z||(v(i.$$.fragment,e),v(j.$$.fragment,e),v(u.$$.fragment,e),v(f.$$.fragment,e),v(U.$$.fragment,e),z=!0)},o(e){k(i.$$.fragment,e),k(j.$$.fragment,e),k(u.$$.fragment,e),k(f.$$.fragment,e),k(U.$$.fragment,e),z=!1},d(e){e&&(a(Z),a(J),a(X),a(Y),a(p),a(R),a(M),a(I),a(d),a(C),a(h),a(A),a(m),a(W),a(y),a(F),a(V),a(E),a(w),a(N),a(S),a($),a(_)),a(r),G(i,e),G(j,e),G(u,e),G(f,e),G(U,e)}}}const de='{"title":"Checkpointing","local":"checkpointing","sections":[{"title":"Restoring the state of the DataLoader","local":"restoring-the-state-of-the-dataloader","sections":[],"depth":2}],"depth":1}';function he(H){return ce(()=>{new URLSearchParams(window.location.search).get("fw")}),[]}class we extends ne{constructor(r){super(),oe(this,r,he,Me,le,{})}}export{we as component};