Update sampler.py

fix import
This commit is contained in:
Robin Rombach 2022-11-16 21:34:06 +01:00 committed by GitHub
parent 5a00c4f8db
commit 21f890f9da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,7 +2,7 @@
import torch
from .solver import NoiseScheduleVP, model_wrapper, DPM_Solver
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
class DPMSolverSampler(object):
@ -79,4 +79,4 @@ class DPMSolverSampler(object):
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
return x.to(device), None
return x.to(device), None