LWC 51:681. Next Closest Time

LWC 51:681. Next Closest Time

传送门:681. Next Closest Time

Problem:

Given a time represented in the format “HH:MM”, form the next closest time by reusing the current digits. There is no limit on how many times a digit can be reused.

You may assume the given input string is always valid. For example, “01:34”, “12:09” are all valid. “1:34”, “12:9” are all invalid.

Example 1:

Input: “19:34”
Output: “19:39”
Explanation: The next closest time choosing from digits 1, 9, 3, 4, is 19:39, which occurs 5 minutes later. It is not 19:33, because this occurs 23 hours and 59 minutes later.

Example 2:

Input: “23:59”
Output: “22:22”
Explanation: The next closest time choosing from digits 2, 3, 5, 9, is 22:22. It may be assumed that the returned time is next day’s time since it is smaller than the input time numerically.

思路:
暴力遍历,下一个最近的时间点,只需要让时间一点点增大,并且找到第一个所有数字都被reused的情况跳出即可。

代码如下:

    public String nextClosestTime(String time) {
        String[] val = time.split(":");
        Set<Integer> set = new HashSet<>();
        int num1 = Integer.parseInt(val[0]);
        set.add(num1 / 10);
        set.add(num1 % 10);
        int num2 = Integer.parseInt(val[1]);
        set.add(num2 / 10);
        set.add(num2 % 10);

        int hour = num1;
        int minu = num2;

        minu ++;
        if (minu == 60) {
            hour ++;
            minu = 0;
            if (hour == 24) hour = 0;
        }

        while (!contains(hour, minu, set)) {
            minu++;
            if (minu == 60) {
                hour ++;
                minu = 0;
                if (hour == 24) hour = 0;
            }
        }
        String ans = "";
        if (hour >= 0 && hour <= 9) ans = "0" + hour;
        else ans += hour;
        ans += ":";
        if (minu >= 0 && minu <= 9) ans += "0" + minu;
        else ans += minu;
        return ans;
    }

    public boolean contains(int hour, int minu, Set<Integer> set) {
        return set.contains(hour / 10) && set.contains(hour % 10) && set.contains(minu / 10) && set.contains(minu % 10);
    }

优化一下:

    public String nextClosestTime(String time) {
        String[] val = time.split(":");
        Set<Integer> set = new HashSet<>();
        int hour = add(set, val[0]);
        int minu = add(set, val[1]);

        int[] times = new int[] {hour, minu};
        nxt(times);

        while (!contains(times[0], times[1], set)) {
            nxt(times);
        }
        return valid(times[0]) + ":" + valid(times[1]);
    }

    public void nxt(int[] time) {
        int hour = time[0];
        int minu = time[1];
        minu ++;
        if (minu == 60) {
            hour ++;
            minu = 0;
            if (hour == 24) hour = 0;
        }
        time[0] = hour;
        time[1] = minu;
    }

    public int add(Set<Integer> set, String timeStr) {
        int time = Integer.parseInt(timeStr);
        set.add(time / 10);
        set.add(time % 10);
        return time;
    }

    public String valid(int time) {
        if (time >= 0 && time <= 9) return "0" + time;
        else return time + "";
    }

    public boolean contains(int hour, int minu, Set<Integer> set) {
        return set.contains(hour / 10) && set.contains(hour % 10) && set.contains(minu / 10) && set.contains(minu % 10);
    }

当然,你也可以直接搜索,因为最多只有4个数字,所以总共有4 * 4 * 4 * 4种情况,在这些情况中找出diff最小的hour和minute即可。

