题目## 题目
解题思路
给定 n n n 个整数和一个整数 m m m,要求计算所有可能的两两异或结果中大于 m m m 的数量。可以使用字典树(Trie)来高效地存储和查询异或结果。
- 字典树构建:将每个数字的二进制表示插入字典树中。
- 查询异或结果:对于每个数字,查找已有数字中与其异或大于 m m m 的数量。
- 统计结果:在查找过程中,统计符合条件的异或对的数量。
关键点
- 使用字典树存储二进制位,便于快速查找。
- 通过位运算判断异或结果是否大于 m m m。
- 采用贪心策略,优先选择能使异或结果更大的路径。
代码
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
const int N = 1e5 + 1;
int idx = 0;
int son[31 * N][2]; // 字典树节点
int num[31 * N]; // 节点计数
int find(int x, int m) {
int p = 0;
int ans = 0;
for (int i = 25; i >= 0; i--) {
int bit = (x >> i) & 1;
int mbit = (m >> i) & 1;
if (mbit == 0) {
// m的当前位是0,可以取更大的值
if (son[p][1 - bit]) {
ans += num[son[p][1 - bit]];
}
if (son[p][bit]) p = son[p][bit];
else break;
} else {
// m的当前位是1,必须取更大的值
if (son[p][1 - bit]) p = son[p][1 - bit];
else break;
}
}
return ans;
}
void insert(int x) {
int p = 0;
for (int i = 25; i >= 0; i--) {
int u = (x >> i & 1);
if (!son[p][u]) {
son[p][u] = ++idx;
}
p = son[p][u];
num[p]++;
}
}
void solve() {
int n, m;
cin >> n >> m;
i64 ans = 0;
for (int i = 0; i < n; i++) {
int x;
cin >> x;
ans += find(x, m); // 先查找
insert(x); // 再插入
}
cout << ans;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
solve();
return 0;
}
import java.util.*;
public class Main {
static final int N = 100001;
static int idx = 0;
static int[][] son = new int[31 * N][2];
static int[] num = new int[31 * N];
static int find(int x, int m) {
int p = 0;
int ans = 0;
for (int i = 25; i >= 0; i--) {
int bit = (x >> i) & 1;
int mbit = (m >> i) & 1;
if (mbit == 0) {
// m的当前位是0,可以取更大的值
if (son[p][1 - bit] != 0) {
ans += num[son[p][1 - bit]];
}
if (son[p][bit] != 0) p = son[p][bit];
else break;
} else {
// m的当前位是1,必须取更大的值
if (son[p][1 - bit] != 0) p = son[p][1 - bit];
else break;
}
}
return ans;
}
static void insert(int x) {
int p = 0;
for (int i = 25; i >= 0; i--) {
int u = (x >> i & 1);
if (son[p][u] == 0) {
son[p][u] = ++idx;
}
p = son[p][u];
num[p]++;
}
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int m = sc.nextInt();
long ans = 0;
for (int i = 0; i < n; i++) {
int x = sc.nextInt();
ans += find(x, m);
insert(x);
}
System.out.println(ans);
sc.close();
}
}
class TrieNode:
def __init__(self):
self.children = [None, None] # 0和1的子节点
self.count = 0 # 该节点的计数
class Trie:
def __init__(self):
self.root = TrieNode()
self.idx = 0 # 节点索引
def insert(self, x):
p = self.root
for i in range(25, -1, -1): # 从高位到低位
bit = (x >> i) & 1
if p.children[bit] is None:
p.children[bit] = TrieNode()
p = p.children[bit]
p.count += 1 # 更新节点计数
def find(self, x, m):
p = self.root
ans = 0
for i in range(25, -1, -1): # 从高位到低位
bit = (x >> i) & 1
mbit = (m >> i) & 1
if mbit == 0:
# m的当前位是0,可以取更大的值
if p.children[1 - bit] is not None:
ans += p.children[1 - bit].count
if p.children[bit] is not None:
p = p.children[bit]
else:
break
else:
# m的当前位是1,必须取更大的值
if p.children[1 - bit] is not None:
p = p.children[1 - bit]
else:
break
return ans
def main():
import sys
input = sys.stdin.read
data = input().split()
n = int(data[0])
m = int(data[1])
trie = Trie()
ans = 0
for i in range(n):
x = int(data[i + 2]) # 从第三个元素开始读取数字
ans += trie.find(x, m) # 先查找
trie.insert(x) # 再插入
print(ans)
if __name__ == "__main__":
main()
算法及复杂度
- 算法:字典树(Trie)+ 位运算
- 时间复杂度: O ( n ∗ 26 ) \mathcal{O}(n * 26) O(n∗26),其中 n n n 是数组长度
- 空间复杂度: O ( 31 ∗ n ) \mathcal{O}(31 * n) O(31∗n),用于存储字典树节点
解题思路
这是一道最优选址问题,主要思路如下:
- 首先计算所有房子的重心坐标 ( x , y ) (x, y) (x,y),作为初始搜索点
- 在重心附近的空地中寻找最优位置:
- 计算每个空地到所有房子的曼哈顿距离之和
- 选择距离和最小的位置作为中转站位置
- 优化策略:
- 只需要在重心附近的空地搜索,不需要遍历整个网格
- 使用曼哈顿距离( ∣ x 1 − x 2 ∣ + ∣ y 1 − y 2 ∣ |x_1-x_2| + |y_1-y_2| ∣x1−x2∣+∣y1−y2∣)计算距离
代码
#include <iostream>
#include <vector>
using namespace std;
int calDistance(vector<vector<int>>& grid, int n, int x, int y) {
int minDist = n * 2;
int bestX = 0, bestY = 0;
// 找到距离重心最近的空地
for(int i = 0; i < n; i++) {
for(int j = 0; j < n; j++) {
if(grid[i][j] == 0) {
int dist = abs(x - i) + abs(y - j);
if(dist < minDist) {
minDist = dist;
bestX = i;
bestY = j;
}
}
}
}
// 计算到所有房子的距离和
int sumDist = 0;
for(int i = 0; i < n; i++) {
for(int j = 0; j < n; j++) {
if(grid[i][j] == 1) {
sumDist += abs(bestX - i) + abs(bestY - j);
}
}
}
return sumDist;
}
int main() {
int n;
cin >> n;
vector<vector<int>> grid(n, vector<int>(n));
int count = 0, sumX = 0, sumY = 0;
// 读入网格并计算重心
for(int i = 0; i < n; i++) {
for(int j = 0; j < n; j++) {
cin >> grid[i][j];
if(grid[i][j] == 1) {
sumX += i;
sumY += j;
count++;
}
}
}
// 如果全是房子,无法建站
if(count == n * n) {
cout << -1 << endl;
return 0;
}
// 计算重心坐标
int centerX = sumX / count;
int centerY = sumY / count;
// 计算最小距离和
int result = calDistance(grid, n, centerX, centerY);
result = min(result, calDistance(grid, n, centerX + 1, centerY + 1));
cout << result << endl;
return 0;
}
import java.util.*;
public class Main {
static int calDistance(int[][] grid, int n, int x, int y) {
int minDist = n * 2;
int bestX = 0, bestY = 0;
// 找到最近的空地
for(int i = 0; i < n; i++) {
for(int j = 0; j < n; j++) {
if(grid[i][j] == 0) {
int dist = Math.abs(x - i) + Math.abs(y - j);
if(dist < minDist) {
minDist = dist;
bestX = i;
bestY = j;
}
}
}
}
// 计算总距离
int sumDist = 0;
for(int i = 0; i < n; i++) {
for(int j = 0; j < n; j++) {
if(grid[i][j] == 1) {
sumDist += Math.abs(bestX - i) + Math.abs(bestY - j);
}
}
}
return sumDist;
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int[][] grid = new int[n][n];
int count = 0, sumX = 0, sumY = 0;
for(int i = 0; i < n; i++) {
for(int j = 0; j < n; j++) {
grid[i][j] = sc.nextInt();
if(grid[i][j] == 1) {
sumX += i;
sumY += j;
count++;
}
}
}
if(count == n * n) {
System.out.println(-1);
return;
}
int centerX = sumX / count;
int centerY = sumY / count;
int result = calDistance(grid, n, centerX, centerY);
result = Math.min(result, calDistance(grid, n, centerX + 1, centerY + 1));
System.out.println(result);
}
}
def cal_distance(grid, n, x, y):
min_dist = n * 2
best_x = best_y = 0
# 找到最近的空地
for i in range(n):
for j in range(n):
if grid[i][j] == 0:
dist = abs(x - i) + abs(y - j)
if dist < min_dist:
min_dist = dist
best_x, best_y = i, j
# 计算总距离
sum_dist = 0
for i in range(n):
for j in range(n):
if grid[i][j] == 1:
sum_dist += abs(best_x - i) + abs(best_y - j)
return sum_dist
n = int(input())
grid = []
count = sum_x = sum_y = 0
# 读入网格并计算重心
for i in range(n):
row = list(map(int, input().split()))
grid.append(row)
for j in range(n):
if row[j] == 1:
sum_x += i
sum_y += j
count += 1
if count == n * n:
print(-1)
else:
center_x = sum_x // count
center_y = sum_y // count
result = cal_distance(grid, n, center_x, center_y)
result = min(result, cal_distance(grid, n, center_x + 1, center_y + 1))
print(result)
算法及复杂度
- 算法:贪心 + 曼哈顿距离计算
- 时间复杂度: O ( n 2 ) \mathcal{O}(n^2) O(n2) - 需要遍历网格两次
- 空间复杂度: O ( n 2 ) \mathcal{O}(n^2) O(n2) - 需要存储网格
630

被折叠的 条评论
为什么被折叠?



