1 条题解
-
0
/* A*算法 做法: 引入一个估值函数,用来估计某个点到达终点的距离。 记f是估值函数,g是真实值,那么f(state) <= g(state),越接近越好(当估值是0时,类似于Dijkstra算法) 记dist是从起点到state状态的步数; 利用的是优先队列,排序依据是dist[state] + f(state) 证明: 反证法: 假设终点第一次出堆时不是最小值,那么意味着dist[end] > dist优 那么说明堆中存在一个最优路径中的某个点(起码起点在路径上),记该点为u, dist优 = dist[u] + g(u) >= dist[u] + f(u) -> dist[end] > dist优 >= dist[u] + f(u),说明优先队列中存在一个比出堆元素更小的值,这就矛盾了。 所以说终点第一次出堆时就是最优的。 应用的环境: 1、有解(无解时,仍然会把所有空间搜索,会比一般的bfs慢,因为优先队列的操作是logn的) 2、边权非负,如果是负数,那么终点的估值有可能是负无穷,终点可能会直接出堆 性质: 除了终点以外的其他点无法在出堆或者如堆的时候确定距离,只能保证终点出堆时是最优的可以。 */ #include <iostream> #include <cstring> #include <queue> #include <unordered_map> #include <algorithm> using namespace std; typedef pair<int , string> PIS; unordered_map<string , int> dist; unordered_map<string , pair<string , char>> pre; priority_queue<PIS , vector<PIS> , greater<PIS>> heap; string ed = "12345678x"; int dx[4] = {-1 , 0 , 1 , 0} , dy[4] = {0 , 1 , 0 , -1}; char op[] = "urdl"; int f(string state)//求估值函数,这里是曼哈顿距离 { int res = 0; for(int i = 0 ; i < 9 ; i++) { if(state[i] != 'x') { int t = state[i] - '1'; res += abs(t / 3 - i / 3) + abs(t % 3 - i % 3); } } return res; } string bfs(string start) { heap.push({f(start) , start}); dist[start] = 0; while(heap.size()) { auto t = heap.top(); heap.pop(); string state = t.second; int step = dist[state];//记录到达state的实际距离 if(state == ed) break;//如果到达终点就break int k = state.find('x'); int x = k / 3 , y = k % 3; string source = state;//因为在下面state会变,所以留一个备份 for (int i = 0; i < 4; i ++ ) { int a = x + dx[i], b = y + dy[i]; if (a >= 0 && a < 3 && b >= 0 && b < 3) { swap(state[x * 3 + y], state[a * 3 + b]); if (!dist.count(state) || dist[state] > step + 1) { dist[state] = step + 1; pre[state] = {source, op[i]}; heap.push({dist[state] + f(state), state}); } swap(state[x * 3 + y], state[a * 3 + b]);//因为要多次交换,所以要恢复现场 } } } string res; while(ed != start) { res += pre[ed].second; ed = pre[ed].first; } reverse(res.begin() , res.end()); return res; } int main() { string start , seq; for(int i = 0 ; i < 9 ; i++) { char c; cin >> c; start += c; if(c != 'x') seq += c; } int cnt = 0; for(int i = 0 ; i < 8 ; i ++) for(int j = i + 1 ; j < 8 ; j++) if(seq[i] > seq[j]) cnt++; if(cnt % 2) puts("unsolvable"); else cout << bfs(start) << endl; return 0; }
- 1
信息
- ID
- 123
- 时间
- 1000ms
- 内存
- 256MiB
- 难度
- 10
- 标签
- 递交数
- 7
- 已通过
- 4
- 上传者