代码如下:

    int diff = 0x3f3f3f3f;
    String result = "";
    int h;
    int m;
    public String nextClosestTime(String time) {
        int[] digit = new int[4];
        int tot = 0;
        String[] val = time.split(":");
        int hour = Integer.parseInt(val[0]);
        int minu = Integer.parseInt(val[1]);
        digit[tot++] = hour / 10;
        digit[tot++] = hour % 10;
        digit[tot++] = minu / 10;
        digit[tot++] = minu % 10;

        h = hour;
        m = minu;

        dfs(digit, 0, new int[4]);

        return result;
    }

    void dfs(int[] digit, int i, int[] ans) {
        if (i == 4) {
            int hour = 10 * ans[0] + ans[1];
            int minu = 10 * ans[2] + ans[3];
            if (hour >= 0 && hour <= 23 && minu >= 0 && minu <= 59) {
                int df = diff(hour, minu);
                if (df < diff) {
                    diff = df;
                    result = valid(hour) + ":" + valid(minu);
                }
            }
        }
        else {
            for (int j = 0; j < 4; ++j) {
                ans[i] = digit[j];
                dfs(digit, i + 1, ans);
                ans[i] = -1;
            }
        }
    }

    int diff(int hour, int minu) {
        int c2o = 60 * 60 - h * 60 - m;
        int n2o = 60 * 60 - hour * 60 - minu;
        return n2o < c2o ? c2o - n2o : c2o - n2o + 3600;
    }

    public String valid(int time) {
        if (time >= 0 && time <= 9) return "0" + time;
        else return time + "";
    }

剪枝优化:

    int diff = 0x3f3f3f3f;
    String result = "";
    int h;
    int m;
    public String nextClosestTime(String time) {
        int[] digit = new int[4];
        int tot = 0;
        String[] val = time.split(":");
        int hour = Integer.parseInt(val[0]);
        int minu = Integer.parseInt(val[1]);
        digit[tot++] = hour / 10;
        digit[tot++] = hour % 10;
        digit[tot++] = minu / 10;
        digit[tot++] = minu % 10;

        h = hour;
        m = minu;

        dfs(digit, 0, new int[4]);

        return result;
    }

    void dfs(int[] digit, int i, int[] ans) {
        if (i == 4) {
            int hour = 10 * ans[0] + ans[1];
            int minu = 10 * ans[2] + ans[3];
            int df = diff(hour, minu);
            if (df < diff) {
                diff = df;
                result = valid(hour) + ":" + valid(minu);
            }
        }
        else {
            for (int j = 0; j < 4; ++j) {
                ans[i] = digit[j];
                if (i == 1) {
                    int hour = 10 * ans[0] + ans[1];
                    if (hour >= 0 && hour <= 23) dfs(digit, i + 1, ans);
                }
                else if (i == 3) {
                    int minu = 10 * ans[2] + ans[3];
                    if (minu >= 0 && minu <= 59) dfs(digit, i + 1, ans);
                }
                else {
                    dfs(digit, i + 1, ans);
                }
            }
        }
    }

    int diff(int hour, int minu) {
        int c2o = 60 * 60 - h * 60 - m;
        int n2o = 60 * 60 - hour * 60 - minu;
        return n2o < c2o ? c2o - n2o : c2o - n2o + 3600;
    }

    public String valid(int time) {
        if (time >= 0 && time <= 9) return "0" + time;
        else return time + "";
    }
class UniformAffineQuantizer(nn.Module): def __init__( self, n_bits: int = 8, symmetric: bool = False, per_channel_axes=[], metric="minmax", dynamic=False, dynamic_method="per_cluster", group_size=None, shape=None, lwc=False, disable_zero_point=False, ): """ support cluster quantize dynamic_method support per_token and per_cluster """ super().__init__() self.symmetric = symmetric self.disable_zero_point = disable_zero_point assert 2 <= n_bits <= 16, "bitwidth not supported" self.n_bits = n_bits if self.disable_zero_point: self.qmin = -(2 ** (n_bits - 1)) self.qmax = 2 ** (n_bits - 1) - 1 else: self.qmin = 0 self.qmax = 2 ** (n_bits) - 1 self.per_channel_axes = per_channel_axes self.metric = metric self.cluster_counts = None self.cluster_dim = None self.scale = None self.zero_point = None self.round_zero_point = None self.cached_xmin = None self.cached_xmax = None self.dynamic = dynamic self.dynamic_method = dynamic_method self.deficiency = 0 self.lwc = lwc init_value = 4. # inti value of learnable weight clipping if lwc: if group_size: dim1 = int(shape[0]*math.ceil(shape[1]/group_size)) self.deficiency = shape[-1]%group_size if self.deficiency > 0: self.deficiency = group_size - self.deficiency assert self.symmetric # support for mlc-llm symmetric quantization else: dim1 = shape[0] self.upbound_factor = nn.Parameter(torch.ones((dim1,1))*init_value) self.lowbound_factor = nn.Parameter(torch.ones((dim1,1))*init_value) self.sigmoid = nn.Sigmoid() self.enable = True self.group_size = group_size def change_n_bits(self, n_bits): self.n_bits = n_bits if self.disable_zero_point: self.qmin = -(2 ** (n_bits - 1)) self.qmax = 2 ** (n_bits - 1) - 1 else: self.qmin = 0 self.qmax = 2 ** (n_bits) - 1 def fake_quant(self, x, scale, round_zero_point): if self.deficiency > 0: pad_zeros = torch.zeros((x.shape[0],self.deficiency),dtype=x.dtype,device=x.device) x = torch.cat((x,pad_zeros),dim=1) if self.group_size: assert len(x.shape)==2, "only support linear layer now" dim1, dim2 = x.shape x = x.reshape(-1, self.group_size) x_int = round_ste(x / scale) if round_zero_point is not None: x_int = x_int.add(round_zero_point) x_int = x_int.clamp(self.qmin, self.qmax) x_dequant = x_int if round_zero_point is not None: x_dequant = x_dequant.sub(round_zero_point) x_dequant = x_dequant.mul(scale) if self.group_size: x_dequant = x_dequant.reshape(dim1, dim2) if self.deficiency > 0: x_dequant = x_dequant[:,:-self.deficiency] return x_dequant def forward(self, x: torch.Tensor): if self.n_bits >= 16 or not self.enable: return x if self.metric == "fix0to1": return x.mul_(2**self.n_bits-1).round_().div_(2**self.n_bits-1) if self.dynamic_method == "per_token" or self.dynamic_method == "per_channel": self.per_token_dynamic_calibration(x) else: raise NotImplementedError() x_dequant = self.fake_quant(x, self.scale, self.round_zero_point) return x_dequant def per_token_dynamic_calibration(self, x): if self.group_size: if self.deficiency == 0: x = x.reshape(-1,self.group_size) else: pad_zeros = torch.zeros((x.shape[0],self.deficiency),dtype=x.dtype,device=x.device) x = torch.cat((x,pad_zeros),dim=1) x = x.reshape(-1,self.group_size) reduce_shape = [-1] xmin = x.amin(reduce_shape, keepdim=True) xmax = x.amax(reduce_shape, keepdim=True) if self.lwc: xmax = self.sigmoid(self.upbound_factor)*xmax xmin = self.sigmoid(self.lowbound_factor)*xmin if self.symmetric: abs_max = torch.max(xmax.abs(),xmin.abs()) scale = abs_max / (2**(self.n_bits-1)-1) self.scale = scale.clamp(min=CLIPMIN, max=1e4) zero_point = (2**(self.n_bits-1)-1)*torch.ones_like(self.scale) else: range = xmax - xmin scale = range / (2**self.n_bits-1) self.scale = scale.clamp(min=CLIPMIN, max=1e4) zero_point = -(xmin) / (self.scale) if self.disable_zero_point: self.round_zero_point = None else: self.round_zero_point = zero_point.clamp(min=-1e4, max=1e4).round() def register_scales_and_zeros(self): self.register_buffer('scales', self.scale) self.register_buffer('zeros', self.round_zero_point) del self.scale del self.round_zero_point
最新发布
07-24
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值